Add files via upload

This commit is contained in:
David Rotermund 2024-02-28 15:15:37 +01:00 committed by GitHub
parent ffa70cfe7a
commit 861fd31620
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 363 additions and 15 deletions

View file

@ -0,0 +1,339 @@
import numpy as np
import torch
import os
import logging
import copy
from functions.get_experiments import get_experiments
from functions.get_trials import get_trials
from functions.get_parts import get_parts
from functions.load_meta_data import load_meta_data
def data_raw_loader(
raw_data_path: str,
mylogger: logging.Logger,
experiment_id: int,
trial_id: int,
device: torch.device,
force_to_cpu_memory: bool,
config: dict,
) -> tuple[list[str], str, str, dict, dict, float, float, str, torch.Tensor]:
meta_channels: list[str] = []
meta_mouse_markings: str = ""
meta_recording_date: str = ""
meta_stimulation_times: dict = {}
meta_experiment_names: dict = {}
meta_trial_recording_duration: float = 0.0
meta_frame_time: float = 0.0
meta_mouse: str = ""
data: torch.Tensor = torch.zeros((1))
dtype_str = config["dtype"]
mylogger.info(f"Data precision will be {dtype_str}")
dtype: torch.dtype = getattr(torch, dtype_str)
dtype_np: np.dtype = getattr(np, dtype_str)
if os.path.isdir(raw_data_path) is False:
mylogger.info(f"ERROR: could not find raw directory {raw_data_path}!!!!")
assert os.path.isdir(raw_data_path)
return (
meta_channels,
meta_mouse_markings,
meta_recording_date,
meta_stimulation_times,
meta_experiment_names,
meta_trial_recording_duration,
meta_frame_time,
meta_mouse,
data,
)
if (torch.where(get_experiments(raw_data_path) == experiment_id)[0].shape[0]) != 1:
mylogger.info(f"ERROR: could not find experiment id {experiment_id}!!!!")
assert (
torch.where(get_experiments(raw_data_path) == experiment_id)[0].shape[0]
) == 1
return (
meta_channels,
meta_mouse_markings,
meta_recording_date,
meta_stimulation_times,
meta_experiment_names,
meta_trial_recording_duration,
meta_frame_time,
meta_mouse,
data,
)
if (
torch.where(get_trials(raw_data_path, experiment_id) == trial_id)[0].shape[0]
) != 1:
mylogger.info(f"ERROR: could not find trial id {trial_id}!!!!")
assert (
torch.where(get_trials(raw_data_path, experiment_id) == trial_id)[0].shape[
0
]
) == 1
return (
meta_channels,
meta_mouse_markings,
meta_recording_date,
meta_stimulation_times,
meta_experiment_names,
meta_trial_recording_duration,
meta_frame_time,
meta_mouse,
data,
)
available_parts: torch.Tensor = get_parts(raw_data_path, experiment_id, trial_id)
if available_parts.shape[0] < 1:
mylogger.info("ERROR: could not find any part files")
assert available_parts.shape[0] >= 1
experiment_name = f"Exp{experiment_id:03d}_Trial{trial_id:03d}"
mylogger.info(f"Will work on: {experiment_name}")
mylogger.info(f"We found {int(available_parts.shape[0])} parts.")
first_run: bool = True
mylogger.info("Compare meta data of all parts")
for id in range(0, available_parts.shape[0]):
part_id = available_parts[id]
filename_meta: str = os.path.join(
raw_data_path,
f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}_meta.txt",
)
if os.path.isfile(filename_meta) is False:
mylogger.info(f"Could not load meta data... {filename_meta}")
assert os.path.isfile(filename_meta)
return (
meta_channels,
meta_mouse_markings,
meta_recording_date,
meta_stimulation_times,
meta_experiment_names,
meta_trial_recording_duration,
meta_frame_time,
meta_mouse,
data,
)
(
meta_channels,
meta_mouse_markings,
meta_recording_date,
meta_stimulation_times,
meta_experiment_names,
meta_trial_recording_duration,
meta_frame_time,
meta_mouse,
) = load_meta_data(
mylogger=mylogger, filename_meta=filename_meta, silent_mode=True
)
if first_run:
first_run = False
master_meta_channels: list[str] = copy.deepcopy(meta_channels)
master_meta_mouse_markings: str = meta_mouse_markings
master_meta_recording_date: str = meta_recording_date
master_meta_stimulation_times: dict = copy.deepcopy(meta_stimulation_times)
master_meta_experiment_names: dict = copy.deepcopy(meta_experiment_names)
master_meta_trial_recording_duration: float = meta_trial_recording_duration
master_meta_frame_time: float = meta_frame_time
master_meta_mouse: str = meta_mouse
meta_channels_check = master_meta_channels == meta_channels
# Check channel order
if meta_channels_check:
for channel_a, channel_b in zip(master_meta_channels, meta_channels):
if channel_a != channel_b:
meta_channels_check = False
meta_mouse_markings_check = master_meta_mouse_markings == meta_mouse_markings
meta_recording_date_check = master_meta_recording_date == meta_recording_date
meta_stimulation_times_check = (
master_meta_stimulation_times == meta_stimulation_times
)
meta_experiment_names_check = (
master_meta_experiment_names == meta_experiment_names
)
meta_trial_recording_duration_check = (
master_meta_trial_recording_duration == meta_trial_recording_duration
)
meta_frame_time_check = master_meta_frame_time == meta_frame_time
meta_mouse_check = master_meta_mouse == meta_mouse
if meta_channels_check is False:
mylogger.info(f"{filename_meta} failed: channels")
assert meta_channels_check
if meta_mouse_markings_check is False:
mylogger.info(f"{filename_meta} failed: mouse_markings")
assert meta_mouse_markings_check
if meta_recording_date_check is False:
mylogger.info(f"{filename_meta} failed: recording_date")
assert meta_recording_date_check
if meta_stimulation_times_check is False:
mylogger.info(f"{filename_meta} failed: stimulation_times")
assert meta_stimulation_times_check
if meta_experiment_names_check is False:
mylogger.info(f"{filename_meta} failed: experiment_names")
assert meta_experiment_names_check
if meta_trial_recording_duration_check is False:
mylogger.info(f"{filename_meta} failed: trial_recording_duration")
assert meta_trial_recording_duration_check
if meta_frame_time_check is False:
mylogger.info(f"{filename_meta} failed: frame_time_check")
assert meta_frame_time_check
if meta_mouse_check is False:
mylogger.info(f"{filename_meta} failed: mouse")
assert meta_mouse_check
mylogger.info("-==- Done -==-")
mylogger.info(f"Will use: {filename_meta} for meta data")
(
meta_channels,
meta_mouse_markings,
meta_recording_date,
meta_stimulation_times,
meta_experiment_names,
meta_trial_recording_duration,
meta_frame_time,
meta_mouse,
) = load_meta_data(mylogger=mylogger, filename_meta=filename_meta)
#################
# Meta data end #
#################
first_run = True
mylogger.info("Count the number of frames in the data of all parts")
frame_count: int = 0
for id in range(0, available_parts.shape[0]):
part_id = available_parts[id]
filename_data: str = os.path.join(
raw_data_path,
f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}.npy",
)
if os.path.isfile(filename_data) is False:
mylogger.info(f"Could not load data... {filename_data}")
assert os.path.isfile(filename_data)
return (
meta_channels,
meta_mouse_markings,
meta_recording_date,
meta_stimulation_times,
meta_experiment_names,
meta_trial_recording_duration,
meta_frame_time,
meta_mouse,
data,
)
data_np: np.ndarray = np.load(filename_data, mmap_mode="r")
if data_np.ndim != 4:
mylogger.info(f"ERROR: Data needs to have 4 dimensions {filename_data}")
assert data_np.ndim == 4
if first_run:
first_run = False
dim_0: int = int(data_np.shape[0])
dim_1: int = int(data_np.shape[1])
dim_3: int = int(data_np.shape[3])
frame_count += int(data_np.shape[2])
if int(data_np.shape[0]) != dim_0:
mylogger.info(
f"ERROR: Data dim 0 is broken {int(data_np.shape[0])} vs {dim_0} {filename_data}"
)
assert int(data_np.shape[0]) == dim_0
if int(data_np.shape[1]) != dim_1:
mylogger.info(
f"ERROR: Data dim 1 is broken {int(data_np.shape[1])} vs {dim_1} {filename_data}"
)
assert int(data_np.shape[1]) == dim_1
if int(data_np.shape[3]) != dim_3:
mylogger.info(
f"ERROR: Data dim 3 is broken {int(data_np.shape[3])} vs {dim_3} {filename_data}"
)
assert int(data_np.shape[3]) == dim_3
mylogger.info(
f"{filename_data}: {int(data_np.shape[2])} frames -> {frame_count} frames total"
)
if force_to_cpu_memory:
mylogger.info("Using CPU memory for data")
data = torch.empty(
(dim_0, dim_1, frame_count, dim_3), dtype=dtype, device=torch.device("cpu")
)
else:
mylogger.info("Using GPU memory for data")
data = torch.empty(
(dim_0, dim_1, frame_count, dim_3), dtype=dtype, device=device
)
start_position: int = 0
end_position: int = 0
for id in range(0, available_parts.shape[0]):
part_id = available_parts[id]
filename_data = os.path.join(
raw_data_path,
f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}.npy",
)
mylogger.info(f"Will work on {filename_data}")
mylogger.info("Loading data file")
data_np = np.load(filename_data).astype(dtype_np)
end_position = start_position + int(data_np.shape[2])
for i in range(0, len(config["required_order"])):
mylogger.info(f"Move raw data channel: {config['required_order'][i]}")
idx = meta_channels.index(config["required_order"][i])
data[..., start_position:end_position, i] = torch.tensor(
data_np[..., idx], dtype=dtype, device=data.device
)
start_position = end_position
if start_position != int(data.shape[2]):
mylogger.info("ERROR: data was not fulled fully!!!")
assert start_position == int(data.shape[2])
mylogger.info("-==- Done -==-")
#################
# Raw data end #
#################
return (
meta_channels,
meta_mouse_markings,
meta_recording_date,
meta_stimulation_times,
meta_experiment_names,
meta_trial_recording_duration,
meta_frame_time,
meta_mouse,
data,
)

View file

@ -3,44 +3,53 @@ import json
def load_meta_data(
mylogger: logging.Logger, filename_meta: str
mylogger: logging.Logger, filename_meta: str, silent_mode=False
) -> tuple[list[str], str, str, dict, dict, float, float, str]:
mylogger.info("Loading meta data")
if silent_mode is False:
mylogger.info("Loading meta data")
with open(filename_meta, "r") as file_handle:
metadata: dict = json.load(file_handle)
channels: list[str] = metadata["channelKey"]
mylogger.info(f"meta data: channel order: {channels}")
if silent_mode is False:
mylogger.info(f"meta data: channel order: {channels}")
mouse_markings: str = metadata["sessionMetaData"]["mouseMarkings"]
mylogger.info(f"meta data: mouse markings: {mouse_markings}")
if silent_mode is False:
mylogger.info(f"meta data: mouse markings: {mouse_markings}")
recording_date: str = metadata["sessionMetaData"]["date"]
mylogger.info(f"meta data: recording data: {recording_date}")
if silent_mode is False:
mylogger.info(f"meta data: recording data: {recording_date}")
stimulation_times: dict = metadata["sessionMetaData"]["stimulationTimes"]
mylogger.info(f"meta data: stimulation times: {stimulation_times}")
if silent_mode is False:
mylogger.info(f"meta data: stimulation times: {stimulation_times}")
experiment_names: dict = metadata["sessionMetaData"]["experimentNames"]
mylogger.info(f"meta data: experiment names: {experiment_names}")
if silent_mode is False:
mylogger.info(f"meta data: experiment names: {experiment_names}")
trial_recording_duration: float = float(
metadata["sessionMetaData"]["trialRecordingDuration"]
)
mylogger.info(
f"meta data: trial recording duration: {trial_recording_duration} sec"
)
if silent_mode is False:
mylogger.info(
f"meta data: trial recording duration: {trial_recording_duration} sec"
)
frame_time: float = float(metadata["sessionMetaData"]["frameTime"])
mylogger.info(
f"meta data: frame time: {frame_time} sec ; frame rate: {1.0/frame_time}Hz"
)
if silent_mode is False:
mylogger.info(
f"meta data: frame time: {frame_time} sec ; frame rate: {1.0/frame_time}Hz"
)
mouse: str = metadata["sessionMetaData"]["mouse"]
mylogger.info(f"meta data: mouse: {mouse}")
mylogger.info("-==- Done -==-")
if silent_mode is False:
mylogger.info(f"meta data: mouse: {mouse}")
mylogger.info("-==- Done -==-")
return (
channels,