Add files via upload
This commit is contained in:
parent
bab14a0208
commit
aafa3a3783
1 changed files with 23 additions and 7 deletions
|
@ -8,7 +8,7 @@ import torchvision as tv # type: ignore
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import h5py
|
import h5py # type: ignore
|
||||||
|
|
||||||
from functions.create_logger import create_logger
|
from functions.create_logger import create_logger
|
||||||
from functions.get_torch_device import get_torch_device
|
from functions.get_torch_device import get_torch_device
|
||||||
|
@ -133,11 +133,27 @@ def process_trial(
|
||||||
mylogger.info(f"CUDA memory: {free_mem//1024} MByte")
|
mylogger.info(f"CUDA memory: {free_mem//1024} MByte")
|
||||||
|
|
||||||
data_np: np.ndarray = np.load(filename_data, mmap_mode="r").astype(dtype_np)
|
data_np: np.ndarray = np.load(filename_data, mmap_mode="r").astype(dtype_np)
|
||||||
data: torch.Tensor = torch.zeros(data_np.shape, dtype=dtype, device=device)
|
if config["binning_enable"] and (config["binning_at_the_end"] is False):
|
||||||
for i in range(0, len(config["required_order"])):
|
|
||||||
mylogger.info(f"Move raw data to PyTorch device: {config['required_order'][i]}")
|
data: torch.Tensor = torch.zeros(
|
||||||
idx = meta_channels.index(config["required_order"][i])
|
data_np.shape, dtype=dtype, device=torch.device("cpu")
|
||||||
data[..., i] = torch.tensor(data_np[..., idx], dtype=dtype, device=device)
|
)
|
||||||
|
for i in range(0, len(config["required_order"])):
|
||||||
|
mylogger.info(
|
||||||
|
f"Move raw data to PyTorch CPU device: {config['required_order'][i]}"
|
||||||
|
)
|
||||||
|
idx = meta_channels.index(config["required_order"][i])
|
||||||
|
data[..., i] = torch.tensor(
|
||||||
|
data_np[..., idx], dtype=dtype, device=torch.device("cpu")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
data = torch.zeros(data_np.shape, dtype=dtype, device=device)
|
||||||
|
for i in range(0, len(config["required_order"])):
|
||||||
|
mylogger.info(
|
||||||
|
f"Move raw data to PyTorch device: {config['required_order'][i]}"
|
||||||
|
)
|
||||||
|
idx = meta_channels.index(config["required_order"][i])
|
||||||
|
data[..., i] = torch.tensor(data_np[..., idx], dtype=dtype, device=device)
|
||||||
|
|
||||||
if device != torch.device("cpu"):
|
if device != torch.device("cpu"):
|
||||||
free_mem = cuda_total_memory - max(
|
free_mem = cuda_total_memory - max(
|
||||||
|
@ -231,7 +247,7 @@ def process_trial(
|
||||||
kernel_size=int(config["binning_kernel_size"]),
|
kernel_size=int(config["binning_kernel_size"]),
|
||||||
stride=int(config["binning_stride"]),
|
stride=int(config["binning_stride"]),
|
||||||
divisor_override=int(config["binning_divisor_override"]),
|
divisor_override=int(config["binning_divisor_override"]),
|
||||||
)
|
).to(device=device)
|
||||||
ref_image_acceptor = (
|
ref_image_acceptor = (
|
||||||
binning(
|
binning(
|
||||||
ref_image_acceptor.unsqueeze(-1).unsqueeze(-1),
|
ref_image_acceptor.unsqueeze(-1).unsqueeze(-1),
|
||||||
|
|
Loading…
Reference in a new issue