Add files via upload

This commit is contained in:
David Rotermund 2024-02-27 20:51:26 +01:00 committed by GitHub
parent bab14a0208
commit aafa3a3783
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -8,7 +8,7 @@ import torchvision as tv # type: ignore
import os
import logging
import h5py
import h5py # type: ignore
from functions.create_logger import create_logger
from functions.get_torch_device import get_torch_device
@ -133,9 +133,25 @@ def process_trial(
mylogger.info(f"CUDA memory: {free_mem//1024} MByte")
data_np: np.ndarray = np.load(filename_data, mmap_mode="r").astype(dtype_np)
data: torch.Tensor = torch.zeros(data_np.shape, dtype=dtype, device=device)
if config["binning_enable"] and (config["binning_at_the_end"] is False):
data: torch.Tensor = torch.zeros(
data_np.shape, dtype=dtype, device=torch.device("cpu")
)
for i in range(0, len(config["required_order"])):
mylogger.info(f"Move raw data to PyTorch device: {config['required_order'][i]}")
mylogger.info(
f"Move raw data to PyTorch CPU device: {config['required_order'][i]}"
)
idx = meta_channels.index(config["required_order"][i])
data[..., i] = torch.tensor(
data_np[..., idx], dtype=dtype, device=torch.device("cpu")
)
else:
data = torch.zeros(data_np.shape, dtype=dtype, device=device)
for i in range(0, len(config["required_order"])):
mylogger.info(
f"Move raw data to PyTorch device: {config['required_order'][i]}"
)
idx = meta_channels.index(config["required_order"][i])
data[..., i] = torch.tensor(data_np[..., idx], dtype=dtype, device=device)
@ -231,7 +247,7 @@ def process_trial(
kernel_size=int(config["binning_kernel_size"]),
stride=int(config["binning_stride"]),
divisor_override=int(config["binning_divisor_override"]),
)
).to(device=device)
ref_image_acceptor = (
binning(
ref_image_acceptor.unsqueeze(-1).unsqueeze(-1),