Add files via upload
This commit is contained in:
parent
bab14a0208
commit
aafa3a3783
1 changed files with 23 additions and 7 deletions
|
@ -8,7 +8,7 @@ import torchvision as tv # type: ignore
|
|||
|
||||
import os
|
||||
import logging
|
||||
import h5py
|
||||
import h5py # type: ignore
|
||||
|
||||
from functions.create_logger import create_logger
|
||||
from functions.get_torch_device import get_torch_device
|
||||
|
@ -133,11 +133,27 @@ def process_trial(
|
|||
mylogger.info(f"CUDA memory: {free_mem//1024} MByte")
|
||||
|
||||
data_np: np.ndarray = np.load(filename_data, mmap_mode="r").astype(dtype_np)
|
||||
data: torch.Tensor = torch.zeros(data_np.shape, dtype=dtype, device=device)
|
||||
for i in range(0, len(config["required_order"])):
|
||||
mylogger.info(f"Move raw data to PyTorch device: {config['required_order'][i]}")
|
||||
idx = meta_channels.index(config["required_order"][i])
|
||||
data[..., i] = torch.tensor(data_np[..., idx], dtype=dtype, device=device)
|
||||
if config["binning_enable"] and (config["binning_at_the_end"] is False):
|
||||
|
||||
data: torch.Tensor = torch.zeros(
|
||||
data_np.shape, dtype=dtype, device=torch.device("cpu")
|
||||
)
|
||||
for i in range(0, len(config["required_order"])):
|
||||
mylogger.info(
|
||||
f"Move raw data to PyTorch CPU device: {config['required_order'][i]}"
|
||||
)
|
||||
idx = meta_channels.index(config["required_order"][i])
|
||||
data[..., i] = torch.tensor(
|
||||
data_np[..., idx], dtype=dtype, device=torch.device("cpu")
|
||||
)
|
||||
else:
|
||||
data = torch.zeros(data_np.shape, dtype=dtype, device=device)
|
||||
for i in range(0, len(config["required_order"])):
|
||||
mylogger.info(
|
||||
f"Move raw data to PyTorch device: {config['required_order'][i]}"
|
||||
)
|
||||
idx = meta_channels.index(config["required_order"][i])
|
||||
data[..., i] = torch.tensor(data_np[..., idx], dtype=dtype, device=device)
|
||||
|
||||
if device != torch.device("cpu"):
|
||||
free_mem = cuda_total_memory - max(
|
||||
|
@ -231,7 +247,7 @@ def process_trial(
|
|||
kernel_size=int(config["binning_kernel_size"]),
|
||||
stride=int(config["binning_stride"]),
|
||||
divisor_override=int(config["binning_divisor_override"]),
|
||||
)
|
||||
).to(device=device)
|
||||
ref_image_acceptor = (
|
||||
binning(
|
||||
ref_image_acceptor.unsqueeze(-1).unsqueeze(-1),
|
||||
|
|
Loading…
Reference in a new issue