From 861fd3162080912c96c5848cd8ebecc45275a637 Mon Sep 17 00:00:00 2001 From: David Rotermund <54365609+davrot@users.noreply.github.com> Date: Wed, 28 Feb 2024 15:15:37 +0100 Subject: [PATCH] Add files via upload --- new_pipeline/functions/data_raw_loader.py | 339 ++++++++++++++++++++++ new_pipeline/functions/load_meta_data.py | 39 ++- 2 files changed, 363 insertions(+), 15 deletions(-) create mode 100644 new_pipeline/functions/data_raw_loader.py diff --git a/new_pipeline/functions/data_raw_loader.py b/new_pipeline/functions/data_raw_loader.py new file mode 100644 index 0000000..67e55cf --- /dev/null +++ b/new_pipeline/functions/data_raw_loader.py @@ -0,0 +1,339 @@ +import numpy as np +import torch +import os +import logging +import copy + +from functions.get_experiments import get_experiments +from functions.get_trials import get_trials +from functions.get_parts import get_parts +from functions.load_meta_data import load_meta_data + + +def data_raw_loader( + raw_data_path: str, + mylogger: logging.Logger, + experiment_id: int, + trial_id: int, + device: torch.device, + force_to_cpu_memory: bool, + config: dict, +) -> tuple[list[str], str, str, dict, dict, float, float, str, torch.Tensor]: + + meta_channels: list[str] = [] + meta_mouse_markings: str = "" + meta_recording_date: str = "" + meta_stimulation_times: dict = {} + meta_experiment_names: dict = {} + meta_trial_recording_duration: float = 0.0 + meta_frame_time: float = 0.0 + meta_mouse: str = "" + data: torch.Tensor = torch.zeros((1)) + + dtype_str = config["dtype"] + mylogger.info(f"Data precision will be {dtype_str}") + dtype: torch.dtype = getattr(torch, dtype_str) + dtype_np: np.dtype = getattr(np, dtype_str) + + if os.path.isdir(raw_data_path) is False: + mylogger.info(f"ERROR: could not find raw directory {raw_data_path}!!!!") + assert os.path.isdir(raw_data_path) + return ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + data, + ) + + if (torch.where(get_experiments(raw_data_path) == experiment_id)[0].shape[0]) != 1: + mylogger.info(f"ERROR: could not find experiment id {experiment_id}!!!!") + assert ( + torch.where(get_experiments(raw_data_path) == experiment_id)[0].shape[0] + ) == 1 + return ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + data, + ) + + if ( + torch.where(get_trials(raw_data_path, experiment_id) == trial_id)[0].shape[0] + ) != 1: + mylogger.info(f"ERROR: could not find trial id {trial_id}!!!!") + assert ( + torch.where(get_trials(raw_data_path, experiment_id) == trial_id)[0].shape[ + 0 + ] + ) == 1 + return ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + data, + ) + + available_parts: torch.Tensor = get_parts(raw_data_path, experiment_id, trial_id) + if available_parts.shape[0] < 1: + mylogger.info("ERROR: could not find any part files") + assert available_parts.shape[0] >= 1 + + experiment_name = f"Exp{experiment_id:03d}_Trial{trial_id:03d}" + mylogger.info(f"Will work on: {experiment_name}") + + mylogger.info(f"We found {int(available_parts.shape[0])} parts.") + + first_run: bool = True + + mylogger.info("Compare meta data of all parts") + for id in range(0, available_parts.shape[0]): + part_id = available_parts[id] + + filename_meta: str = os.path.join( + raw_data_path, + f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}_meta.txt", + ) + + if os.path.isfile(filename_meta) is False: + mylogger.info(f"Could not load meta data... {filename_meta}") + assert os.path.isfile(filename_meta) + return ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + data, + ) + + ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + ) = load_meta_data( + mylogger=mylogger, filename_meta=filename_meta, silent_mode=True + ) + + if first_run: + first_run = False + master_meta_channels: list[str] = copy.deepcopy(meta_channels) + master_meta_mouse_markings: str = meta_mouse_markings + master_meta_recording_date: str = meta_recording_date + master_meta_stimulation_times: dict = copy.deepcopy(meta_stimulation_times) + master_meta_experiment_names: dict = copy.deepcopy(meta_experiment_names) + master_meta_trial_recording_duration: float = meta_trial_recording_duration + master_meta_frame_time: float = meta_frame_time + master_meta_mouse: str = meta_mouse + + meta_channels_check = master_meta_channels == meta_channels + + # Check channel order + if meta_channels_check: + for channel_a, channel_b in zip(master_meta_channels, meta_channels): + if channel_a != channel_b: + meta_channels_check = False + + meta_mouse_markings_check = master_meta_mouse_markings == meta_mouse_markings + meta_recording_date_check = master_meta_recording_date == meta_recording_date + meta_stimulation_times_check = ( + master_meta_stimulation_times == meta_stimulation_times + ) + meta_experiment_names_check = ( + master_meta_experiment_names == meta_experiment_names + ) + meta_trial_recording_duration_check = ( + master_meta_trial_recording_duration == meta_trial_recording_duration + ) + meta_frame_time_check = master_meta_frame_time == meta_frame_time + meta_mouse_check = master_meta_mouse == meta_mouse + + if meta_channels_check is False: + mylogger.info(f"{filename_meta} failed: channels") + assert meta_channels_check + + if meta_mouse_markings_check is False: + mylogger.info(f"{filename_meta} failed: mouse_markings") + assert meta_mouse_markings_check + + if meta_recording_date_check is False: + mylogger.info(f"{filename_meta} failed: recording_date") + assert meta_recording_date_check + + if meta_stimulation_times_check is False: + mylogger.info(f"{filename_meta} failed: stimulation_times") + assert meta_stimulation_times_check + + if meta_experiment_names_check is False: + mylogger.info(f"{filename_meta} failed: experiment_names") + assert meta_experiment_names_check + + if meta_trial_recording_duration_check is False: + mylogger.info(f"{filename_meta} failed: trial_recording_duration") + assert meta_trial_recording_duration_check + + if meta_frame_time_check is False: + mylogger.info(f"{filename_meta} failed: frame_time_check") + assert meta_frame_time_check + + if meta_mouse_check is False: + mylogger.info(f"{filename_meta} failed: mouse") + assert meta_mouse_check + mylogger.info("-==- Done -==-") + + mylogger.info(f"Will use: {filename_meta} for meta data") + ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + ) = load_meta_data(mylogger=mylogger, filename_meta=filename_meta) + + ################# + # Meta data end # + ################# + + first_run = True + mylogger.info("Count the number of frames in the data of all parts") + frame_count: int = 0 + for id in range(0, available_parts.shape[0]): + part_id = available_parts[id] + + filename_data: str = os.path.join( + raw_data_path, + f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}.npy", + ) + + if os.path.isfile(filename_data) is False: + mylogger.info(f"Could not load data... {filename_data}") + assert os.path.isfile(filename_data) + return ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + data, + ) + data_np: np.ndarray = np.load(filename_data, mmap_mode="r") + + if data_np.ndim != 4: + mylogger.info(f"ERROR: Data needs to have 4 dimensions {filename_data}") + assert data_np.ndim == 4 + + if first_run: + first_run = False + dim_0: int = int(data_np.shape[0]) + dim_1: int = int(data_np.shape[1]) + dim_3: int = int(data_np.shape[3]) + + frame_count += int(data_np.shape[2]) + + if int(data_np.shape[0]) != dim_0: + mylogger.info( + f"ERROR: Data dim 0 is broken {int(data_np.shape[0])} vs {dim_0} {filename_data}" + ) + assert int(data_np.shape[0]) == dim_0 + + if int(data_np.shape[1]) != dim_1: + mylogger.info( + f"ERROR: Data dim 1 is broken {int(data_np.shape[1])} vs {dim_1} {filename_data}" + ) + assert int(data_np.shape[1]) == dim_1 + + if int(data_np.shape[3]) != dim_3: + mylogger.info( + f"ERROR: Data dim 3 is broken {int(data_np.shape[3])} vs {dim_3} {filename_data}" + ) + assert int(data_np.shape[3]) == dim_3 + + mylogger.info( + f"{filename_data}: {int(data_np.shape[2])} frames -> {frame_count} frames total" + ) + + if force_to_cpu_memory: + mylogger.info("Using CPU memory for data") + data = torch.empty( + (dim_0, dim_1, frame_count, dim_3), dtype=dtype, device=torch.device("cpu") + ) + else: + mylogger.info("Using GPU memory for data") + data = torch.empty( + (dim_0, dim_1, frame_count, dim_3), dtype=dtype, device=device + ) + + start_position: int = 0 + end_position: int = 0 + for id in range(0, available_parts.shape[0]): + part_id = available_parts[id] + + filename_data = os.path.join( + raw_data_path, + f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}.npy", + ) + + mylogger.info(f"Will work on {filename_data}") + mylogger.info("Loading data file") + data_np = np.load(filename_data).astype(dtype_np) + + end_position = start_position + int(data_np.shape[2]) + + for i in range(0, len(config["required_order"])): + mylogger.info(f"Move raw data channel: {config['required_order'][i]}") + + idx = meta_channels.index(config["required_order"][i]) + data[..., start_position:end_position, i] = torch.tensor( + data_np[..., idx], dtype=dtype, device=data.device + ) + start_position = end_position + + if start_position != int(data.shape[2]): + mylogger.info("ERROR: data was not fulled fully!!!") + assert start_position == int(data.shape[2]) + + mylogger.info("-==- Done -==-") + + ################# + # Raw data end # + ################# + + return ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + data, + ) diff --git a/new_pipeline/functions/load_meta_data.py b/new_pipeline/functions/load_meta_data.py index 641beb7..2a893e0 100644 --- a/new_pipeline/functions/load_meta_data.py +++ b/new_pipeline/functions/load_meta_data.py @@ -3,44 +3,53 @@ import json def load_meta_data( - mylogger: logging.Logger, filename_meta: str + mylogger: logging.Logger, filename_meta: str, silent_mode=False ) -> tuple[list[str], str, str, dict, dict, float, float, str]: - mylogger.info("Loading meta data") + if silent_mode is False: + mylogger.info("Loading meta data") with open(filename_meta, "r") as file_handle: metadata: dict = json.load(file_handle) channels: list[str] = metadata["channelKey"] - mylogger.info(f"meta data: channel order: {channels}") + if silent_mode is False: + mylogger.info(f"meta data: channel order: {channels}") mouse_markings: str = metadata["sessionMetaData"]["mouseMarkings"] - mylogger.info(f"meta data: mouse markings: {mouse_markings}") + if silent_mode is False: + mylogger.info(f"meta data: mouse markings: {mouse_markings}") recording_date: str = metadata["sessionMetaData"]["date"] - mylogger.info(f"meta data: recording data: {recording_date}") + if silent_mode is False: + mylogger.info(f"meta data: recording data: {recording_date}") stimulation_times: dict = metadata["sessionMetaData"]["stimulationTimes"] - mylogger.info(f"meta data: stimulation times: {stimulation_times}") + if silent_mode is False: + mylogger.info(f"meta data: stimulation times: {stimulation_times}") experiment_names: dict = metadata["sessionMetaData"]["experimentNames"] - mylogger.info(f"meta data: experiment names: {experiment_names}") + if silent_mode is False: + mylogger.info(f"meta data: experiment names: {experiment_names}") trial_recording_duration: float = float( metadata["sessionMetaData"]["trialRecordingDuration"] ) - mylogger.info( - f"meta data: trial recording duration: {trial_recording_duration} sec" - ) + if silent_mode is False: + mylogger.info( + f"meta data: trial recording duration: {trial_recording_duration} sec" + ) frame_time: float = float(metadata["sessionMetaData"]["frameTime"]) - mylogger.info( - f"meta data: frame time: {frame_time} sec ; frame rate: {1.0/frame_time}Hz" - ) + if silent_mode is False: + mylogger.info( + f"meta data: frame time: {frame_time} sec ; frame rate: {1.0/frame_time}Hz" + ) mouse: str = metadata["sessionMetaData"]["mouse"] - mylogger.info(f"meta data: mouse: {mouse}") - mylogger.info("-==- Done -==-") + if silent_mode is False: + mylogger.info(f"meta data: mouse: {mouse}") + mylogger.info("-==- Done -==-") return ( channels,