Add files via upload

This commit is contained in:
David Rotermund 2024-02-27 20:51:26 +01:00 committed by GitHub
parent bab14a0208
commit aafa3a3783
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -8,7 +8,7 @@ import torchvision as tv # type: ignore
import os import os
import logging import logging
import h5py import h5py # type: ignore
from functions.create_logger import create_logger from functions.create_logger import create_logger
from functions.get_torch_device import get_torch_device from functions.get_torch_device import get_torch_device
@ -133,9 +133,25 @@ def process_trial(
mylogger.info(f"CUDA memory: {free_mem//1024} MByte") mylogger.info(f"CUDA memory: {free_mem//1024} MByte")
data_np: np.ndarray = np.load(filename_data, mmap_mode="r").astype(dtype_np) data_np: np.ndarray = np.load(filename_data, mmap_mode="r").astype(dtype_np)
data: torch.Tensor = torch.zeros(data_np.shape, dtype=dtype, device=device) if config["binning_enable"] and (config["binning_at_the_end"] is False):
data: torch.Tensor = torch.zeros(
data_np.shape, dtype=dtype, device=torch.device("cpu")
)
for i in range(0, len(config["required_order"])): for i in range(0, len(config["required_order"])):
mylogger.info(f"Move raw data to PyTorch device: {config['required_order'][i]}") mylogger.info(
f"Move raw data to PyTorch CPU device: {config['required_order'][i]}"
)
idx = meta_channels.index(config["required_order"][i])
data[..., i] = torch.tensor(
data_np[..., idx], dtype=dtype, device=torch.device("cpu")
)
else:
data = torch.zeros(data_np.shape, dtype=dtype, device=device)
for i in range(0, len(config["required_order"])):
mylogger.info(
f"Move raw data to PyTorch device: {config['required_order'][i]}"
)
idx = meta_channels.index(config["required_order"][i]) idx = meta_channels.index(config["required_order"][i])
data[..., i] = torch.tensor(data_np[..., idx], dtype=dtype, device=device) data[..., i] = torch.tensor(data_np[..., idx], dtype=dtype, device=device)
@ -231,7 +247,7 @@ def process_trial(
kernel_size=int(config["binning_kernel_size"]), kernel_size=int(config["binning_kernel_size"]),
stride=int(config["binning_stride"]), stride=int(config["binning_stride"]),
divisor_override=int(config["binning_divisor_override"]), divisor_override=int(config["binning_divisor_override"]),
) ).to(device=device)
ref_image_acceptor = ( ref_image_acceptor = (
binning( binning(
ref_image_acceptor.unsqueeze(-1).unsqueeze(-1), ref_image_acceptor.unsqueeze(-1).unsqueeze(-1),