diff --git a/new_pipeline/stage_4_process.py b/new_pipeline/stage_4_process.py index 33f9b8b..14045ff 100644 --- a/new_pipeline/stage_4_process.py +++ b/new_pipeline/stage_4_process.py @@ -8,7 +8,7 @@ import torchvision as tv # type: ignore import os import logging -import h5py +import h5py # type: ignore from functions.create_logger import create_logger from functions.get_torch_device import get_torch_device @@ -133,11 +133,27 @@ def process_trial( mylogger.info(f"CUDA memory: {free_mem//1024} MByte") data_np: np.ndarray = np.load(filename_data, mmap_mode="r").astype(dtype_np) - data: torch.Tensor = torch.zeros(data_np.shape, dtype=dtype, device=device) - for i in range(0, len(config["required_order"])): - mylogger.info(f"Move raw data to PyTorch device: {config['required_order'][i]}") - idx = meta_channels.index(config["required_order"][i]) - data[..., i] = torch.tensor(data_np[..., idx], dtype=dtype, device=device) + if config["binning_enable"] and (config["binning_at_the_end"] is False): + + data: torch.Tensor = torch.zeros( + data_np.shape, dtype=dtype, device=torch.device("cpu") + ) + for i in range(0, len(config["required_order"])): + mylogger.info( + f"Move raw data to PyTorch CPU device: {config['required_order'][i]}" + ) + idx = meta_channels.index(config["required_order"][i]) + data[..., i] = torch.tensor( + data_np[..., idx], dtype=dtype, device=torch.device("cpu") + ) + else: + data = torch.zeros(data_np.shape, dtype=dtype, device=device) + for i in range(0, len(config["required_order"])): + mylogger.info( + f"Move raw data to PyTorch device: {config['required_order'][i]}" + ) + idx = meta_channels.index(config["required_order"][i]) + data[..., i] = torch.tensor(data_np[..., idx], dtype=dtype, device=device) if device != torch.device("cpu"): free_mem = cuda_total_memory - max( @@ -231,7 +247,7 @@ def process_trial( kernel_size=int(config["binning_kernel_size"]), stride=int(config["binning_stride"]), divisor_override=int(config["binning_divisor_override"]), - ) + ).to(device=device) ref_image_acceptor = ( binning( ref_image_acceptor.unsqueeze(-1).unsqueeze(-1),