Add files via upload

This commit is contained in:
David Rotermund 2024-02-28 16:14:50 +01:00 committed by GitHub
parent cc54cf1a29
commit 84c254ae76
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 2329 additions and 0 deletions

1015
functions/ImageAlignment.py Normal file

File diff suppressed because it is too large Load diff

57
functions/align_refref.py Normal file
View file

@ -0,0 +1,57 @@
import torch
import torchvision as tv # type: ignore
import logging
from functions.ImageAlignment import ImageAlignment
from functions.calculate_translation import calculate_translation
from functions.calculate_rotation import calculate_rotation
@torch.no_grad()
def align_refref(
mylogger: logging.Logger,
ref_image_acceptor: torch.Tensor,
ref_image_donor: torch.Tensor,
image_alignment: ImageAlignment,
batch_size: int,
fill_value: float = 0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
mylogger.info("Rotate ref image acceptor onto donor")
angle_refref = calculate_rotation(
image_alignment=image_alignment,
input=ref_image_acceptor.unsqueeze(0),
reference_image=ref_image_donor,
batch_size=batch_size,
)
ref_image_acceptor = tv.transforms.functional.affine(
img=ref_image_acceptor.unsqueeze(0),
angle=-float(angle_refref),
translate=[0, 0],
scale=1.0,
shear=0,
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=fill_value,
)
mylogger.info("Translate ref image acceptor onto donor")
tvec_refref = calculate_translation(
image_alignment=image_alignment,
input=ref_image_acceptor,
reference_image=ref_image_donor,
batch_size=batch_size,
)
tvec_refref = tvec_refref[0, :]
ref_image_acceptor = tv.transforms.functional.affine(
img=ref_image_acceptor,
angle=0,
translate=[tvec_refref[1], tvec_refref[0]],
scale=1.0,
shear=0,
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=fill_value,
).squeeze(0)
return angle_refref, tvec_refref, ref_image_acceptor, ref_image_donor

85
functions/bandpass.py Normal file
View file

@ -0,0 +1,85 @@
import torchaudio as ta # type: ignore
import torch
@torch.no_grad()
def filtfilt(
input: torch.Tensor,
butter_a: torch.Tensor,
butter_b: torch.Tensor,
) -> torch.Tensor:
assert butter_a.ndim == 1
assert butter_b.ndim == 1
assert butter_a.shape[0] == butter_b.shape[0]
process_data: torch.Tensor = input.detach().clone()
padding_length = 12 * int(butter_a.shape[0])
left_padding = 2 * process_data[..., 0].unsqueeze(-1) - process_data[
..., 1 : padding_length + 1
].flip(-1)
right_padding = 2 * process_data[..., -1].unsqueeze(-1) - process_data[
..., -(padding_length + 1) : -1
].flip(-1)
process_data_padded = torch.cat((left_padding, process_data, right_padding), dim=-1)
output = ta.functional.filtfilt(
process_data_padded.unsqueeze(0), butter_a, butter_b, clamp=False
).squeeze(0)
output = output[..., padding_length:-padding_length]
return output
@torch.no_grad()
def butter_bandpass(
device: torch.device,
low_frequency: float = 0.1,
high_frequency: float = 1.0,
fs: float = 30.0,
) -> tuple[torch.Tensor, torch.Tensor]:
import scipy # type: ignore
butter_b_np, butter_a_np = scipy.signal.butter(
4, [low_frequency, high_frequency], btype="bandpass", output="ba", fs=fs
)
butter_a = torch.tensor(butter_a_np, device=device, dtype=torch.float32)
butter_b = torch.tensor(butter_b_np, device=device, dtype=torch.float32)
return butter_a, butter_b
@torch.no_grad()
def chunk_iterator(array: torch.Tensor, chunk_size: int):
for i in range(0, array.shape[0], chunk_size):
yield array[i : i + chunk_size]
@torch.no_grad()
def bandpass(
data: torch.Tensor,
device: torch.device,
low_frequency: float = 0.1,
high_frequency: float = 1.0,
fs=30.0,
filtfilt_chuck_size: int = 10,
) -> torch.Tensor:
butter_a, butter_b = butter_bandpass(
device=device,
low_frequency=low_frequency,
high_frequency=high_frequency,
fs=fs,
)
index_full_dataset: torch.Tensor = torch.arange(
0, data.shape[1], device=device, dtype=torch.int64
)
for chunk in chunk_iterator(index_full_dataset, filtfilt_chuck_size):
temp_filtfilt = filtfilt(
data[:, chunk, :],
butter_a=butter_a,
butter_b=butter_b,
)
data[:, chunk, :] = temp_filtfilt
return data

21
functions/binning.py Normal file
View file

@ -0,0 +1,21 @@
import torch
def binning(
data: torch.Tensor,
kernel_size: int = 4,
stride: int = 4,
divisor_override: int | None = 1,
) -> torch.Tensor:
assert data.ndim == 4
return (
torch.nn.functional.avg_pool2d(
input=data.movedim(0, -1).movedim(0, -1),
kernel_size=kernel_size,
stride=stride,
divisor_override=divisor_override,
)
.movedim(-1, 0)
.movedim(-1, 0)
)

View file

@ -0,0 +1,40 @@
import torch
from functions.ImageAlignment import ImageAlignment
@torch.no_grad()
def calculate_rotation(
image_alignment: ImageAlignment,
input: torch.Tensor,
reference_image: torch.Tensor,
batch_size: int,
) -> torch.Tensor:
angle = torch.zeros((input.shape[0]))
data_loader = torch.utils.data.DataLoader(
torch.utils.data.TensorDataset(input),
batch_size=batch_size,
shuffle=False,
)
start_position: int = 0
for input_batch in data_loader:
assert len(input_batch) == 1
end_position = start_position + input_batch[0].shape[0]
angle_temp = image_alignment.dry_run_angle(
input=input_batch[0],
new_reference_image=reference_image,
)
assert angle_temp is not None
angle[start_position:end_position] = angle_temp
start_position += input_batch[0].shape[0]
angle = torch.where(angle >= 180, 360.0 - angle, angle)
angle = torch.where(angle <= -180, 360.0 + angle, angle)
return angle

View file

@ -0,0 +1,37 @@
import torch
from functions.ImageAlignment import ImageAlignment
@torch.no_grad()
def calculate_translation(
image_alignment: ImageAlignment,
input: torch.Tensor,
reference_image: torch.Tensor,
batch_size: int,
) -> torch.Tensor:
tvec = torch.zeros((input.shape[0], 2))
data_loader = torch.utils.data.DataLoader(
torch.utils.data.TensorDataset(input),
batch_size=batch_size,
shuffle=False,
)
start_position: int = 0
for input_batch in data_loader:
assert len(input_batch) == 1
end_position = start_position + input_batch[0].shape[0]
tvec_temp = image_alignment.dry_run_translation(
input=input_batch[0],
new_reference_image=reference_image,
)
assert tvec_temp is not None
tvec[start_position:end_position, :] = tvec_temp
start_position += input_batch[0].shape[0]
return tvec

View file

@ -0,0 +1,37 @@
import logging
import datetime
import os
def create_logger(
save_logging_messages: bool, display_logging_messages: bool, log_stage_name: str
):
now = datetime.datetime.now()
dt_string_filename = now.strftime("%Y_%m_%d_%H_%M_%S")
logger = logging.getLogger("MyLittleLogger")
logger.setLevel(logging.DEBUG)
if save_logging_messages:
time_format = "%b %-d %Y %H:%M:%S"
logformat = "%(asctime)s %(message)s"
file_formatter = logging.Formatter(fmt=logformat, datefmt=time_format)
os.makedirs("logs_" + log_stage_name, exist_ok=True)
file_handler = logging.FileHandler(
os.path.join("logs_" + log_stage_name, f"log_{dt_string_filename}.txt")
)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
if display_logging_messages:
time_format = "%H:%M:%S"
logformat = "%(asctime)s %(message)s"
stream_formatter = logging.Formatter(fmt=logformat, datefmt=time_format)
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(stream_formatter)
logger.addHandler(stream_handler)
return logger

View file

@ -0,0 +1,339 @@
import numpy as np
import torch
import os
import logging
import copy
from functions.get_experiments import get_experiments
from functions.get_trials import get_trials
from functions.get_parts import get_parts
from functions.load_meta_data import load_meta_data
def data_raw_loader(
raw_data_path: str,
mylogger: logging.Logger,
experiment_id: int,
trial_id: int,
device: torch.device,
force_to_cpu_memory: bool,
config: dict,
) -> tuple[list[str], str, str, dict, dict, float, float, str, torch.Tensor]:
meta_channels: list[str] = []
meta_mouse_markings: str = ""
meta_recording_date: str = ""
meta_stimulation_times: dict = {}
meta_experiment_names: dict = {}
meta_trial_recording_duration: float = 0.0
meta_frame_time: float = 0.0
meta_mouse: str = ""
data: torch.Tensor = torch.zeros((1))
dtype_str = config["dtype"]
mylogger.info(f"Data precision will be {dtype_str}")
dtype: torch.dtype = getattr(torch, dtype_str)
dtype_np: np.dtype = getattr(np, dtype_str)
if os.path.isdir(raw_data_path) is False:
mylogger.info(f"ERROR: could not find raw directory {raw_data_path}!!!!")
assert os.path.isdir(raw_data_path)
return (
meta_channels,
meta_mouse_markings,
meta_recording_date,
meta_stimulation_times,
meta_experiment_names,
meta_trial_recording_duration,
meta_frame_time,
meta_mouse,
data,
)
if (torch.where(get_experiments(raw_data_path) == experiment_id)[0].shape[0]) != 1:
mylogger.info(f"ERROR: could not find experiment id {experiment_id}!!!!")
assert (
torch.where(get_experiments(raw_data_path) == experiment_id)[0].shape[0]
) == 1
return (
meta_channels,
meta_mouse_markings,
meta_recording_date,
meta_stimulation_times,
meta_experiment_names,
meta_trial_recording_duration,
meta_frame_time,
meta_mouse,
data,
)
if (
torch.where(get_trials(raw_data_path, experiment_id) == trial_id)[0].shape[0]
) != 1:
mylogger.info(f"ERROR: could not find trial id {trial_id}!!!!")
assert (
torch.where(get_trials(raw_data_path, experiment_id) == trial_id)[0].shape[
0
]
) == 1
return (
meta_channels,
meta_mouse_markings,
meta_recording_date,
meta_stimulation_times,
meta_experiment_names,
meta_trial_recording_duration,
meta_frame_time,
meta_mouse,
data,
)
available_parts: torch.Tensor = get_parts(raw_data_path, experiment_id, trial_id)
if available_parts.shape[0] < 1:
mylogger.info("ERROR: could not find any part files")
assert available_parts.shape[0] >= 1
experiment_name = f"Exp{experiment_id:03d}_Trial{trial_id:03d}"
mylogger.info(f"Will work on: {experiment_name}")
mylogger.info(f"We found {int(available_parts.shape[0])} parts.")
first_run: bool = True
mylogger.info("Compare meta data of all parts")
for id in range(0, available_parts.shape[0]):
part_id = available_parts[id]
filename_meta: str = os.path.join(
raw_data_path,
f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}_meta.txt",
)
if os.path.isfile(filename_meta) is False:
mylogger.info(f"Could not load meta data... {filename_meta}")
assert os.path.isfile(filename_meta)
return (
meta_channels,
meta_mouse_markings,
meta_recording_date,
meta_stimulation_times,
meta_experiment_names,
meta_trial_recording_duration,
meta_frame_time,
meta_mouse,
data,
)
(
meta_channels,
meta_mouse_markings,
meta_recording_date,
meta_stimulation_times,
meta_experiment_names,
meta_trial_recording_duration,
meta_frame_time,
meta_mouse,
) = load_meta_data(
mylogger=mylogger, filename_meta=filename_meta, silent_mode=True
)
if first_run:
first_run = False
master_meta_channels: list[str] = copy.deepcopy(meta_channels)
master_meta_mouse_markings: str = meta_mouse_markings
master_meta_recording_date: str = meta_recording_date
master_meta_stimulation_times: dict = copy.deepcopy(meta_stimulation_times)
master_meta_experiment_names: dict = copy.deepcopy(meta_experiment_names)
master_meta_trial_recording_duration: float = meta_trial_recording_duration
master_meta_frame_time: float = meta_frame_time
master_meta_mouse: str = meta_mouse
meta_channels_check = master_meta_channels == meta_channels
# Check channel order
if meta_channels_check:
for channel_a, channel_b in zip(master_meta_channels, meta_channels):
if channel_a != channel_b:
meta_channels_check = False
meta_mouse_markings_check = master_meta_mouse_markings == meta_mouse_markings
meta_recording_date_check = master_meta_recording_date == meta_recording_date
meta_stimulation_times_check = (
master_meta_stimulation_times == meta_stimulation_times
)
meta_experiment_names_check = (
master_meta_experiment_names == meta_experiment_names
)
meta_trial_recording_duration_check = (
master_meta_trial_recording_duration == meta_trial_recording_duration
)
meta_frame_time_check = master_meta_frame_time == meta_frame_time
meta_mouse_check = master_meta_mouse == meta_mouse
if meta_channels_check is False:
mylogger.info(f"{filename_meta} failed: channels")
assert meta_channels_check
if meta_mouse_markings_check is False:
mylogger.info(f"{filename_meta} failed: mouse_markings")
assert meta_mouse_markings_check
if meta_recording_date_check is False:
mylogger.info(f"{filename_meta} failed: recording_date")
assert meta_recording_date_check
if meta_stimulation_times_check is False:
mylogger.info(f"{filename_meta} failed: stimulation_times")
assert meta_stimulation_times_check
if meta_experiment_names_check is False:
mylogger.info(f"{filename_meta} failed: experiment_names")
assert meta_experiment_names_check
if meta_trial_recording_duration_check is False:
mylogger.info(f"{filename_meta} failed: trial_recording_duration")
assert meta_trial_recording_duration_check
if meta_frame_time_check is False:
mylogger.info(f"{filename_meta} failed: frame_time_check")
assert meta_frame_time_check
if meta_mouse_check is False:
mylogger.info(f"{filename_meta} failed: mouse")
assert meta_mouse_check
mylogger.info("-==- Done -==-")
mylogger.info(f"Will use: {filename_meta} for meta data")
(
meta_channels,
meta_mouse_markings,
meta_recording_date,
meta_stimulation_times,
meta_experiment_names,
meta_trial_recording_duration,
meta_frame_time,
meta_mouse,
) = load_meta_data(mylogger=mylogger, filename_meta=filename_meta)
#################
# Meta data end #
#################
first_run = True
mylogger.info("Count the number of frames in the data of all parts")
frame_count: int = 0
for id in range(0, available_parts.shape[0]):
part_id = available_parts[id]
filename_data: str = os.path.join(
raw_data_path,
f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}.npy",
)
if os.path.isfile(filename_data) is False:
mylogger.info(f"Could not load data... {filename_data}")
assert os.path.isfile(filename_data)
return (
meta_channels,
meta_mouse_markings,
meta_recording_date,
meta_stimulation_times,
meta_experiment_names,
meta_trial_recording_duration,
meta_frame_time,
meta_mouse,
data,
)
data_np: np.ndarray = np.load(filename_data, mmap_mode="r")
if data_np.ndim != 4:
mylogger.info(f"ERROR: Data needs to have 4 dimensions {filename_data}")
assert data_np.ndim == 4
if first_run:
first_run = False
dim_0: int = int(data_np.shape[0])
dim_1: int = int(data_np.shape[1])
dim_3: int = int(data_np.shape[3])
frame_count += int(data_np.shape[2])
if int(data_np.shape[0]) != dim_0:
mylogger.info(
f"ERROR: Data dim 0 is broken {int(data_np.shape[0])} vs {dim_0} {filename_data}"
)
assert int(data_np.shape[0]) == dim_0
if int(data_np.shape[1]) != dim_1:
mylogger.info(
f"ERROR: Data dim 1 is broken {int(data_np.shape[1])} vs {dim_1} {filename_data}"
)
assert int(data_np.shape[1]) == dim_1
if int(data_np.shape[3]) != dim_3:
mylogger.info(
f"ERROR: Data dim 3 is broken {int(data_np.shape[3])} vs {dim_3} {filename_data}"
)
assert int(data_np.shape[3]) == dim_3
mylogger.info(
f"{filename_data}: {int(data_np.shape[2])} frames -> {frame_count} frames total"
)
if force_to_cpu_memory:
mylogger.info("Using CPU memory for data")
data = torch.empty(
(dim_0, dim_1, frame_count, dim_3), dtype=dtype, device=torch.device("cpu")
)
else:
mylogger.info("Using GPU memory for data")
data = torch.empty(
(dim_0, dim_1, frame_count, dim_3), dtype=dtype, device=device
)
start_position: int = 0
end_position: int = 0
for id in range(0, available_parts.shape[0]):
part_id = available_parts[id]
filename_data = os.path.join(
raw_data_path,
f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}.npy",
)
mylogger.info(f"Will work on {filename_data}")
mylogger.info("Loading data file")
data_np = np.load(filename_data).astype(dtype_np)
end_position = start_position + int(data_np.shape[2])
for i in range(0, len(config["required_order"])):
mylogger.info(f"Move raw data channel: {config['required_order'][i]}")
idx = meta_channels.index(config["required_order"][i])
data[..., start_position:end_position, i] = torch.tensor(
data_np[..., idx], dtype=dtype, device=data.device
)
start_position = end_position
if start_position != int(data.shape[2]):
mylogger.info("ERROR: data was not fulled fully!!!")
assert start_position == int(data.shape[2])
mylogger.info("-==- Done -==-")
#################
# Raw data end #
#################
return (
meta_channels,
meta_mouse_markings,
meta_recording_date,
meta_stimulation_times,
meta_experiment_names,
meta_trial_recording_duration,
meta_frame_time,
meta_mouse,
data,
)

View file

@ -0,0 +1,127 @@
import torch
import math
@torch.no_grad()
def gauss_smear_individual(
input: torch.Tensor,
spatial_width: float,
temporal_width: float,
overwrite_fft_gauss: None | torch.Tensor = None,
use_matlab_mask: bool = True,
epsilon: float = float(torch.finfo(torch.float64).eps),
) -> tuple[torch.Tensor, torch.Tensor]:
dim_x: int = int(2 * math.ceil(2 * spatial_width) + 1)
dim_y: int = int(2 * math.ceil(2 * spatial_width) + 1)
dim_t: int = int(2 * math.ceil(2 * temporal_width) + 1)
dims_xyt: torch.Tensor = torch.tensor(
[dim_x, dim_y, dim_t], dtype=torch.int64, device=input.device
)
if input.ndim == 2:
input = input.unsqueeze(-1)
input_padded = torch.nn.functional.pad(
input.unsqueeze(0),
pad=(
dim_t,
dim_t,
dim_y,
dim_y,
dim_x,
dim_x,
),
mode="replicate",
).squeeze(0)
if overwrite_fft_gauss is None:
center_x: int = int(math.ceil(input_padded.shape[0] / 2))
center_y: int = int(math.ceil(input_padded.shape[1] / 2))
center_z: int = int(math.ceil(input_padded.shape[2] / 2))
grid_x: torch.Tensor = (
torch.arange(0, input_padded.shape[0], device=input.device) - center_x + 1
)
grid_y: torch.Tensor = (
torch.arange(0, input_padded.shape[1], device=input.device) - center_y + 1
)
grid_z: torch.Tensor = (
torch.arange(0, input_padded.shape[2], device=input.device) - center_z + 1
)
grid_x = grid_x.unsqueeze(-1).unsqueeze(-1) ** 2
grid_y = grid_y.unsqueeze(0).unsqueeze(-1) ** 2
grid_z = grid_z.unsqueeze(0).unsqueeze(0) ** 2
gauss_kernel: torch.Tensor = (
(grid_x / (spatial_width**2))
+ (grid_y / (spatial_width**2))
+ (grid_z / (temporal_width**2))
)
if use_matlab_mask:
filter_radius: torch.Tensor = (dims_xyt - 1) // 2
border_lower: list[int] = [
center_x - int(filter_radius[0]) - 1,
center_y - int(filter_radius[1]) - 1,
center_z - int(filter_radius[2]) - 1,
]
border_upper: list[int] = [
center_x + int(filter_radius[0]),
center_y + int(filter_radius[1]),
center_z + int(filter_radius[2]),
]
matlab_mask: torch.Tensor = torch.zeros_like(gauss_kernel)
matlab_mask[
border_lower[0] : border_upper[0],
border_lower[1] : border_upper[1],
border_lower[2] : border_upper[2],
] = 1.0
gauss_kernel = torch.exp(-gauss_kernel / 2.0)
if use_matlab_mask:
gauss_kernel = gauss_kernel * matlab_mask
gauss_kernel[gauss_kernel < (epsilon * gauss_kernel.max())] = 0
sum_gauss_kernel: float = float(gauss_kernel.sum())
if sum_gauss_kernel != 0.0:
gauss_kernel = gauss_kernel / sum_gauss_kernel
# FFT Shift
gauss_kernel = torch.cat(
(gauss_kernel[center_x - 1 :, :, :], gauss_kernel[: center_x - 1, :, :]),
dim=0,
)
gauss_kernel = torch.cat(
(gauss_kernel[:, center_y - 1 :, :], gauss_kernel[:, : center_y - 1, :]),
dim=1,
)
gauss_kernel = torch.cat(
(gauss_kernel[:, :, center_z - 1 :], gauss_kernel[:, :, : center_z - 1]),
dim=2,
)
overwrite_fft_gauss = torch.fft.fftn(gauss_kernel)
input_padded_gauss_filtered: torch.Tensor = torch.real(
torch.fft.ifftn(torch.fft.fftn(input_padded) * overwrite_fft_gauss)
)
else:
input_padded_gauss_filtered = torch.real(
torch.fft.ifftn(torch.fft.fftn(input_padded) * overwrite_fft_gauss)
)
start = dims_xyt
stop = (
torch.tensor(input_padded.shape, device=dims_xyt.device, dtype=dims_xyt.dtype)
- dims_xyt
)
output = input_padded_gauss_filtered[
start[0] : stop[0], start[1] : stop[1], start[2] : stop[2]
]
return (output, overwrite_fft_gauss)

View file

@ -0,0 +1,19 @@
import torch
import os
import glob
@torch.no_grad()
def get_experiments(path: str) -> torch.Tensor:
filename_np: str = os.path.join(
path,
"Exp*_Part001.npy",
)
list_str = glob.glob(filename_np)
list_int: list[int] = []
for i in range(0, len(list_str)):
list_int.append(int(list_str[i].split("Exp")[-1].split("_Trial")[0]))
list_int = sorted(list_int)
return torch.tensor(list_int).unique()

18
functions/get_parts.py Normal file
View file

@ -0,0 +1,18 @@
import torch
import os
import glob
@torch.no_grad()
def get_parts(path: str, experiment_id: int, trial_id: int) -> torch.Tensor:
filename_np: str = os.path.join(
path,
f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part*.npy",
)
list_str = glob.glob(filename_np)
list_int: list[int] = []
for i in range(0, len(list_str)):
list_int.append(int(list_str[i].split("_Part")[-1].split(".npy")[0]))
list_int = sorted(list_int)
return torch.tensor(list_int).unique()

View file

@ -0,0 +1,17 @@
import torch
import logging
def get_torch_device(mylogger: logging.Logger, force_to_cpu: bool) -> torch.device:
if torch.cuda.is_available():
device_name: str = "cuda:0"
else:
device_name = "cpu"
if force_to_cpu:
device_name = "cpu"
mylogger.info(f"Using device: {device_name}")
device: torch.device = torch.device(device_name)
return device

18
functions/get_trials.py Normal file
View file

@ -0,0 +1,18 @@
import torch
import os
import glob
@torch.no_grad()
def get_trials(path: str, experiment_id: int) -> torch.Tensor:
filename_np: str = os.path.join(
path,
f"Exp{experiment_id:03d}_Trial*_Part001.npy",
)
list_str = glob.glob(filename_np)
list_int: list[int] = []
for i in range(0, len(list_str)):
list_int.append(int(list_str[i].split("_Trial")[-1].split("_Part")[0]))
list_int = sorted(list_int)
return torch.tensor(list_int).unique()

16
functions/load_config.py Normal file
View file

@ -0,0 +1,16 @@
import json
import os
import logging
from jsmin import jsmin # type:ignore
def load_config(mylogger: logging.Logger, filename: str = "config.json") -> dict:
mylogger.info("loading config file")
if os.path.isfile(filename) is False:
mylogger.info(f"{filename} is missing")
with open(filename, "r") as file:
config = json.loads(jsmin(file.read()))
return config

View file

@ -0,0 +1,63 @@
import logging
import json
def load_meta_data(
mylogger: logging.Logger, filename_meta: str, silent_mode=False
) -> tuple[list[str], str, str, dict, dict, float, float, str]:
if silent_mode is False:
mylogger.info("Loading meta data")
with open(filename_meta, "r") as file_handle:
metadata: dict = json.load(file_handle)
channels: list[str] = metadata["channelKey"]
if silent_mode is False:
mylogger.info(f"meta data: channel order: {channels}")
mouse_markings: str = metadata["sessionMetaData"]["mouseMarkings"]
if silent_mode is False:
mylogger.info(f"meta data: mouse markings: {mouse_markings}")
recording_date: str = metadata["sessionMetaData"]["date"]
if silent_mode is False:
mylogger.info(f"meta data: recording data: {recording_date}")
stimulation_times: dict = metadata["sessionMetaData"]["stimulationTimes"]
if silent_mode is False:
mylogger.info(f"meta data: stimulation times: {stimulation_times}")
experiment_names: dict = metadata["sessionMetaData"]["experimentNames"]
if silent_mode is False:
mylogger.info(f"meta data: experiment names: {experiment_names}")
trial_recording_duration: float = float(
metadata["sessionMetaData"]["trialRecordingDuration"]
)
if silent_mode is False:
mylogger.info(
f"meta data: trial recording duration: {trial_recording_duration} sec"
)
frame_time: float = float(metadata["sessionMetaData"]["frameTime"])
if silent_mode is False:
mylogger.info(
f"meta data: frame time: {frame_time} sec ; frame rate: {1.0/frame_time}Hz"
)
mouse: str = metadata["sessionMetaData"]["mouse"]
if silent_mode is False:
mylogger.info(f"meta data: mouse: {mouse}")
mylogger.info("-==- Done -==-")
return (
channels,
mouse_markings,
recording_date,
stimulation_times,
experiment_names,
trial_recording_duration,
frame_time,
mouse,
)

View file

@ -0,0 +1,140 @@
import torch
import torchvision as tv # type: ignore
import logging
from functions.calculate_rotation import calculate_rotation
from functions.ImageAlignment import ImageAlignment
@torch.no_grad()
def perform_donor_volume_rotation(
mylogger: logging.Logger,
acceptor: torch.Tensor,
donor: torch.Tensor,
oxygenation: torch.Tensor,
volume: torch.Tensor,
ref_image_donor: torch.Tensor,
ref_image_volume: torch.Tensor,
image_alignment: ImageAlignment,
batch_size: int,
config: dict,
fill_value: float = 0,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
mylogger.info("Calculate rotation between donor data and donor ref image")
angle_donor = calculate_rotation(
input=donor,
reference_image=ref_image_donor,
image_alignment=image_alignment,
batch_size=batch_size,
)
mylogger.info("Calculate rotation between volume data and volume ref image")
angle_volume = calculate_rotation(
input=volume,
reference_image=ref_image_volume,
image_alignment=image_alignment,
batch_size=batch_size,
)
mylogger.info("Average over both rotations")
donor_threshold: torch.Tensor = torch.sort(torch.abs(angle_donor))[0]
donor_threshold = donor_threshold[
int(
donor_threshold.shape[0]
* float(config["rotation_stabilization_threshold_border"])
)
] * float(config["rotation_stabilization_threshold_factor"])
volume_threshold: torch.Tensor = torch.sort(torch.abs(angle_volume))[0]
volume_threshold = volume_threshold[
int(
volume_threshold.shape[0]
* float(config["rotation_stabilization_threshold_border"])
)
] * float(config["rotation_stabilization_threshold_factor"])
donor_idx = torch.where(torch.abs(angle_donor) > donor_threshold)[0]
volume_idx = torch.where(torch.abs(angle_volume) > volume_threshold)[0]
mylogger.info(
f"Border: {config['rotation_stabilization_threshold_border']}, "
f"factor {config['rotation_stabilization_threshold_factor']} "
)
mylogger.info(
f"Donor threshold: {donor_threshold:.3e}, volume threshold: {volume_threshold:.3e}"
)
mylogger.info(
f"Found broken rotation values: "
f"donor {int(donor_idx.shape[0])}, "
f"volume {int(volume_idx.shape[0])}"
)
angle_donor[donor_idx] = angle_volume[donor_idx]
angle_volume[volume_idx] = angle_donor[volume_idx]
donor_idx = torch.where(torch.abs(angle_donor) > donor_threshold)[0]
volume_idx = torch.where(torch.abs(angle_volume) > volume_threshold)[0]
mylogger.info(
f"After fill in these broken rotation values remain: "
f"donor {int(donor_idx.shape[0])}, "
f"volume {int(volume_idx.shape[0])}"
)
angle_donor[donor_idx] = 0.0
angle_volume[volume_idx] = 0.0
angle_donor_volume = (angle_donor + angle_volume) / 2.0
mylogger.info("Rotate acceptor data based on the average rotation")
for frame_id in range(0, angle_donor_volume.shape[0]):
acceptor[frame_id, ...] = tv.transforms.functional.affine(
img=acceptor[frame_id, ...].unsqueeze(0),
angle=-float(angle_donor_volume[frame_id]),
translate=[0, 0],
scale=1.0,
shear=0,
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=fill_value,
).squeeze(0)
mylogger.info("Rotate donor data based on the average rotation")
for frame_id in range(0, angle_donor_volume.shape[0]):
donor[frame_id, ...] = tv.transforms.functional.affine(
img=donor[frame_id, ...].unsqueeze(0),
angle=-float(angle_donor_volume[frame_id]),
translate=[0, 0],
scale=1.0,
shear=0,
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=fill_value,
).squeeze(0)
mylogger.info("Rotate oxygenation data based on the average rotation")
for frame_id in range(0, angle_donor_volume.shape[0]):
oxygenation[frame_id, ...] = tv.transforms.functional.affine(
img=oxygenation[frame_id, ...].unsqueeze(0),
angle=-float(angle_donor_volume[frame_id]),
translate=[0, 0],
scale=1.0,
shear=0,
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=fill_value,
).squeeze(0)
mylogger.info("Rotate volume data based on the average rotation")
for frame_id in range(0, angle_donor_volume.shape[0]):
volume[frame_id, ...] = tv.transforms.functional.affine(
img=volume[frame_id, ...].unsqueeze(0),
angle=-float(angle_donor_volume[frame_id]),
translate=[0, 0],
scale=1.0,
shear=0,
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=fill_value,
).squeeze(0)
return (acceptor, donor, oxygenation, volume, angle_donor_volume)

View file

@ -0,0 +1,143 @@
import torch
import torchvision as tv # type: ignore
import logging
from functions.calculate_translation import calculate_translation
from functions.ImageAlignment import ImageAlignment
@torch.no_grad()
def perform_donor_volume_translation(
mylogger: logging.Logger,
acceptor: torch.Tensor,
donor: torch.Tensor,
oxygenation: torch.Tensor,
volume: torch.Tensor,
ref_image_donor: torch.Tensor,
ref_image_volume: torch.Tensor,
image_alignment: ImageAlignment,
batch_size: int,
config: dict,
fill_value: float = 0,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
mylogger.info("Calculate translation between donor data and donor ref image")
tvec_donor = calculate_translation(
input=donor,
reference_image=ref_image_donor,
image_alignment=image_alignment,
batch_size=batch_size,
)
mylogger.info("Calculate translation between volume data and volume ref image")
tvec_volume = calculate_translation(
input=volume,
reference_image=ref_image_volume,
image_alignment=image_alignment,
batch_size=batch_size,
)
mylogger.info("Average over both translations")
for i in range(0, 2):
mylogger.info(f"Processing dimension {i}")
donor_threshold: torch.Tensor = torch.sort(torch.abs(tvec_donor[:, i]))[0]
donor_threshold = donor_threshold[
int(
donor_threshold.shape[0]
* float(config["rotation_stabilization_threshold_border"])
)
] * float(config["rotation_stabilization_threshold_factor"])
volume_threshold: torch.Tensor = torch.sort(torch.abs(tvec_volume[:, i]))[0]
volume_threshold = volume_threshold[
int(
volume_threshold.shape[0]
* float(config["rotation_stabilization_threshold_border"])
)
] * float(config["rotation_stabilization_threshold_factor"])
donor_idx = torch.where(torch.abs(tvec_donor[:, i]) > donor_threshold)[0]
volume_idx = torch.where(torch.abs(tvec_volume[:, i]) > volume_threshold)[0]
mylogger.info(
f"Border: {config['rotation_stabilization_threshold_border']}, "
f"factor {config['rotation_stabilization_threshold_factor']} "
)
mylogger.info(
f"Donor threshold: {donor_threshold:.3e}, volume threshold: {volume_threshold:.3e}"
)
mylogger.info(
f"Found broken rotation values: "
f"donor {int(donor_idx.shape[0])}, "
f"volume {int(volume_idx.shape[0])}"
)
tvec_donor[donor_idx, i] = tvec_volume[donor_idx, i]
tvec_volume[volume_idx, i] = tvec_donor[volume_idx, i]
donor_idx = torch.where(torch.abs(tvec_donor[:, i]) > donor_threshold)[0]
volume_idx = torch.where(torch.abs(tvec_volume[:, i]) > volume_threshold)[0]
mylogger.info(
f"After fill in these broken rotation values remain: "
f"donor {int(donor_idx.shape[0])}, "
f"volume {int(volume_idx.shape[0])}"
)
tvec_donor[donor_idx, i] = 0.0
tvec_volume[volume_idx, i] = 0.0
tvec_donor_volume = (tvec_donor + tvec_volume) / 2.0
mylogger.info("Translate acceptor data based on the average translation vector")
for frame_id in range(0, tvec_donor_volume.shape[0]):
acceptor[frame_id, ...] = tv.transforms.functional.affine(
img=acceptor[frame_id, ...].unsqueeze(0),
angle=0,
translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]],
scale=1.0,
shear=0,
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=fill_value,
).squeeze(0)
mylogger.info("Translate donor data based on the average translation vector")
for frame_id in range(0, tvec_donor_volume.shape[0]):
donor[frame_id, ...] = tv.transforms.functional.affine(
img=donor[frame_id, ...].unsqueeze(0),
angle=0,
translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]],
scale=1.0,
shear=0,
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=fill_value,
).squeeze(0)
mylogger.info("Translate oxygenation data based on the average translation vector")
for frame_id in range(0, tvec_donor_volume.shape[0]):
oxygenation[frame_id, ...] = tv.transforms.functional.affine(
img=oxygenation[frame_id, ...].unsqueeze(0),
angle=0,
translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]],
scale=1.0,
shear=0,
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=fill_value,
).squeeze(0)
mylogger.info("Translate volume data based on the average translation vector")
for frame_id in range(0, tvec_donor_volume.shape[0]):
volume[frame_id, ...] = tv.transforms.functional.affine(
img=volume[frame_id, ...].unsqueeze(0),
angle=0,
translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]],
scale=1.0,
shear=0,
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=fill_value,
).squeeze(0)
return (acceptor, donor, oxygenation, volume, tvec_donor_volume)

117
functions/regression.py Normal file
View file

@ -0,0 +1,117 @@
import torch
import logging
from functions.regression_internal import regression_internal
@torch.no_grad()
def regression(
mylogger: logging.Logger,
target_camera_id: int,
regressor_camera_ids: list[int],
mask: torch.Tensor,
data: torch.Tensor,
data_filtered: torch.Tensor,
first_none_ramp_frame: int,
) -> tuple[torch.Tensor, torch.Tensor]:
assert len(regressor_camera_ids) > 0
mylogger.info("Prepare the target signal - 1.0 (from data_filtered)")
target_signals_train: torch.Tensor = (
data_filtered[target_camera_id, ..., first_none_ramp_frame:].clone() - 1.0
)
target_signals_train[target_signals_train < -1] = 0.0
# Check if everything is happy
assert target_signals_train.ndim == 3
assert target_signals_train.ndim == data[target_camera_id, ...].ndim
assert target_signals_train.shape[0] == data[target_camera_id, ...].shape[0]
assert target_signals_train.shape[1] == data[target_camera_id, ...].shape[1]
assert (target_signals_train.shape[2] + first_none_ramp_frame) == data[
target_camera_id, ...
].shape[2]
mylogger.info("Prepare the regressor signals (linear plus from data_filtered)")
regressor_signals_train: torch.Tensor = torch.zeros(
(
data_filtered.shape[1],
data_filtered.shape[2],
data_filtered.shape[3],
len(regressor_camera_ids) + 1,
),
device=data_filtered.device,
dtype=data_filtered.dtype,
)
mylogger.info("Copy the regressor signals - 1.0")
for matrix_id, id in enumerate(regressor_camera_ids):
regressor_signals_train[..., matrix_id] = data_filtered[id, ...] - 1.0
regressor_signals_train[regressor_signals_train < -1] = 0.0
mylogger.info("Create the linear regressor")
trend = torch.arange(
0, regressor_signals_train.shape[-2], device=data_filtered.device
) / float(regressor_signals_train.shape[-2] - 1)
trend -= trend.mean()
trend = trend.unsqueeze(0).unsqueeze(0)
trend = trend.tile(
(regressor_signals_train.shape[0], regressor_signals_train.shape[1], 1)
)
regressor_signals_train[..., -1] = trend
regressor_signals_train = regressor_signals_train[:, :, first_none_ramp_frame:, :]
mylogger.info("Calculating the regression coefficients")
coefficients, intercept = regression_internal(
input_regressor=regressor_signals_train, input_target=target_signals_train
)
del regressor_signals_train
del target_signals_train
mylogger.info("Prepare the target signal - 1.0 (from data)")
target_signals_perform: torch.Tensor = data[target_camera_id, ...].clone() - 1.0
mylogger.info("Prepare the regressor signals (linear plus from data)")
regressor_signals_perform: torch.Tensor = torch.zeros(
(
data.shape[1],
data.shape[2],
data.shape[3],
len(regressor_camera_ids) + 1,
),
device=data.device,
dtype=data.dtype,
)
mylogger.info("Copy the regressor signals - 1.0 ")
for matrix_id, id in enumerate(regressor_camera_ids):
regressor_signals_perform[..., matrix_id] = data[id] - 1.0
mylogger.info("Create the linear regressor")
trend = torch.arange(
0, regressor_signals_perform.shape[-2], device=data[0].device
) / float(regressor_signals_perform.shape[-2] - 1)
trend -= trend.mean()
trend = trend.unsqueeze(0).unsqueeze(0)
trend = trend.tile(
(regressor_signals_perform.shape[0], regressor_signals_perform.shape[1], 1)
)
regressor_signals_perform[..., -1] = trend
mylogger.info("Remove regressors")
target_signals_perform -= (
regressor_signals_perform * coefficients.unsqueeze(-2)
).sum(dim=-1)
mylogger.info("Remove offset")
target_signals_perform -= intercept.unsqueeze(-1)
mylogger.info("Remove masked pixels")
target_signals_perform[mask, :] = 0.0
mylogger.info("Add an offset of 1.0")
target_signals_perform += 1.0
return target_signals_perform, coefficients

View file

@ -0,0 +1,20 @@
import torch
def regression_internal(
input_regressor: torch.Tensor, input_target: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
regressor_offset = input_regressor.mean(keepdim=True, dim=-2)
target_offset = input_target.mean(keepdim=True, dim=-1)
regressor = input_regressor - regressor_offset
target = input_target - target_offset
coefficients, _, _, _ = torch.linalg.lstsq(regressor, target, rcond=None) # None ?
intercept = target_offset.squeeze(-1) - (
coefficients * regressor_offset.squeeze(-2)
).sum(dim=-1)
return coefficients, intercept