diff --git a/reproduction_effort/functions/get_experiments.py b/reproduction_effort/functions/get_experiments.py new file mode 100644 index 0000000..d92b936 --- /dev/null +++ b/reproduction_effort/functions/get_experiments.py @@ -0,0 +1,19 @@ +import torch +import os +import glob + + +@torch.no_grad() +def get_experiments(path: str) -> torch.Tensor: + filename_np: str = os.path.join( + path, + "Exp*_Part001.npy", + ) + + list_str = glob.glob(filename_np) + list_int: list[int] = [] + for i in range(0, len(list_str)): + list_int.append(int(list_str[i].split("Exp")[-1].split("_Trial")[0])) + list_int = sorted(list_int) + + return torch.tensor(list_int).unique() diff --git a/reproduction_effort/functions/get_parts.py b/reproduction_effort/functions/get_parts.py new file mode 100644 index 0000000..d68e1ae --- /dev/null +++ b/reproduction_effort/functions/get_parts.py @@ -0,0 +1,18 @@ +import torch +import os +import glob + + +@torch.no_grad() +def get_parts(path: str, experiment_id: int, trial_id: int) -> torch.Tensor: + filename_np: str = os.path.join( + path, + f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part*.npy", + ) + + list_str = glob.glob(filename_np) + list_int: list[int] = [] + for i in range(0, len(list_str)): + list_int.append(int(list_str[i].split("_Part")[-1].split(".npy")[0])) + list_int = sorted(list_int) + return torch.tensor(list_int).unique() diff --git a/reproduction_effort/functions/get_trials.py b/reproduction_effort/functions/get_trials.py new file mode 100644 index 0000000..8c687d9 --- /dev/null +++ b/reproduction_effort/functions/get_trials.py @@ -0,0 +1,18 @@ +import torch +import os +import glob + + +@torch.no_grad() +def get_trials(path: str, experiment_id: int) -> torch.Tensor: + filename_np: str = os.path.join( + path, + f"Exp{experiment_id:03d}_Trial*_Part001.npy", + ) + + list_str = glob.glob(filename_np) + list_int: list[int] = [] + for i in range(0, len(list_str)): + list_int.append(int(list_str[i].split("_Trial")[-1].split("_Part")[0])) + list_int = sorted(list_int) + return torch.tensor(list_int).unique()