Add files via upload
Additional files used for analysis
This commit is contained in:
parent
81aea7fecd
commit
f7e931ba3d
36 changed files with 5376 additions and 0 deletions
49
thesis code/network analysis/freeParamCalc.py
Normal file
49
thesis code/network analysis/freeParamCalc.py
Normal file
|
@ -0,0 +1,49 @@
|
|||
import torch
|
||||
|
||||
|
||||
def calc_free_params(from_loaded_model: bool, model_name: str | None):
|
||||
"""
|
||||
* Calculates the number of free parameters of a CNN
|
||||
* either from trained model or by entering the respective parameters
|
||||
over command line
|
||||
"""
|
||||
|
||||
if from_loaded_model:
|
||||
# path to NN
|
||||
PATH = f"D:/Katha/Neuroscience/Semester 4/newCode/kk_contour_net_shallow-main/trained_models/{model_name}"
|
||||
|
||||
# load and evaluate model
|
||||
model = torch.load(PATH).to("cpu")
|
||||
model.eval()
|
||||
print(model)
|
||||
|
||||
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print(f"Total number of free parameters: {total_params}")
|
||||
else:
|
||||
print("\n##########################")
|
||||
input_out_channel_size = input(
|
||||
"Enter output channel size (comma seperated, including output layer): "
|
||||
)
|
||||
out_channel_size = [1] + [int(x) for x in input_out_channel_size.split(",")]
|
||||
|
||||
input_kernel_sizes = input(
|
||||
"Enter kernel sizes of respective layers (comma seperated, including output layer): "
|
||||
)
|
||||
kernel_sizes = [int(x) for x in input_kernel_sizes.split(",")]
|
||||
|
||||
total_params = 0
|
||||
for i in range(1, len(out_channel_size)):
|
||||
input_size = out_channel_size[i - 1]
|
||||
out_size = out_channel_size[i]
|
||||
kernel = kernel_sizes[i - 1]
|
||||
bias = out_channel_size[i]
|
||||
num_free_params = input_size * kernel * kernel * out_size + bias
|
||||
total_params += num_free_params
|
||||
print(f"Total number of free parameters: {total_params}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# model name
|
||||
nn = "ArghCNN_numConvLayers3_outChannels[8, 8, 8]_kernelSize[7, 15]_leaky relu_stride1_trainFirstConvLayerTrue_seed291857_Natural_1351Epoch_3107-2121.pt"
|
||||
|
||||
calc_free_params(from_loaded_model=False, model_name=nn)
|
18
thesis code/network analysis/minimal_architecture/README.txt
Normal file
18
thesis code/network analysis/minimal_architecture/README.txt
Normal file
|
@ -0,0 +1,18 @@
|
|||
Folder minimal_architecture:
|
||||
|
||||
1. config.json:
|
||||
* json file with all configurations and cnn parameters
|
||||
|
||||
2. training_loop.sh:
|
||||
* bash script to train the 64 cnns
|
||||
|
||||
|
||||
3. get_trained_models:
|
||||
* searches for the saved trained models in a directory
|
||||
* chooses model based on the largest saved epoch in the save-name
|
||||
|
||||
|
||||
4. pfinkel_performance_test64:
|
||||
* load all models extracted by 'get_trained_models'
|
||||
* test them on all stimulus conditions
|
||||
* sort their performances either after number of free parameters, or architecture
|
368
thesis code/network analysis/minimal_architecture/config.json
Normal file
368
thesis code/network analysis/minimal_architecture/config.json
Normal file
|
@ -0,0 +1,368 @@
|
|||
{
|
||||
"data_path": "/home/kk/Documents/Semester4/code/RenderStimuli/Output/",
|
||||
"save_logging_messages": true, // (true), false
|
||||
"display_logging_messages": true, // (true), false
|
||||
"batch_size_train": 500,
|
||||
"batch_size_test": 250,
|
||||
"max_epochs": 2000,
|
||||
"save_model": true,
|
||||
"conv_0_kernel_size": 11,
|
||||
"mp_1_kernel_size": 3,
|
||||
"mp_1_stride": 2,
|
||||
"use_plot_intermediate": true, // true, (false)
|
||||
"stimuli_per_pfinkel": 10000,
|
||||
"num_pfinkel_start": 0,
|
||||
"num_pfinkel_stop": 100,
|
||||
"num_pfinkel_step": 10,
|
||||
"precision_100_percent": 0, // (4)
|
||||
"train_first_layer": true, // true, (false)
|
||||
"save_ever_x_epochs": 100, // (10)
|
||||
"activation_function": "leaky relu", // tanh, relu, (leaky relu), none
|
||||
"leak_relu_negative_slope": 0.1, // (0.1)
|
||||
"switch_leakyR_to_relu": false,
|
||||
// LR Scheduler ->
|
||||
"use_scheduler": true, // (true), false
|
||||
"scheduler_verbose": true,
|
||||
"scheduler_factor": 0.1, //(0.1)
|
||||
"scheduler_patience": 10, // (10)
|
||||
"scheduler_threshold": 1e-5, // (1e-4)
|
||||
"minimum_learning_rate": 1e-8,
|
||||
"learning_rate": 0.0001,
|
||||
// <- LR Scheduler
|
||||
"pooling_type": "max", // (max), average, none
|
||||
"conv_0_enable_softmax": false, // true, (false)
|
||||
"use_adam": true, // (true) => adam, false => SGD
|
||||
"condition": "Natural",
|
||||
"scale_data": 255.0, // (255.0)
|
||||
"conv_out_channels_list": [
|
||||
[
|
||||
8,
|
||||
8,
|
||||
8
|
||||
],
|
||||
[
|
||||
8,
|
||||
8,
|
||||
6
|
||||
],
|
||||
[
|
||||
8,
|
||||
8,
|
||||
4
|
||||
],
|
||||
[
|
||||
8,
|
||||
8,
|
||||
2
|
||||
],
|
||||
[
|
||||
8,
|
||||
6,
|
||||
8
|
||||
],
|
||||
[
|
||||
8,
|
||||
6,
|
||||
6
|
||||
],
|
||||
[
|
||||
8,
|
||||
6,
|
||||
4
|
||||
],
|
||||
[
|
||||
8,
|
||||
6,
|
||||
2
|
||||
],
|
||||
[
|
||||
8,
|
||||
4,
|
||||
8
|
||||
],
|
||||
[
|
||||
8,
|
||||
4,
|
||||
6
|
||||
],
|
||||
[
|
||||
8,
|
||||
4,
|
||||
4
|
||||
],
|
||||
[
|
||||
8,
|
||||
4,
|
||||
2
|
||||
],
|
||||
[
|
||||
8,
|
||||
2,
|
||||
8
|
||||
],
|
||||
[
|
||||
8,
|
||||
2,
|
||||
6
|
||||
],
|
||||
[
|
||||
8,
|
||||
2,
|
||||
4
|
||||
],
|
||||
[
|
||||
8,
|
||||
2,
|
||||
2
|
||||
],
|
||||
[
|
||||
6,
|
||||
8,
|
||||
8
|
||||
],
|
||||
[
|
||||
6,
|
||||
8,
|
||||
6
|
||||
],
|
||||
[
|
||||
6,
|
||||
8,
|
||||
4
|
||||
],
|
||||
[
|
||||
6,
|
||||
8,
|
||||
2
|
||||
],
|
||||
[
|
||||
6,
|
||||
6,
|
||||
8
|
||||
],
|
||||
[
|
||||
6,
|
||||
6,
|
||||
6
|
||||
],
|
||||
[
|
||||
6,
|
||||
6,
|
||||
4
|
||||
],
|
||||
[
|
||||
6,
|
||||
6,
|
||||
2
|
||||
],
|
||||
[
|
||||
6,
|
||||
4,
|
||||
8
|
||||
],
|
||||
[
|
||||
6,
|
||||
4,
|
||||
6
|
||||
],
|
||||
[
|
||||
6,
|
||||
4,
|
||||
4
|
||||
],
|
||||
[
|
||||
6,
|
||||
4,
|
||||
2
|
||||
],
|
||||
[
|
||||
6,
|
||||
2,
|
||||
8
|
||||
],
|
||||
[
|
||||
6,
|
||||
2,
|
||||
6
|
||||
],
|
||||
[
|
||||
6,
|
||||
2,
|
||||
4
|
||||
],
|
||||
[
|
||||
6,
|
||||
2,
|
||||
2
|
||||
],
|
||||
[
|
||||
4,
|
||||
8,
|
||||
8
|
||||
],
|
||||
[
|
||||
4,
|
||||
8,
|
||||
6
|
||||
],
|
||||
[
|
||||
4,
|
||||
8,
|
||||
4
|
||||
],
|
||||
[
|
||||
4,
|
||||
8,
|
||||
2
|
||||
],
|
||||
[
|
||||
4,
|
||||
6,
|
||||
8
|
||||
],
|
||||
[
|
||||
4,
|
||||
6,
|
||||
6
|
||||
],
|
||||
[
|
||||
4,
|
||||
6,
|
||||
4
|
||||
],
|
||||
[
|
||||
4,
|
||||
6,
|
||||
2
|
||||
],
|
||||
[
|
||||
4,
|
||||
4,
|
||||
8
|
||||
],
|
||||
[
|
||||
4,
|
||||
4,
|
||||
6
|
||||
],
|
||||
[
|
||||
4,
|
||||
4,
|
||||
4
|
||||
],
|
||||
[
|
||||
4,
|
||||
4,
|
||||
2
|
||||
],
|
||||
[
|
||||
4,
|
||||
2,
|
||||
8
|
||||
],
|
||||
[
|
||||
4,
|
||||
2,
|
||||
6
|
||||
],
|
||||
[
|
||||
4,
|
||||
2,
|
||||
4
|
||||
],
|
||||
[
|
||||
4,
|
||||
2,
|
||||
2
|
||||
],
|
||||
[
|
||||
2,
|
||||
8,
|
||||
8
|
||||
],
|
||||
[
|
||||
2,
|
||||
8,
|
||||
6
|
||||
],
|
||||
[
|
||||
2,
|
||||
8,
|
||||
4
|
||||
],
|
||||
[
|
||||
2,
|
||||
8,
|
||||
2
|
||||
],
|
||||
[
|
||||
2,
|
||||
6,
|
||||
8
|
||||
],
|
||||
[
|
||||
2,
|
||||
6,
|
||||
6
|
||||
],
|
||||
[
|
||||
2,
|
||||
6,
|
||||
4
|
||||
],
|
||||
[
|
||||
2,
|
||||
6,
|
||||
2
|
||||
],
|
||||
[
|
||||
2,
|
||||
4,
|
||||
8
|
||||
],
|
||||
[
|
||||
2,
|
||||
4,
|
||||
6
|
||||
],
|
||||
[
|
||||
2,
|
||||
4,
|
||||
4
|
||||
],
|
||||
[
|
||||
2,
|
||||
4,
|
||||
2
|
||||
],
|
||||
[
|
||||
2,
|
||||
2,
|
||||
8
|
||||
],
|
||||
[
|
||||
2,
|
||||
2,
|
||||
6
|
||||
],
|
||||
[
|
||||
2,
|
||||
2,
|
||||
4
|
||||
],
|
||||
[
|
||||
2,
|
||||
2,
|
||||
2
|
||||
]
|
||||
],
|
||||
"conv_kernel_sizes": [
|
||||
[
|
||||
7,
|
||||
15
|
||||
]
|
||||
],
|
||||
"conv_stride_sizes": [
|
||||
1
|
||||
]
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
import glob
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
|
||||
"""
|
||||
get performances from .pt files
|
||||
"""
|
||||
|
||||
directory = "./trained_models"
|
||||
string = "Natural"
|
||||
final_path = "./trained_corners"
|
||||
|
||||
|
||||
# list of all files in the directory
|
||||
files = glob.glob(directory + "/*.pt")
|
||||
|
||||
# filter
|
||||
filtered_files = [f for f in files if string in f]
|
||||
|
||||
# group by seed
|
||||
seed_files = {}
|
||||
for f in filtered_files:
|
||||
# get seed from filename
|
||||
match = re.search(r"_seed(\d+)_", f)
|
||||
if match:
|
||||
seed = int(match.group(1))
|
||||
if seed not in seed_files:
|
||||
seed_files[seed] = []
|
||||
seed_files[seed].append(f)
|
||||
|
||||
|
||||
# get saved cnn largests epoch
|
||||
newest_files = {}
|
||||
for seed, files in seed_files.items():
|
||||
max_epoch = -1
|
||||
newest_file = None
|
||||
for f in files:
|
||||
# search for epoch
|
||||
match = re.search(r"_(\d+)Epoch_", f)
|
||||
if match:
|
||||
epoch = int(match.group(1))
|
||||
if epoch > max_epoch:
|
||||
max_epoch = epoch
|
||||
newest_file = f
|
||||
newest_files[seed] = newest_file
|
||||
|
||||
print(len(newest_files))
|
||||
|
||||
# move files to new folder
|
||||
os.makedirs(final_path, exist_ok=True)
|
||||
|
||||
# Copy the files to the new folder
|
||||
for seed, file in newest_files.items():
|
||||
shutil.copy(file, os.path.join(final_path, os.path.basename(file)))
|
|
@ -0,0 +1,282 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib as mpl
|
||||
import os
|
||||
import datetime
|
||||
import re
|
||||
|
||||
# import glob
|
||||
# from natsort import natsorted
|
||||
|
||||
mpl.rcParams["text.usetex"] = True
|
||||
mpl.rcParams["font.family"] = "serif"
|
||||
|
||||
from functions.alicorn_data_loader import alicorn_data_loader
|
||||
from functions.create_logger import create_logger
|
||||
|
||||
|
||||
def sort_and_plot(
|
||||
extracted_params,
|
||||
save: bool,
|
||||
plot_for_each_condition: bool,
|
||||
name: str,
|
||||
sort_by="params",
|
||||
):
|
||||
figure_path: str = "performance_pfinkel_0210"
|
||||
os.makedirs(figure_path, exist_ok=True)
|
||||
|
||||
architecture_params = extracted_params.copy()
|
||||
if sort_by == "params":
|
||||
architecture_params.sort(key=lambda x: x[1])
|
||||
elif sort_by == "accuracy":
|
||||
architecture_params.sort(key=lambda x: x[-1])
|
||||
|
||||
sorted_architectures, sorted_params, test_conditions, sorted_performances = zip(
|
||||
*architecture_params
|
||||
)
|
||||
final_labels = [
|
||||
f"{arch[1:-1]} - {params}"
|
||||
for arch, params in zip(sorted_architectures, sorted_params)
|
||||
]
|
||||
|
||||
plt.figure(figsize=(18, 9))
|
||||
|
||||
# performance for each condition
|
||||
if plot_for_each_condition:
|
||||
conditions = ["Coignless", "Natural", "Angular"]
|
||||
labels = ["Classic", "Corner", "Bridge"]
|
||||
shift_amounts = [-0.05, 0, 0.05]
|
||||
save_name = name + "_each_condition"
|
||||
for i, condition in enumerate(conditions):
|
||||
# x_vals = range(len(sorted_performances))
|
||||
jittered_x = np.arange(len(sorted_performances)) + shift_amounts[i]
|
||||
y_vals = [perf[condition] for perf in test_conditions]
|
||||
plt.errorbar(
|
||||
jittered_x,
|
||||
y_vals,
|
||||
fmt="D",
|
||||
markerfacecolor="none",
|
||||
markeredgewidth=1.5,
|
||||
label=labels[i],
|
||||
)
|
||||
else:
|
||||
save_name = name + "_mean"
|
||||
plt.plot(range(len(sorted_performances)), sorted_performances, marker="o")
|
||||
|
||||
plt.ylabel("Accuracy (in \\%)", fontsize=17)
|
||||
plt.xticks(range(len(sorted_performances)), final_labels, rotation=90, fontsize=15)
|
||||
plt.yticks(fontsize=16)
|
||||
plt.grid(True)
|
||||
plt.tight_layout()
|
||||
plt.legend(fontsize=15)
|
||||
|
||||
if save:
|
||||
plt.savefig(
|
||||
os.path.join(
|
||||
figure_path,
|
||||
f"minimalCNN_64sorted_{sort_by}_{save_name}.pdf",
|
||||
),
|
||||
dpi=300,
|
||||
bbox_inches="tight",
|
||||
)
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
training_con: str = "classic"
|
||||
model_path: str = "./trained_classic"
|
||||
print(model_path)
|
||||
data_path: str = "/home/kk/Documents/Semester4/code/RenderStimuli/Output/"
|
||||
|
||||
# num stimuli per Pfinkel and batch size
|
||||
stim_per_pfinkel: int = 10000
|
||||
batch_size: int = 1000
|
||||
|
||||
# stimulus condition:
|
||||
performances_list: list = []
|
||||
condition: list[str] = ["Coignless", "Natural", "Angular"]
|
||||
|
||||
# load test data:
|
||||
num_pfinkel: list = np.arange(0, 100, 10).tolist()
|
||||
image_scale: float = 255.0
|
||||
|
||||
# ------------------------------------------
|
||||
|
||||
# create logger:
|
||||
logger = create_logger(
|
||||
save_logging_messages=False,
|
||||
display_logging_messages=True,
|
||||
model_name=model_path,
|
||||
)
|
||||
|
||||
device_str: str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"Using {device_str} device")
|
||||
device: torch.device = torch.device(device_str)
|
||||
torch.set_default_dtype(torch.float32)
|
||||
|
||||
# current time:
|
||||
current = datetime.datetime.now().strftime("%d%m-%H%M")
|
||||
|
||||
# save data
|
||||
cnn_data: list = []
|
||||
cnn_counter: int = 0
|
||||
|
||||
for filename in os.listdir(model_path):
|
||||
if filename.endswith(".pt"):
|
||||
model_filename = os.path.join(model_path, filename)
|
||||
model = torch.load(model_filename, map_location=device)
|
||||
model.eval()
|
||||
print(f"CNN {cnn_counter+1} :{model_filename}")
|
||||
|
||||
# number free parameters for current CNN
|
||||
num_free_params = sum(
|
||||
p.numel() for p in model.parameters() if p.requires_grad
|
||||
)
|
||||
|
||||
# save
|
||||
all_performances: dict = {
|
||||
condition_name: {pfinkel: [] for pfinkel in num_pfinkel}
|
||||
for condition_name in condition
|
||||
}
|
||||
|
||||
for selected_condition in condition:
|
||||
# save performances:
|
||||
logger.info(f"Condition: {selected_condition}")
|
||||
performances: dict = {}
|
||||
for pfinkel in num_pfinkel:
|
||||
test_loss: float = 0.0
|
||||
correct: int = 0
|
||||
pattern_count: int = 0
|
||||
|
||||
data_test = alicorn_data_loader(
|
||||
num_pfinkel=[pfinkel],
|
||||
load_stimuli_per_pfinkel=stim_per_pfinkel,
|
||||
condition=selected_condition,
|
||||
logger=logger,
|
||||
data_path=data_path,
|
||||
)
|
||||
loader = torch.utils.data.DataLoader(
|
||||
data_test, shuffle=False, batch_size=batch_size
|
||||
)
|
||||
|
||||
# start testing network on new stimuli:
|
||||
logger.info("")
|
||||
logger.info(
|
||||
f"-==- Start {selected_condition} " f"Pfinkel {pfinkel}° -==-"
|
||||
)
|
||||
with torch.no_grad():
|
||||
for batch_num, data in enumerate(loader):
|
||||
label = data[0].to(device)
|
||||
image = data[1].type(dtype=torch.float32).to(device)
|
||||
image /= image_scale
|
||||
|
||||
# compute prediction error;
|
||||
output = model(image)
|
||||
|
||||
# Label Typecast:
|
||||
label = label.to(device)
|
||||
|
||||
# loss and optimization
|
||||
loss = torch.nn.functional.cross_entropy(
|
||||
output, label, reduction="sum"
|
||||
)
|
||||
pattern_count += int(label.shape[0])
|
||||
test_loss += float(loss)
|
||||
prediction = output.argmax(dim=1)
|
||||
correct += prediction.eq(label).sum().item()
|
||||
|
||||
total_number_of_pattern: int = int(len(loader)) * int(
|
||||
label.shape[0]
|
||||
)
|
||||
|
||||
# logging:
|
||||
logger.info(
|
||||
(
|
||||
f"{selected_condition},{pfinkel}° "
|
||||
"Pfinkel: "
|
||||
f"[{int(pattern_count)}/{total_number_of_pattern} ({100.0 * pattern_count / total_number_of_pattern:.2f}%)],"
|
||||
f" Average loss: {test_loss / pattern_count:.3e}, "
|
||||
"Accuracy: "
|
||||
f"{100.0 * correct / pattern_count:.2f}% "
|
||||
)
|
||||
)
|
||||
|
||||
performances[pfinkel] = {
|
||||
"pfinkel": pfinkel,
|
||||
"test_accuracy": 100 * correct / pattern_count,
|
||||
"test_losses": float(loss) / pattern_count,
|
||||
}
|
||||
all_performances[selected_condition][pfinkel].append(
|
||||
100 * correct / pattern_count
|
||||
)
|
||||
|
||||
performances_list.append(performances)
|
||||
|
||||
# store num free params + performances
|
||||
avg_performance_per_condition = {
|
||||
cond: np.mean([np.mean(perfs) for perfs in pfinkel_dict.values()])
|
||||
for cond, pfinkel_dict in all_performances.items()
|
||||
}
|
||||
avg_performance_overall = np.mean(
|
||||
list(avg_performance_per_condition.values())
|
||||
)
|
||||
|
||||
# extract CNN config:
|
||||
match = re.search(r"_outChannels\[(\d+), (\d+), (\d+)\]_", filename)
|
||||
if match:
|
||||
out_channels = (
|
||||
[1] + [int(match.group(i)) for i in range(1, 3 + 1)] + [2]
|
||||
)
|
||||
|
||||
# number of free parameters and performances
|
||||
cnn_data.append(
|
||||
(
|
||||
out_channels,
|
||||
num_free_params,
|
||||
avg_performance_per_condition,
|
||||
avg_performance_overall,
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
print("No files found!")
|
||||
break
|
||||
|
||||
# save all 64 performances
|
||||
torch.save(
|
||||
cnn_data,
|
||||
f"{model_path}.pt",
|
||||
)
|
||||
|
||||
# plot
|
||||
sort_and_plot(
|
||||
cnn_data,
|
||||
save=True,
|
||||
plot_for_each_condition=True,
|
||||
name=training_con,
|
||||
sort_by="params",
|
||||
)
|
||||
sort_and_plot(
|
||||
cnn_data,
|
||||
save=True,
|
||||
plot_for_each_condition=False,
|
||||
name=training_con,
|
||||
sort_by="params",
|
||||
)
|
||||
sort_and_plot(
|
||||
cnn_data,
|
||||
save=True,
|
||||
plot_for_each_condition=True,
|
||||
name=training_con,
|
||||
sort_by="accuracy",
|
||||
)
|
||||
sort_and_plot(
|
||||
cnn_data,
|
||||
save=True,
|
||||
plot_for_each_condition=False,
|
||||
name=training_con,
|
||||
sort_by="accuracy",
|
||||
)
|
||||
|
||||
logger.info("-==- DONE -==-")
|
|
@ -0,0 +1,11 @@
|
|||
Directory="/home/kk/Documents/Semester4/code/Run64Variations"
|
||||
Priority="0"
|
||||
echo $Directory
|
||||
mkdir $Directory/argh_log_corner
|
||||
for out_channels_idx in {0..63}; do
|
||||
for kernel_size_idx in {0..0}; do
|
||||
for stride_idx in {0..0}; do
|
||||
echo "hostname; cd $Directory ; /home/kk/P3.10/bin/python3 cnn_training.py --idx-conv-out-channels-list $out_channels_idx --idx-conv-kernel-sizes $kernel_size_idx --idx-conv-stride-sizes $stride_idx -s \$JOB_ID" | qsub -o $Directory/argh_log_classic -j y -p $Priority -q gp4u,gp3u -N itsCorn
|
||||
done
|
||||
done
|
||||
done
|
8
thesis code/network analysis/optimal_stimulus/README.txt
Normal file
8
thesis code/network analysis/optimal_stimulus/README.txt
Normal file
|
@ -0,0 +1,8 @@
|
|||
Folder optimal_stimulus
|
||||
|
||||
1. optimal_stimulus:
|
||||
* for single trained model
|
||||
* generates optimal stimulus for neuron in selected layer
|
||||
|
||||
2. optimal_stimulus_20cnns:
|
||||
* generates stimulus for neuron in same layer of all 20 cnns
|
|
@ -0,0 +1,219 @@
|
|||
# %%
|
||||
import torch
|
||||
import random
|
||||
import re
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as patch
|
||||
import matplotlib as mpl
|
||||
|
||||
mpl.rcParams["text.usetex"] = True
|
||||
mpl.rcParams["font.family"] = "serif"
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.append(parent_dir)
|
||||
from functions.analyse_network import analyse_network
|
||||
from functions.set_seed import set_seed
|
||||
|
||||
# define parameters
|
||||
num_iterations: int = 100000
|
||||
learning_rate: float = 0.1
|
||||
apply_input_mask: bool = True
|
||||
mark_region_in_plot: bool = True
|
||||
sheduler_patience: int = 500
|
||||
sheduler_factor: float = 0.9
|
||||
sheduler_eps = 1e-08
|
||||
target_image_active: float = 1e4
|
||||
random_seed = random.randint(0, 100)
|
||||
save_final: bool = True
|
||||
model_str: str = "CORNER_888"
|
||||
|
||||
# set seet
|
||||
set_seed(random_seed)
|
||||
print(f"Random seed: {random_seed}")
|
||||
|
||||
# path to NN
|
||||
condition: str = "corner_888_poster"
|
||||
pattern = r"seed\d+_Natural_\d+Epoch"
|
||||
nn = "ArghCNN_numConvLayers3_outChannels[8, 8, 8]_kernelSize[7, 15]_leaky relu_stride1_trainFirstConvLayerTrue_seed291857_Natural_1351Epoch_3107-2121.pt"
|
||||
PATH = f"./trained_models/{nn}"
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# %%
|
||||
# load and eval model
|
||||
model = torch.load(PATH).to(device)
|
||||
model.eval()
|
||||
print("Full network:")
|
||||
print(model)
|
||||
print("")
|
||||
|
||||
|
||||
# enter index to plot:
|
||||
idx = int(input("Please select layer: "))
|
||||
print(f"Selected layer: {model[idx]}")
|
||||
assert idx < len(model)
|
||||
model = model[: idx + 1]
|
||||
|
||||
# random input
|
||||
input_img = torch.rand(1, 200, 200).to(device)
|
||||
input_img = input_img.unsqueeze(0)
|
||||
input_img.requires_grad_(True) # type: ignore
|
||||
print(input_img.min(), input_img.max())
|
||||
|
||||
input_shape = input_img.shape
|
||||
assert input_shape[-2] == input_shape[-1]
|
||||
coordinate_list, layer_type_list, pixel_used = analyse_network(
|
||||
model=model, input_shape=int(input_shape[-1])
|
||||
)
|
||||
|
||||
|
||||
output_shape = model(input_img).shape
|
||||
|
||||
|
||||
target_image = torch.zeros(
|
||||
(*output_shape,), dtype=input_img.dtype, device=input_img.device
|
||||
)
|
||||
|
||||
|
||||
# image to parameter (2B optimized)
|
||||
input_parameter = torch.nn.Parameter(input_img)
|
||||
|
||||
|
||||
if len(target_image.shape) == 2:
|
||||
print((f"Available max positions: f:{target_image.shape[1] - 1} "))
|
||||
|
||||
# select neuron and plot for all feature maps (?)
|
||||
neuron_f = int(input("Please select neuron_f: "))
|
||||
print(f"Selected neuron {neuron_f}")
|
||||
target_image[0, neuron_f] = 1e4
|
||||
else:
|
||||
print(
|
||||
(
|
||||
f"Available max positions: f:{target_image.shape[1] - 1} "
|
||||
f"x:{target_image.shape[2]} y:{target_image.shape[3]}"
|
||||
)
|
||||
)
|
||||
|
||||
# select neuron and plot for all feature maps (?)
|
||||
neuron_f = int(input("Please select neuron_f: "))
|
||||
neuron_x = target_image.shape[2] // 2
|
||||
neuron_y = target_image.shape[3] // 2
|
||||
print(f"Selected neuron {neuron_f}, {neuron_x}, {neuron_y}")
|
||||
target_image[0, neuron_f, neuron_x, neuron_y] = target_image_active
|
||||
|
||||
# Input mask ->
|
||||
active_input_x = coordinate_list[-1][:, neuron_x].clone()
|
||||
active_input_y = coordinate_list[-1][:, neuron_y].clone()
|
||||
|
||||
input_mask: torch.Tensor = torch.zeros_like(input_img)
|
||||
|
||||
input_mask[
|
||||
:,
|
||||
:,
|
||||
active_input_x.type(torch.int64).unsqueeze(-1),
|
||||
active_input_y.type(torch.int64).unsqueeze(0),
|
||||
] = 1
|
||||
|
||||
rect_x = [int(active_input_x.min()), int(active_input_x.max())]
|
||||
rect_y = [int(active_input_y.min()), int(active_input_y.max())]
|
||||
# <- Input mask
|
||||
|
||||
if apply_input_mask:
|
||||
with torch.no_grad():
|
||||
input_img *= input_mask
|
||||
|
||||
|
||||
optimizer = torch.optim.Adam([{"params": input_parameter}], lr=learning_rate)
|
||||
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer,
|
||||
patience=sheduler_patience,
|
||||
factor=sheduler_factor,
|
||||
eps=sheduler_eps * 0.1,
|
||||
)
|
||||
|
||||
|
||||
counter: int = 0
|
||||
while (optimizer.param_groups[0]["lr"] > sheduler_eps) and (counter < num_iterations):
|
||||
optimizer.zero_grad()
|
||||
|
||||
output = model(input_parameter)
|
||||
|
||||
loss = torch.nn.functional.mse_loss(output, target_image)
|
||||
loss.backward()
|
||||
|
||||
if counter % 1000 == 0:
|
||||
print(
|
||||
f"{counter} : loss={float(loss):.3e} lr={optimizer.param_groups[0]['lr']:.3e}"
|
||||
)
|
||||
|
||||
optimizer.step()
|
||||
|
||||
if apply_input_mask and len(target_image.shape) != 2:
|
||||
with torch.no_grad():
|
||||
input_parameter.data[torch.where(input_mask == 0)] = 0.0
|
||||
|
||||
with torch.no_grad():
|
||||
max_data = torch.abs(input_parameter.data).max()
|
||||
if max_data > 1.0:
|
||||
input_parameter.data /= max_data
|
||||
|
||||
if (
|
||||
torch.isfinite(input_parameter.data).sum().cpu()
|
||||
!= torch.tensor(input_parameter.data.size()).prod()
|
||||
):
|
||||
print(f"Found NaN in step: {counter}, use a smaller initial lr")
|
||||
exit()
|
||||
|
||||
scheduler.step(float(loss))
|
||||
counter += 1
|
||||
|
||||
# save image
|
||||
if save_final:
|
||||
# get short model name:
|
||||
matches = re.findall(pattern, nn)
|
||||
model_short = "".join(["".join(match) for match in matches])
|
||||
save_name = (
|
||||
f"optimal_model{model_short}_layer{idx}_feature{neuron_f}_seed{random_seed}.pt"
|
||||
)
|
||||
|
||||
# filepath:
|
||||
folderpath = f"./other_{condition}_optimal"
|
||||
os.makedirs(folderpath, exist_ok=True)
|
||||
torch.save(input_img.squeeze().detach().cpu(), os.path.join(folderpath, save_name))
|
||||
|
||||
# plot image:
|
||||
_, ax = plt.subplots()
|
||||
|
||||
ax.imshow(input_img.squeeze().detach().cpu().numpy(), cmap="gray")
|
||||
|
||||
plt.yticks(fontsize=15)
|
||||
plt.xticks(fontsize=15)
|
||||
|
||||
|
||||
if len(target_image.shape) != 2 and mark_region_in_plot:
|
||||
edgecolor = "sienna"
|
||||
kernel = patch.Rectangle(
|
||||
(rect_y[0], rect_x[0]),
|
||||
int(rect_y[1] - rect_y[0]),
|
||||
int(rect_x[1] - rect_x[0]),
|
||||
linewidth=1.2,
|
||||
edgecolor=edgecolor,
|
||||
facecolor="none",
|
||||
)
|
||||
ax.add_patch(kernel)
|
||||
|
||||
figure_path = f"./other_{condition}_optimal"
|
||||
os.makedirs(figure_path, exist_ok=True)
|
||||
plt.savefig(
|
||||
os.path.join(
|
||||
figure_path,
|
||||
f"{save_name}_{model_str}.pdf",
|
||||
),
|
||||
dpi=300,
|
||||
bbox_inches="tight",
|
||||
)
|
||||
|
||||
plt.show(block=True)
|
|
@ -0,0 +1,293 @@
|
|||
# %%
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
import re
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as patch
|
||||
import matplotlib as mpl
|
||||
|
||||
mpl.rcParams["text.usetex"] = True
|
||||
mpl.rcParams["font.family"] = "serif"
|
||||
mpl.rcParams["font.size"] = 15
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.append(parent_dir)
|
||||
from functions.analyse_network import analyse_network
|
||||
from functions.set_seed import set_seed
|
||||
|
||||
# set seet
|
||||
random_seed = random.randint(0, 100)
|
||||
set_seed(random_seed)
|
||||
print(f"Random seed: {random_seed}")
|
||||
|
||||
|
||||
def get_file_list_all_cnns(dir: str) -> list:
|
||||
all_results: list = []
|
||||
for filename in os.listdir(dir):
|
||||
if filename.endswith(".pt"):
|
||||
print(os.path.join(dir, filename))
|
||||
all_results.append(os.path.join(dir, filename))
|
||||
|
||||
return all_results
|
||||
|
||||
|
||||
def show_single_optimal_stimulus(model_list, save: bool = False, cnn: str = "CORNER"):
|
||||
first_run: bool = True
|
||||
chosen_layer_idx: int
|
||||
chosen_neuron_f_idx: int
|
||||
chosen_neuron_x_idx: int
|
||||
chosen_neuron_y_idx: int
|
||||
mean_opt_stim_list: list = []
|
||||
fig, axs = plt.subplots(4, 5, figsize=(15, 15))
|
||||
for i, load_model in enumerate(model_list):
|
||||
print(f"\nModel: {i} ")
|
||||
num_iterations: int = 100000
|
||||
learning_rate: float = 0.1
|
||||
apply_input_mask: bool = True
|
||||
mark_region_in_plot: bool = True
|
||||
sheduler_patience: int = 500
|
||||
sheduler_factor: float = 0.9
|
||||
sheduler_eps = 1e-08
|
||||
target_image_active: float = 1e4
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# load model
|
||||
model = torch.load(load_model).to(device)
|
||||
model.eval()
|
||||
|
||||
if first_run:
|
||||
print("Full network:")
|
||||
print(model)
|
||||
print("")
|
||||
|
||||
# enter index to plot:
|
||||
idx = int(input("Please select layer: "))
|
||||
assert idx < len(model)
|
||||
chosen_layer_idx = idx
|
||||
|
||||
print(f"Selected layer: {model[chosen_layer_idx]}")
|
||||
model = model[: chosen_layer_idx + 1]
|
||||
|
||||
# prepare random input image
|
||||
input_img = torch.rand(1, 200, 200).to(device)
|
||||
input_img = input_img.unsqueeze(0)
|
||||
input_img.requires_grad_(True) # type: ignore
|
||||
|
||||
input_shape = input_img.shape
|
||||
assert input_shape[-2] == input_shape[-1]
|
||||
coordinate_list, layer_type_list, pixel_used = analyse_network(
|
||||
model=model, input_shape=int(input_shape[-1])
|
||||
)
|
||||
|
||||
output_shape = model(input_img).shape
|
||||
|
||||
target_image = torch.zeros(
|
||||
(*output_shape,), dtype=input_img.dtype, device=input_img.device
|
||||
)
|
||||
|
||||
# image to parameter (2B optimized)
|
||||
input_parameter = torch.nn.Parameter(input_img)
|
||||
|
||||
# back to first run:
|
||||
if first_run:
|
||||
if len(target_image.shape) == 2:
|
||||
print((f"Available max positions: f:{target_image.shape[1] - 1} "))
|
||||
|
||||
# select neuron and plot for all feature maps (?)
|
||||
neuron_f = int(input("Please select neuron_f: "))
|
||||
print(f"Selected neuron {neuron_f}")
|
||||
chosen_neuron_f_idx = neuron_f
|
||||
else:
|
||||
print(
|
||||
(
|
||||
f"Available max positions: f:{target_image.shape[1] - 1} "
|
||||
f"x:{target_image.shape[2]} y:{target_image.shape[3]}"
|
||||
)
|
||||
)
|
||||
|
||||
# select neuron and plot for all feature maps (?)
|
||||
neuron_f = int(input("Please select neuron_f: "))
|
||||
neuron_x = target_image.shape[2] // 2
|
||||
neuron_y = target_image.shape[3] // 2
|
||||
chosen_neuron_f_idx = neuron_f
|
||||
chosen_neuron_x_idx = neuron_x
|
||||
chosen_neuron_y_idx = neuron_y
|
||||
print(
|
||||
f"Selected neuron {chosen_neuron_f_idx}, {chosen_neuron_x_idx}, {chosen_neuron_y_idx}"
|
||||
)
|
||||
|
||||
# keep settings for further runs:
|
||||
first_run = False
|
||||
|
||||
# keep input values for all cnns
|
||||
if len(target_image.shape) == 2:
|
||||
target_image[0, chosen_neuron_f_idx] = 1e4
|
||||
else:
|
||||
target_image[
|
||||
0, chosen_neuron_f_idx, chosen_neuron_x_idx, chosen_neuron_y_idx
|
||||
] = target_image_active
|
||||
|
||||
# Input mask ->
|
||||
active_input_x = coordinate_list[-1][:, neuron_x].clone()
|
||||
active_input_y = coordinate_list[-1][:, neuron_y].clone()
|
||||
|
||||
input_mask: torch.Tensor = torch.zeros_like(input_img)
|
||||
|
||||
input_mask[
|
||||
:,
|
||||
:,
|
||||
active_input_x.type(torch.int64).unsqueeze(-1),
|
||||
active_input_y.type(torch.int64).unsqueeze(0),
|
||||
] = 1
|
||||
|
||||
rect_x = [int(active_input_x.min()), int(active_input_x.max())]
|
||||
rect_y = [int(active_input_y.min()), int(active_input_y.max())]
|
||||
# <- Input mask
|
||||
|
||||
if apply_input_mask:
|
||||
with torch.no_grad():
|
||||
input_img *= input_mask
|
||||
|
||||
# start optimization:
|
||||
optimizer = torch.optim.Adam([{"params": input_parameter}], lr=learning_rate)
|
||||
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer,
|
||||
patience=sheduler_patience,
|
||||
factor=sheduler_factor,
|
||||
eps=sheduler_eps * 0.1,
|
||||
)
|
||||
|
||||
counter: int = 0
|
||||
while (optimizer.param_groups[0]["lr"] > sheduler_eps) and (
|
||||
counter < num_iterations
|
||||
):
|
||||
optimizer.zero_grad()
|
||||
|
||||
output = model(input_parameter)
|
||||
|
||||
loss = torch.nn.functional.mse_loss(output, target_image)
|
||||
loss.backward()
|
||||
|
||||
if counter % 1000 == 0:
|
||||
print(
|
||||
f"{counter} : loss={float(loss):.3e} lr={optimizer.param_groups[0]['lr']:.3e}"
|
||||
)
|
||||
|
||||
optimizer.step()
|
||||
|
||||
if apply_input_mask and len(target_image.shape) != 2:
|
||||
with torch.no_grad():
|
||||
input_parameter.data[torch.where(input_mask == 0)] = 0.0
|
||||
|
||||
with torch.no_grad():
|
||||
max_data = torch.abs(input_parameter.data).max()
|
||||
if max_data > 1.0:
|
||||
input_parameter.data /= max_data
|
||||
|
||||
if (
|
||||
torch.isfinite(input_parameter.data).sum().cpu()
|
||||
!= torch.tensor(input_parameter.data.size()).prod()
|
||||
):
|
||||
print(f"Found NaN in step: {counter}, use a smaller initial lr")
|
||||
exit()
|
||||
|
||||
scheduler.step(float(loss))
|
||||
counter += 1
|
||||
mean_opt_stim_list.append(input_img.squeeze().detach().cpu().numpy())
|
||||
|
||||
# plot image:
|
||||
ax = axs[i // 5, i % 5]
|
||||
im = ax.imshow(input_img.squeeze().detach().cpu().numpy(), cmap="gray")
|
||||
cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
||||
ax.set_title(f"Model {i+1}", fontsize=13)
|
||||
cbar.ax.tick_params(labelsize=12)
|
||||
|
||||
if len(target_image.shape) != 2 and mark_region_in_plot:
|
||||
edgecolor = "sienna"
|
||||
kernel = patch.Rectangle(
|
||||
(rect_y[0], rect_x[0]),
|
||||
int(rect_y[1] - rect_y[0]),
|
||||
int(rect_x[1] - rect_x[0]),
|
||||
linewidth=1.2,
|
||||
edgecolor=edgecolor,
|
||||
facecolor="none",
|
||||
)
|
||||
ax.add_patch(kernel)
|
||||
|
||||
plt.tight_layout()
|
||||
# save image
|
||||
if save:
|
||||
save_name = f"single_optimal_stimulus_{cnn}_layer{chosen_layer_idx}_feature{chosen_neuron_f_idx}"
|
||||
folderpath = "./all20_optimal_stimuli"
|
||||
os.makedirs(folderpath, exist_ok=True)
|
||||
torch.save(
|
||||
input_img.squeeze().detach().cpu(),
|
||||
os.path.join(folderpath, save_name) + ".pt",
|
||||
)
|
||||
plt.savefig(
|
||||
f"{os.path.join(folderpath, save_name)}.pdf",
|
||||
dpi=300,
|
||||
bbox_inches="tight",
|
||||
)
|
||||
|
||||
plt.show(block=True)
|
||||
|
||||
if len(target_image.shape) == 2:
|
||||
return mean_opt_stim_list, chosen_neuron_f_idx, chosen_layer_idx
|
||||
else:
|
||||
return (
|
||||
mean_opt_stim_list,
|
||||
(chosen_layer_idx, chosen_neuron_f_idx),
|
||||
(chosen_neuron_x_idx, chosen_neuron_y_idx),
|
||||
)
|
||||
|
||||
|
||||
def plot_mean_optimal_stimulus(
|
||||
overall_optimal_stimuli,
|
||||
chosen_layer_idx: int,
|
||||
chosen_neuron_f_idx: int,
|
||||
save: bool = False,
|
||||
cnn: str = "CORNER",
|
||||
):
|
||||
fig, axs = plt.subplots(figsize=(15, 15))
|
||||
mean_optimal_stimulus = np.mean(overall_optimal_stimuli, axis=0)
|
||||
im = axs.imshow(mean_optimal_stimulus, cmap="gray")
|
||||
cbar = fig.colorbar(im, ax=axs, fraction=0.046, pad=0.04)
|
||||
cbar.ax.tick_params(labelsize=15)
|
||||
|
||||
plt.tight_layout()
|
||||
# save image
|
||||
if save:
|
||||
save_name = f"overall_mean_optimal_stimulus_{cnn}_layer{chosen_layer_idx}_feature{chosen_neuron_f_idx}"
|
||||
folderpath = "./mean_optimal_stimulus"
|
||||
os.makedirs(folderpath, exist_ok=True)
|
||||
torch.save(mean_optimal_stimulus, os.path.join(folderpath, save_name) + ".pt")
|
||||
plt.savefig(
|
||||
f"{os.path.join(folderpath, save_name)}.pdf",
|
||||
dpi=300,
|
||||
)
|
||||
|
||||
plt.show(block=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# path to NN
|
||||
PATH_corner = "./classic_3288_fest"
|
||||
all_cnns_corner = get_file_list_all_cnns(dir=PATH_corner)
|
||||
opt_stim_list, feature_idx, layer_idx = show_single_optimal_stimulus(
|
||||
all_cnns_corner, save=True, cnn="CLASSIC_3288_fest"
|
||||
)
|
||||
|
||||
# average optimal stimulus:
|
||||
# plot_mean_optimal_stimulus(
|
||||
# opt_stim_list,
|
||||
# save=True,
|
||||
# cnn="CORNER_3288_fest",
|
||||
# chosen_layer_idx=layer_idx,
|
||||
# chosen_neuron_f_idx=feature_idx,
|
||||
# )
|
14
thesis code/network analysis/orientation_tuning/README.txt
Normal file
14
thesis code/network analysis/orientation_tuning/README.txt
Normal file
|
@ -0,0 +1,14 @@
|
|||
Folder orientation_tuning:
|
||||
|
||||
|
||||
1. orientation_tuning_curve:
|
||||
* generates the original tuning curve by convolving the Gabor patches with the weight matrices of C1
|
||||
* Gabor patches file: gabor_dict_32o_8p.py
|
||||
|
||||
2. fitkarotte:
|
||||
* implements the fitting procedure of the 3 von Mises functions
|
||||
* plots the fitted tuning curves
|
||||
|
||||
3. fit_statistics:
|
||||
* contains all statistical test for the 20 trained CNNs of each stimulus condition
|
||||
* calls the 'plot_fit_statistics' function to plot the data
|
|
@ -0,0 +1,475 @@
|
|||
import numpy as np
|
||||
import fitkarotte
|
||||
from orientation_tuning_curve import load_data_from_cnn # noqa
|
||||
import plot_fit_statistics
|
||||
import warnings
|
||||
from scipy.stats import ranksums
|
||||
|
||||
# suppress warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def get_general_data_info(data, print_mises_all_cnn: bool):
|
||||
num_cnns = len(data)
|
||||
num_weights_per_cnn = [len(cnn_results) for cnn_results in data]
|
||||
|
||||
num_fits_per_cnn = {1: [0] * num_cnns, 2: [0] * num_cnns, 3: [0] * num_cnns}
|
||||
|
||||
for idx, cnn_results in enumerate(data):
|
||||
for _, fit in cnn_results:
|
||||
curve = fit["curve"]
|
||||
num_fits_per_cnn[curve][idx] += 1
|
||||
|
||||
print("\n\nNumber of CNNs:", num_cnns)
|
||||
print("Number of weights saved for each CNN:", num_weights_per_cnn)
|
||||
print("Number of fits with 1 von Mises function per CNN:", num_fits_per_cnn[1])
|
||||
print("Number of fits with 2 von Mises functions per CNN:", num_fits_per_cnn[2])
|
||||
print("Number of fits with 3 von Mises functions per CNN:", num_fits_per_cnn[3])
|
||||
|
||||
# mean and stdev 4each type of fit
|
||||
mean_1_mises = np.mean(num_fits_per_cnn[1])
|
||||
std_1_mises = np.std(num_fits_per_cnn[1])
|
||||
mean_2_mises = np.mean(num_fits_per_cnn[2])
|
||||
std_2_mises = np.std(num_fits_per_cnn[2])
|
||||
mean_3_mises = np.mean(num_fits_per_cnn[3])
|
||||
std_3_mises = np.std(num_fits_per_cnn[3])
|
||||
|
||||
print(
|
||||
f"Mean number of fits with 1 von Mises function: {mean_1_mises:.2f} (std: {std_1_mises:.2f})"
|
||||
)
|
||||
print(
|
||||
f"Mean number of fits with 2 von Mises functions: {mean_2_mises:.2f} (std: {std_2_mises:.2f})"
|
||||
)
|
||||
print(
|
||||
f"Mean number of fits with 3 von Mises functions: {mean_3_mises:.2f} (std: {std_3_mises:.2f})"
|
||||
)
|
||||
|
||||
if print_mises_all_cnn:
|
||||
print("--================================--")
|
||||
for idx_cnn, (num_1_mises, num_2_mises, num_3_mises) in enumerate(
|
||||
zip(num_fits_per_cnn[1], num_fits_per_cnn[2], num_fits_per_cnn[3])
|
||||
):
|
||||
print(
|
||||
f"CNN {idx_cnn+1}:\t# 1 Mises: {num_1_mises},\t# 2 Mises: {num_2_mises},\t# 3 Mises: {num_3_mises}"
|
||||
)
|
||||
|
||||
return (
|
||||
num_fits_per_cnn,
|
||||
mean_1_mises,
|
||||
mean_2_mises,
|
||||
mean_3_mises,
|
||||
std_1_mises,
|
||||
std_2_mises,
|
||||
std_3_mises,
|
||||
)
|
||||
|
||||
|
||||
def ratio_amplitude_two_mises(data, mean_std: bool = False):
|
||||
"""
|
||||
* This function calculates the mean ratio of those weights
|
||||
of the first layer, which were fitted with 2 von Mises functions
|
||||
* It first calculates the mean ratio for every single CNN
|
||||
(of the overall 20 CNNs)
|
||||
* It then computes the overall mean ratio for the weights
|
||||
of all 20 CNNs that were fitted with 2 von Mises functions
|
||||
"""
|
||||
num_cnns = len(data)
|
||||
mean_ratio_per_cnn = [0] * num_cnns
|
||||
|
||||
for idx, cnn_results in enumerate(data):
|
||||
ratio_list: list = []
|
||||
count_num_2mises: int = 0
|
||||
for _, fit in cnn_results:
|
||||
curve = fit["curve"]
|
||||
if curve == 2 and fit["fit_params"] is not None:
|
||||
count_num_2mises += 1
|
||||
first_amp = fit["fit_params"][0]
|
||||
sec_amp = fit["fit_params"][3]
|
||||
|
||||
if sec_amp < first_amp:
|
||||
ratio = sec_amp / first_amp
|
||||
else:
|
||||
ratio = first_amp / sec_amp
|
||||
|
||||
if not (ratio > 1.0 or ratio < 0):
|
||||
ratio_list.append(ratio)
|
||||
else:
|
||||
print(f"\nRATIO OUT OF RANGE FOR: CNN:{idx}, weight{_}\n")
|
||||
|
||||
# print(f"CNN [{idx}]: num fits with 2 von mises = {count_num_2mises}")
|
||||
mean_ratio_per_cnn[idx] = np.mean(ratio_list)
|
||||
|
||||
# calc mean difference over all 20 CNNs:
|
||||
if mean_std:
|
||||
mean_all_cnns = np.mean(mean_ratio_per_cnn)
|
||||
std_all_cnns = np.std(mean_ratio_per_cnn)
|
||||
print("\n-=== Mean ratio between 2 amplitudes ===-")
|
||||
print(f"Mean ratio of all {len(mean_ratio_per_cnn)} CNNs: {mean_all_cnns}")
|
||||
print(f"Stdev of ratio of all {len(mean_ratio_per_cnn)} CNNs: {std_all_cnns}")
|
||||
|
||||
return mean_all_cnns, std_all_cnns
|
||||
|
||||
else: # get median and percentile
|
||||
percentiles = np.percentile(mean_ratio_per_cnn, [10, 25, 50, 75, 90])
|
||||
|
||||
print("\n-=== Percentiles of ratio between 2 amplitudes ===-")
|
||||
print(f"10th Percentile: {percentiles[0]}")
|
||||
print(f"25th Percentile: {percentiles[1]}")
|
||||
print(f"Median (50th Percentile): {percentiles[2]}")
|
||||
print(f"75th Percentile: {percentiles[3]}")
|
||||
print(f"90th Percentile: {percentiles[4]}")
|
||||
|
||||
# return mean_all_cnns, std_all_cnns
|
||||
return percentiles[2], (percentiles[1], percentiles[3])
|
||||
|
||||
|
||||
def ratio_amplitude_three_mises(data, mean_std: bool = False):
|
||||
"""
|
||||
* returns: mean21, std21, mean32, std32
|
||||
* This function calculates the mean ratio of those weights
|
||||
of the first layer, which were fitted with 2 von Mises functions
|
||||
* It first calculates the mean ratio for every single CNN
|
||||
(of the overall 20 CNNs)
|
||||
* It then computes the overall mean ratio for the weights
|
||||
of all 20 CNNs that were fitted with 2 von Mises functions
|
||||
"""
|
||||
num_cnns = len(data)
|
||||
mean_ratio_per_cnn21 = [0] * num_cnns
|
||||
mean_ratio_per_cnn32 = [0] * num_cnns
|
||||
|
||||
for idx, cnn_results in enumerate(data):
|
||||
ratio_list21: list = []
|
||||
ratio_list32: list = []
|
||||
count_num_2mises: int = 0
|
||||
for _, fit in cnn_results:
|
||||
curve = fit["curve"]
|
||||
if curve == 3 and fit["fit_params"] is not None:
|
||||
count_num_2mises += 1
|
||||
first_amp = fit["fit_params"][0]
|
||||
sec_amp = fit["fit_params"][3]
|
||||
third_amp = fit["fit_params"][6]
|
||||
|
||||
if sec_amp < first_amp:
|
||||
ratio21 = sec_amp / first_amp
|
||||
else:
|
||||
ratio21 = first_amp / sec_amp
|
||||
|
||||
if third_amp < sec_amp:
|
||||
ratio32 = third_amp / sec_amp
|
||||
else:
|
||||
ratio32 = sec_amp / third_amp
|
||||
|
||||
if not (ratio21 > 1.0 or ratio32 > 1.0 or ratio21 < 0 or ratio32 < 0):
|
||||
ratio_list21.append(ratio21)
|
||||
ratio_list32.append(ratio32)
|
||||
else:
|
||||
print(f"\nRATIO OUT OF RANGE FOR: CNN:{idx}, weight{_}\n")
|
||||
|
||||
# print(f"CNN [{idx}]: num fits with 2 von mises =
|
||||
# {count_num_2mises}")
|
||||
if len(ratio_list21) != 0:
|
||||
mean_ratio_per_cnn21[idx] = np.mean(ratio_list21)
|
||||
mean_ratio_per_cnn32[idx] = np.mean(ratio_list32)
|
||||
else:
|
||||
mean_ratio_per_cnn21[idx] = None # type: ignore
|
||||
mean_ratio_per_cnn32[idx] = None # type: ignore
|
||||
|
||||
mean_ratio_per_cnn21 = [x for x in mean_ratio_per_cnn21 if x is not None]
|
||||
mean_ratio_per_cnn32 = [x for x in mean_ratio_per_cnn32 if x is not None]
|
||||
|
||||
# calc mean difference over all 20 CNNs:
|
||||
|
||||
if mean_std:
|
||||
mean_all_cnns21 = np.mean(mean_ratio_per_cnn21)
|
||||
std_all_21 = np.std(mean_ratio_per_cnn21)
|
||||
mean_all_cnns32 = np.mean(mean_ratio_per_cnn32)
|
||||
std_all_32 = np.std(mean_ratio_per_cnn32)
|
||||
|
||||
print("\n-=== Mean ratio between 3 preferred orienations ===-")
|
||||
print(f"Ratio 2/1 of all {len(mean_ratio_per_cnn21)} CNNs: {mean_all_cnns21}")
|
||||
print(
|
||||
f"Stdev of ratio 2/1 of all {len(mean_ratio_per_cnn21)} CNNs: {std_all_21}"
|
||||
)
|
||||
print(f"Ratio 3/2 of all {len(mean_ratio_per_cnn32)} CNNs: {mean_all_cnns32}")
|
||||
print(
|
||||
f"Stdev of ratio 3/2 of all {len(mean_ratio_per_cnn32)} CNNs: {std_all_32}"
|
||||
)
|
||||
|
||||
return mean_all_cnns21, std_all_21, mean_all_cnns32, std_all_32
|
||||
|
||||
else: # get median and percentile:
|
||||
percentiles_21 = np.percentile(mean_ratio_per_cnn32, [10, 25, 50, 75, 90])
|
||||
percentiles_32 = np.percentile(mean_ratio_per_cnn21, [10, 25, 50, 75, 90])
|
||||
|
||||
# get percentile 25 and 75
|
||||
percentile25_32 = percentiles_32[1]
|
||||
percentile75_32 = percentiles_32[-2]
|
||||
percentile25_21 = percentiles_21[1]
|
||||
percentile75_21 = percentiles_21[-2]
|
||||
|
||||
print("\n-=== Percentiles of ratio between 2 amplitudes ===-")
|
||||
print(f"10th Percentile 3->2: {percentiles_32[0]}")
|
||||
print(f"10th Percentile 2->1: {percentiles_21[0]}")
|
||||
print(f"25th Percentile 3->2: {percentiles_32[1]}")
|
||||
print(f"25th Percentile 2->1: {percentiles_21[1]}")
|
||||
print(f"Median (50th Percentile 3->2): {percentiles_32[2]}")
|
||||
print(f"Median (50th Percentile 2->1): {percentiles_21[2]}")
|
||||
print(f"75th Percentile 3->2: {percentiles_32[3]}")
|
||||
print(f"75th Percentile 2->1: {percentiles_21[3]}")
|
||||
print(f"90th Percentile3->2: {percentiles_32[4]}")
|
||||
print(f"90th Percentile 2->1: {percentiles_21[4]}")
|
||||
|
||||
return (
|
||||
percentiles_21[2],
|
||||
(percentile25_21, percentile75_21),
|
||||
percentiles_32[2],
|
||||
(percentile25_32, percentile75_32),
|
||||
)
|
||||
|
||||
|
||||
def willy_is_not_whitney_test(data_classic, data_corner):
|
||||
from scipy.stats import mannwhitneyu
|
||||
|
||||
"""
|
||||
* Test does not assume normal distribution
|
||||
* Compares means between 2 indep groups
|
||||
"""
|
||||
|
||||
# call test
|
||||
statistic, p_value = mannwhitneyu(data_classic, data_corner)
|
||||
|
||||
# results
|
||||
print("\nMann-Whitney U Test Statistic:", statistic)
|
||||
print("Mann-Whitney U Test p-value:", p_value)
|
||||
|
||||
# check significance:
|
||||
print("Null-hypothesis: distributions are the same.")
|
||||
alpha = 0.05
|
||||
if p_value < alpha:
|
||||
print("The distributions are significantly different.")
|
||||
else:
|
||||
print("The distributions are not significantly different.")
|
||||
|
||||
return p_value
|
||||
|
||||
|
||||
def ks(data_classic, data_corner):
|
||||
from scipy.stats import ks_2samp
|
||||
|
||||
ks_statistic, p_value = ks_2samp(data_classic, data_corner)
|
||||
|
||||
print("\nKolmogorov-Smirnov Test - p-value:", p_value)
|
||||
print("Kolmogorov-Smirnov Test - ks_statistic:", ks_statistic)
|
||||
alpha = 0.05
|
||||
if p_value < alpha:
|
||||
print("The distributions for von Mises functions are significantly different.")
|
||||
|
||||
return p_value
|
||||
|
||||
|
||||
def shapiro(fits_per_mises, num_mises: int, alpha: float = 0.05):
|
||||
"""
|
||||
Tests if data has normal distribution
|
||||
* 0-hyp: data is normally distributed
|
||||
* low p-val: data not normally distributed
|
||||
"""
|
||||
from scipy.stats import shapiro
|
||||
|
||||
statistic, p_value = shapiro(fits_per_mises)
|
||||
print(f"\nShapiro-Wilk Test for {num_mises} von Mises function - p-val :", p_value)
|
||||
print(
|
||||
f"Shapiro-Wilk Test for {num_mises} von Mises function - statistic :", statistic
|
||||
)
|
||||
|
||||
# set alpha
|
||||
if p_value < alpha:
|
||||
print("P-val < alpha. Reject 0-hypothesis. Data is not normally distributed")
|
||||
else:
|
||||
print("P-val > alpha. Keep 0-hypothesis. Data is normally distributed")
|
||||
|
||||
return p_value
|
||||
|
||||
|
||||
def agostino_pearson(fits_per_mises, num_mises: int, alpha: float = 0.05):
|
||||
"""
|
||||
Tests if data has normal distribution
|
||||
* 0-hyp: data is normally distributed
|
||||
* low p-val: data not normally distributed
|
||||
"""
|
||||
from scipy import stats
|
||||
|
||||
statistic, p_value = stats.normaltest(fits_per_mises)
|
||||
print(
|
||||
f"\nD'Agostino-Pearson Test for {num_mises} von Mises function - p-val :",
|
||||
p_value,
|
||||
)
|
||||
print(
|
||||
f"D'Agostino-Pearson Test for {num_mises} von Mises function - statistic :",
|
||||
statistic,
|
||||
)
|
||||
|
||||
# set alpha
|
||||
if p_value < alpha:
|
||||
print("P-val < alpha. Reject 0-hypothesis. Data is not normally distributed")
|
||||
else:
|
||||
print("P-val > alpha. Keep 0-hypothesis. Data is normally distributed")
|
||||
|
||||
return p_value
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
num_thetas = 32
|
||||
dtheta = 2 * np.pi / num_thetas
|
||||
theta = dtheta * np.arange(num_thetas)
|
||||
threshold: float = 0.1
|
||||
|
||||
# to do statistics on corner
|
||||
directory_corner: str = "D:/Katha/Neuroscience/Semester 4/newCode/kk_contour_net_shallow-main/corner_888"
|
||||
all_results_corner = fitkarotte.analyze_cnns(dir=directory_corner)
|
||||
|
||||
# analyze
|
||||
print("-=== CORNER ===-")
|
||||
# amplitude ratios
|
||||
ratio_corner_21, std_corner_21 = ratio_amplitude_two_mises(data=all_results_corner)
|
||||
(
|
||||
ratio_corner_321,
|
||||
std_corner_321,
|
||||
ratio_corner_332,
|
||||
std_corner_332,
|
||||
) = ratio_amplitude_three_mises(data=all_results_corner)
|
||||
|
||||
# general data
|
||||
(
|
||||
corner_num_fits,
|
||||
mean_corner_1,
|
||||
mean_corner_2,
|
||||
mean_corner_3,
|
||||
std_corner_1,
|
||||
std_corner_2,
|
||||
std_corner_3,
|
||||
) = get_general_data_info(data=all_results_corner, print_mises_all_cnn=True)
|
||||
# analyze_num_curve_fits(data=all_results_corner)
|
||||
|
||||
# to do statistics: CLASSIC
|
||||
directory_classic: str = "D:/Katha/Neuroscience/Semester 4/newCode/kk_contour_net_shallow-main/classic_888"
|
||||
all_results_classic = fitkarotte.analyze_cnns(dir=directory_classic)
|
||||
|
||||
# analyze
|
||||
print("-=== CLASSIC ===-")
|
||||
# amplitude ratio
|
||||
ratio_classic_21, std_class_21 = ratio_amplitude_two_mises(data=all_results_classic)
|
||||
(
|
||||
ratio_classic_321,
|
||||
std_classic_321,
|
||||
ratio_classic_332,
|
||||
std_classic_332,
|
||||
) = ratio_amplitude_three_mises(data=all_results_classic)
|
||||
|
||||
# general data
|
||||
(
|
||||
classic_num_fits,
|
||||
mean_classic_1,
|
||||
mean_classic_2,
|
||||
mean_classic_3,
|
||||
std_classic_1,
|
||||
std_classic_2,
|
||||
std_classic_3,
|
||||
) = get_general_data_info(data=all_results_classic, print_mises_all_cnn=False)
|
||||
# analyze_num_curve_fits(data=all_results_classic)
|
||||
|
||||
print("################################")
|
||||
print("-==== plotting hists: compare amplitude ratios ====-")
|
||||
plot_fit_statistics.plot_mean_percentile_amplit_ratio(
|
||||
ratio_classic_21=ratio_classic_21,
|
||||
ratio_classic_321=ratio_classic_321,
|
||||
ratio_classic_332=ratio_classic_332,
|
||||
ratio_corner_21=ratio_corner_21,
|
||||
ratio_corner_321=ratio_corner_321,
|
||||
ratio_corner_332=ratio_corner_332,
|
||||
percentile_classic21=std_class_21,
|
||||
percentile_classic321=std_classic_321,
|
||||
percentile_classic_332=std_classic_332,
|
||||
percentile_corner_21=std_corner_21,
|
||||
percentile_corner_321=std_corner_321,
|
||||
percentile_corner_332=std_corner_332,
|
||||
saveplot=True,
|
||||
save_name="median_percentile_888",
|
||||
)
|
||||
|
||||
# p-value < 0.05: statistically significant difference between your two samples
|
||||
statistic21, pvalue21 = ranksums(ratio_classic_21, ratio_corner_21)
|
||||
print(
|
||||
f"Wilcoxon rank sum test 2 Mises for ratio 2->1: statistic={statistic21}, pvalue={pvalue21}"
|
||||
)
|
||||
|
||||
statistic321, pvalue321 = ranksums(ratio_classic_321, ratio_corner_321)
|
||||
print(
|
||||
f"Wilcoxon rank sum test 3 Mises for ratio 2->1: statistic={statistic321}, pvalue={pvalue321}"
|
||||
)
|
||||
|
||||
statistic332, pvalue332 = ranksums(ratio_classic_332, ratio_corner_332)
|
||||
print(
|
||||
f"Wilcoxon rank sum test 3 Mises for ratio 3->2: statistic={statistic332}, pvalue={pvalue332}"
|
||||
)
|
||||
|
||||
print("-==== plotting hists: CORNER ====-")
|
||||
# plot histogram
|
||||
# plot_hist(corner_num_fits[1], num_mises=1)
|
||||
# plot_hist(corner_num_fits[2], num_mises=2)
|
||||
# plot_hist(corner_num_fits[3], num_mises=3)
|
||||
|
||||
# test for normal distribution
|
||||
print("-== Shapiro test ==- ")
|
||||
# shapiro(corner_num_fits[1], num_mises=1)
|
||||
# shapiro(corner_num_fits[2], num_mises=2)
|
||||
# shapiro(corner_num_fits[3], num_mises=3)
|
||||
|
||||
print("\n-== D'Agostino-Pearson test ==- ")
|
||||
agostino_pearson(corner_num_fits[1], num_mises=1)
|
||||
agostino_pearson(corner_num_fits[2], num_mises=2)
|
||||
agostino_pearson(corner_num_fits[3], num_mises=3)
|
||||
|
||||
print("-==== plotting hists: CLASSIC ====-")
|
||||
# plot histogram
|
||||
# plot_hist(classic_num_fits[1], num_mises=1)
|
||||
# plot_hist(classic_num_fits[2], num_mises=2)
|
||||
# plot_hist(classic_num_fits[3], num_mises=3)
|
||||
|
||||
# test for normal distribution
|
||||
print("-== Shapiro test ==- ")
|
||||
# shapiro(classic_num_fits[1], num_mises=1)
|
||||
# shapiro(classic_num_fits[2], num_mises=2)
|
||||
# shapiro(classic_num_fits[3], num_mises=3)
|
||||
|
||||
print("\n-== D'Agostino-Pearson test ==- ")
|
||||
agostino_pearson(classic_num_fits[1], num_mises=1)
|
||||
agostino_pearson(classic_num_fits[2], num_mises=2)
|
||||
agostino_pearson(classic_num_fits[3], num_mises=3)
|
||||
|
||||
# statistics for each von mises:
|
||||
print("######################")
|
||||
print(" -=== CLASSIC vs CORNER ===-")
|
||||
# 1:
|
||||
willy_is_not_whitney_test(
|
||||
data_classic=classic_num_fits[1], data_corner=corner_num_fits[1]
|
||||
)
|
||||
|
||||
# 2:
|
||||
willy_is_not_whitney_test(
|
||||
data_classic=classic_num_fits[2], data_corner=corner_num_fits[2]
|
||||
)
|
||||
|
||||
# 3:
|
||||
willy_is_not_whitney_test(
|
||||
data_classic=classic_num_fits[3], data_corner=corner_num_fits[3]
|
||||
)
|
||||
|
||||
# visualize as bar plots:
|
||||
plot_fit_statistics.plot_means_std_corner_classic(
|
||||
means_classic=[mean_classic_1, mean_classic_2, mean_classic_3],
|
||||
means_corner=[mean_corner_1, mean_corner_2, mean_corner_3],
|
||||
std_classic=[std_classic_1, std_classic_2, std_classic_3],
|
||||
std_corner=[std_corner_1, std_corner_2, std_corner_3],
|
||||
saveplot=False,
|
||||
save_name="3288",
|
||||
)
|
373
thesis code/network analysis/orientation_tuning/fitkarotte.py
Normal file
373
thesis code/network analysis/orientation_tuning/fitkarotte.py
Normal file
|
@ -0,0 +1,373 @@
|
|||
# %%
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import os
|
||||
import scipy.optimize as sop
|
||||
import orientation_tuning_curve # import load_data_from_cnn
|
||||
import warnings
|
||||
import matplotlib as mpl
|
||||
|
||||
mpl.rcParams["text.usetex"] = True
|
||||
mpl.rcParams["font.family"] = "serif"
|
||||
mpl.rcParams["font.size"] = 15
|
||||
|
||||
# suppress warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def mises(orientation, a, mean, variance):
|
||||
k = 1 / variance**2
|
||||
return a * np.exp(k * np.cos(orientation - mean)) / np.exp(k)
|
||||
|
||||
|
||||
def biemlich_mieses_karma(orientation, a1, mean1, variance1, a2, mean2, variance2):
|
||||
m1 = mises(orientation, a1, mean1, variance1)
|
||||
m2 = mises(orientation, a2, mean2, variance2)
|
||||
return m1 + m2
|
||||
|
||||
|
||||
def triemlich_mieses_karma(
|
||||
orientation, a1, mean1, variance1, a2, mean2, variance2, a3, mean3, variance3
|
||||
):
|
||||
m1 = mises(orientation, a1, mean1, variance1)
|
||||
m2 = mises(orientation, a2, mean2, variance2)
|
||||
m3 = mises(orientation, a3, mean3, variance3)
|
||||
return m1 + m2 + m3
|
||||
|
||||
|
||||
def plot_reshaped(tune, fits, theta, save_name: str | None, save_plot: bool = False):
|
||||
"""
|
||||
Plot shows the original tuning with the best fits
|
||||
"""
|
||||
|
||||
num_rows = tune.shape[0] // 4
|
||||
num_cols = tune.shape[0] // num_rows
|
||||
# plt.figure(figsize=(12, 15))
|
||||
fig, axs = plt.subplots(num_rows, num_cols, figsize=(10, 7))
|
||||
|
||||
# plot the respective y-lims:
|
||||
overall_min = np.min(tune)
|
||||
overall_max = np.max(tune)
|
||||
|
||||
for i_tune in range(tune.shape[0]):
|
||||
ax = axs[i_tune // num_cols, i_tune % num_cols]
|
||||
ax.plot(np.rad2deg(theta), tune[i_tune], label="Original")
|
||||
|
||||
x_center = (np.rad2deg(theta).min() + np.rad2deg(theta).max()) / 2
|
||||
y_center = (tune[i_tune].min() + tune[i_tune].max()) / 2
|
||||
|
||||
fit = next((fit for key, fit in fits if key == i_tune))
|
||||
if fit["fitted_curve"] is not None:
|
||||
ax.plot(
|
||||
np.rad2deg(theta),
|
||||
fit["fitted_curve"] * fit["scaling_factor"],
|
||||
label="Fit",
|
||||
)
|
||||
ax.text(
|
||||
x_center,
|
||||
y_center,
|
||||
str(fit["curve"]),
|
||||
ha="center",
|
||||
va="center",
|
||||
size="xx-large",
|
||||
color="gray",
|
||||
)
|
||||
|
||||
# update again if there's a fit
|
||||
overall_min = min(
|
||||
overall_min, (fit["fitted_curve"] * fit["scaling_factor"]).min()
|
||||
)
|
||||
overall_max = max(
|
||||
overall_max, (fit["fitted_curve"] * fit["scaling_factor"]).max()
|
||||
)
|
||||
else:
|
||||
# plt.plot(np.rad2deg(theta), fit[i_tune], "--")
|
||||
ax.text(
|
||||
x_center,
|
||||
y_center,
|
||||
"*",
|
||||
ha="center",
|
||||
va="center",
|
||||
size="xx-large",
|
||||
color="gray",
|
||||
)
|
||||
# specified y lims: of no fit: min and max of tune
|
||||
ax.set_ylim([overall_min, overall_max + 0.05])
|
||||
|
||||
# x-ticks from 0°-360°
|
||||
ax.set_xticks(range(0, 361, 90))
|
||||
|
||||
# label them from 0° to 180°
|
||||
ax.set_xticklabels(range(0, 181, 45), fontsize=15)
|
||||
ax.set_xlabel("(in deg)", fontsize=16)
|
||||
|
||||
plt.yticks(fontsize=15)
|
||||
|
||||
plt.tight_layout()
|
||||
if save_plot:
|
||||
plt.savefig(
|
||||
f"additional thesis plots/saved_plots/fitkarotte/{save_name}.pdf",
|
||||
dpi=300,
|
||||
bbox_inches="tight",
|
||||
)
|
||||
|
||||
plt.show(block=True)
|
||||
|
||||
|
||||
def plot_fit(tune, fits, theta, save_name: str | None, save_plot: bool = False):
|
||||
"""
|
||||
Plot shows the original tuning with the best fits
|
||||
"""
|
||||
|
||||
if tune.shape[0] >= 8:
|
||||
num_rows = tune.shape[0] // 8
|
||||
num_cols = tune.shape[0] // num_rows
|
||||
else:
|
||||
num_rows = 2
|
||||
num_cols = tune.shape[0] // num_rows
|
||||
# plt.figure(figsize=(12, 15))
|
||||
fig, axs = plt.subplots(num_rows, num_cols, figsize=(10, 7))
|
||||
|
||||
# plot the respective y-lims:
|
||||
overall_min = np.min(tune)
|
||||
overall_max = np.max(tune)
|
||||
|
||||
for i_tune in range(tune.shape[0]):
|
||||
if axs.ndim == 1:
|
||||
ax = axs[i_tune]
|
||||
else:
|
||||
ax = axs[i_tune // num_cols, i_tune % num_cols]
|
||||
ax.plot(np.rad2deg(theta), tune[i_tune], label="Original")
|
||||
|
||||
x_center = (np.rad2deg(theta).min() + np.rad2deg(theta).max()) / 2
|
||||
y_center = (tune[i_tune].min() + tune[i_tune].max()) / 2
|
||||
|
||||
# fit = next((fit for key, fit in fits if key == i_tune), None)
|
||||
fit = next((fit for key, fit in fits if key == i_tune))
|
||||
if fit["fitted_curve"] is not None:
|
||||
ax.plot(
|
||||
np.rad2deg(theta),
|
||||
fit["fitted_curve"] * fit["scaling_factor"],
|
||||
label="Fit",
|
||||
)
|
||||
ax.text(
|
||||
x_center,
|
||||
y_center,
|
||||
str(fit["curve"]),
|
||||
ha="center",
|
||||
va="center",
|
||||
size="xx-large",
|
||||
color="gray",
|
||||
)
|
||||
|
||||
# update again if there's a fit
|
||||
overall_min = min(
|
||||
overall_min, (fit["fitted_curve"] * fit["scaling_factor"]).min()
|
||||
)
|
||||
overall_max = max(
|
||||
overall_max, (fit["fitted_curve"] * fit["scaling_factor"]).max()
|
||||
)
|
||||
else:
|
||||
ax.text(
|
||||
x_center,
|
||||
y_center,
|
||||
"*",
|
||||
ha="center",
|
||||
va="center",
|
||||
size="xx-large",
|
||||
color="gray",
|
||||
)
|
||||
# specified y lims: of no fit: min and max of tune
|
||||
ax.set_ylim([overall_min, overall_max + 0.05])
|
||||
|
||||
# x-ticks from 0°-360°
|
||||
ax.set_xticks(range(0, 361, 90))
|
||||
|
||||
# label them from 0° to 180°
|
||||
ax.set_xticklabels(range(0, 181, 45), fontsize=15)
|
||||
ax.set_xlabel("(in deg)", fontsize=16)
|
||||
|
||||
plt.yticks(fontsize=15)
|
||||
|
||||
plt.tight_layout()
|
||||
if save_plot:
|
||||
plt.savefig(
|
||||
f"additional thesis plots/saved_plots/fitkarotte/{save_name}.pdf", dpi=300
|
||||
)
|
||||
|
||||
plt.show(block=True)
|
||||
|
||||
|
||||
def fit_curves(tune, theta):
|
||||
# save all fits:
|
||||
save_fits: list = []
|
||||
scaling_factor: list = []
|
||||
for curve in range(1, 4):
|
||||
fit_possible: int = 0
|
||||
fit_impossible: int = 0
|
||||
for i_tune in range(tune.shape[0]):
|
||||
to_tune = tune[i_tune].copy()
|
||||
scale_fact = np.max(to_tune)
|
||||
scaling_factor.append(scale_fact)
|
||||
to_tune /= scale_fact
|
||||
|
||||
p10 = theta[np.argmax(to_tune)]
|
||||
a10 = 1
|
||||
s10 = 0.5
|
||||
|
||||
if curve == 1:
|
||||
function = mises
|
||||
p0 = [a10, p10, s10]
|
||||
elif curve == 2:
|
||||
function = biemlich_mieses_karma # type: ignore
|
||||
p20 = p10 + np.pi
|
||||
a20 = 1.0
|
||||
s20 = 0.4
|
||||
p0 = [a10, p10, s10, a20, p20, s20]
|
||||
else:
|
||||
function = triemlich_mieses_karma # type: ignore
|
||||
p20 = p10 + 2 * np.pi / 3
|
||||
a20 = 0.7
|
||||
s20 = 0.3
|
||||
p30 = p10 + 4 * np.pi / 3
|
||||
a30 = 0.4
|
||||
s30 = 0.3
|
||||
p0 = [a10, p10, s10, a20, p20, s20, a30, p30, s30]
|
||||
|
||||
try:
|
||||
popt = sop.curve_fit(function, theta, to_tune, p0=p0)
|
||||
fitted_curve = function(theta, *popt[0])
|
||||
quad_dist = np.sum((to_tune - fitted_curve) ** 2)
|
||||
|
||||
save_fits.append(
|
||||
{
|
||||
"weight_idx": i_tune,
|
||||
"curve": curve,
|
||||
"fit_params": popt[0],
|
||||
"fitted_curve": fitted_curve,
|
||||
"quad_dist": quad_dist,
|
||||
"scaling_factor": scale_fact,
|
||||
}
|
||||
)
|
||||
|
||||
# count:
|
||||
fit_possible += 1
|
||||
except:
|
||||
fit_impossible += 1
|
||||
fitted_curve = function(theta, *p0)
|
||||
quad_dist = np.sum((to_tune - fitted_curve) ** 2)
|
||||
save_fits.append(
|
||||
{
|
||||
"weight_idx": i_tune,
|
||||
"curve": curve,
|
||||
"fit_params": None,
|
||||
"fitted_curve": None,
|
||||
"quad_dist": quad_dist, # quad_dist
|
||||
"scaling_factor": scale_fact,
|
||||
}
|
||||
)
|
||||
print(
|
||||
"\n################",
|
||||
f" {curve} Mises\tPossible fits: {fit_possible}\tImpossible fits: {fit_impossible}",
|
||||
"################\n",
|
||||
)
|
||||
|
||||
return save_fits
|
||||
|
||||
|
||||
def sort_fits(fits, thresh1: float = 0.1, thresh2: float = 0.1): # , thresh3=0.5 | None
|
||||
filtered_fits: dict = {}
|
||||
|
||||
# search fits for 1 mises:
|
||||
for fit in fits:
|
||||
w_idx = fit["weight_idx"]
|
||||
quad_dist = fit["quad_dist"]
|
||||
curve = fit["curve"]
|
||||
|
||||
if curve == 1:
|
||||
if quad_dist <= thresh1:
|
||||
filtered_fits[w_idx] = fit
|
||||
|
||||
if w_idx not in filtered_fits:
|
||||
if curve == 2:
|
||||
if round(quad_dist, 2) <= thresh2:
|
||||
filtered_fits[w_idx] = fit
|
||||
elif curve == 3:
|
||||
filtered_fits[w_idx] = fit
|
||||
|
||||
sorted_filtered_fits = sorted(
|
||||
filtered_fits.items(), key=lambda x: x[1]["weight_idx"]
|
||||
)
|
||||
return sorted_filtered_fits
|
||||
|
||||
|
||||
def analyze_cnns(dir: str, thresh1: float = 0.1, thresh2: float = 0.1):
|
||||
# theta
|
||||
num_thetas = 32
|
||||
dtheta = 2 * np.pi / num_thetas
|
||||
theta = dtheta * np.arange(num_thetas)
|
||||
|
||||
all_results: list = []
|
||||
for filename in os.listdir(dir):
|
||||
if filename.endswith(".pt"):
|
||||
print(os.path.join(dir, filename))
|
||||
# load
|
||||
tune = orientation_tuning_curve.load_data_from_cnn(
|
||||
cnn_name=os.path.join(dir, filename),
|
||||
plot_responses=False,
|
||||
do_stats=True,
|
||||
)
|
||||
|
||||
# fit
|
||||
all_fits = fit_curves(tune=tune, theta=theta)
|
||||
|
||||
# sort
|
||||
filtered = sort_fits(fits=all_fits, thresh1=thresh1, thresh2=thresh2)
|
||||
|
||||
# store
|
||||
all_results.append(filtered)
|
||||
return all_results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
num_thetas = 32
|
||||
dtheta = 2 * np.pi / num_thetas
|
||||
theta = dtheta * np.arange(num_thetas)
|
||||
threshold: float = 0.1
|
||||
use_saved_tuning: bool = False
|
||||
|
||||
if use_saved_tuning:
|
||||
# load from file
|
||||
tune = np.load(
|
||||
"D:/Katha/Neuroscience/Semester 4/newCode/tuning_CORNER_32o_4p.npy"
|
||||
)
|
||||
else:
|
||||
# load cnn data
|
||||
nn = "ArghCNN_numConvLayers3_outChannels[2, 6, 8]_kernelSize[7, 15]_leaky relu_stride1_trainFirstConvLayerTrue_seed299624_Natural_921Epoch_1609-2307"
|
||||
PATH = f"D:/Katha/Neuroscience/Semester 4/newCode/kk_contour_net_shallow-main/trained_64er_models/{nn}.pt"
|
||||
|
||||
tune = orientation_tuning_curve.load_data_from_cnn(
|
||||
cnn_name=PATH, plot_responses=False, do_stats=True
|
||||
)
|
||||
|
||||
all_fits = fit_curves(tune=tune, theta=theta)
|
||||
filtered_fits = sort_fits(fits=all_fits)
|
||||
save_name: str = "CLASSIC_888_trained_4r8c"
|
||||
save_plot: bool = False
|
||||
|
||||
if tune.shape[0] == 8:
|
||||
plot_reshaped(
|
||||
tune=tune,
|
||||
fits=filtered_fits,
|
||||
theta=theta,
|
||||
save_name=save_name,
|
||||
save_plot=save_plot,
|
||||
)
|
||||
else:
|
||||
plot_fit(
|
||||
tune=tune,
|
||||
fits=filtered_fits,
|
||||
theta=theta,
|
||||
save_name=save_name,
|
||||
save_plot=save_plot,
|
||||
)
|
Binary file not shown.
|
@ -0,0 +1,244 @@
|
|||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import matplotlib as mpl
|
||||
|
||||
mpl.rcParams["text.usetex"] = True
|
||||
mpl.rcParams["font.family"] = "serif"
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.append(parent_dir)
|
||||
from functions.make_cnn import make_cnn # noqa
|
||||
|
||||
|
||||
def plot_single_tuning_curve(mean_syn_input, mean_relu_response, theta):
|
||||
plt.figure()
|
||||
plt.plot(theta, mean_syn_input, label="Before ReLU")
|
||||
plt.plot(theta, mean_relu_response, label="After ReLU")
|
||||
plt.xlabel("orientation (degs)")
|
||||
plt.ylabel("activity")
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.gca().set_xticks(theta)
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_single_phase_orientation(syn_input, relu_response, theta, phi, j):
|
||||
plt.figure()
|
||||
plt.subplot(1, 2, 1, aspect="equal")
|
||||
plt.imshow(
|
||||
syn_input.T, # type: ignore
|
||||
cmap="viridis",
|
||||
aspect="auto",
|
||||
extent=[theta[0], theta[-1], phi[0], phi[-1]],
|
||||
)
|
||||
plt.xlabel("orientation (degs)")
|
||||
plt.ylabel("phase (degs)")
|
||||
plt.colorbar(label="activity")
|
||||
plt.title(f"Weight {j}", fontsize=16)
|
||||
|
||||
plt.subplot(1, 2, 2, aspect="equal")
|
||||
plt.imshow(
|
||||
relu_response.T, # type: ignore
|
||||
cmap="viridis",
|
||||
aspect="auto",
|
||||
extent=[theta[0], theta[-1], phi[0], phi[-1]],
|
||||
)
|
||||
plt.xlabel("orientation (degs)")
|
||||
plt.ylabel("phase (degs)")
|
||||
plt.colorbar(label="activity")
|
||||
plt.title(f"Weight {j}", fontsize=16)
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_all_tuning_curves(response_array, theta, orientations=32, phases=4):
|
||||
# plot tuning curves
|
||||
plt.figure(figsize=(12, 15))
|
||||
for i in range(response_array.shape[0]):
|
||||
# synaptic input
|
||||
in_neuron = response_array[i].reshape(orientations, phases)
|
||||
mean_syn_in = in_neuron.mean(axis=1)
|
||||
|
||||
# after non linearity
|
||||
out_relu = torch.nn.functional.leaky_relu(torch.tensor(response_array[i]))
|
||||
out_relu = out_relu.numpy().reshape(orientations, phases)
|
||||
mean_out_relu = out_relu.mean(axis=1) # type: ignore
|
||||
|
||||
plt.subplot(8, 4, i + 1)
|
||||
plt.plot(theta, mean_syn_in)
|
||||
plt.plot(theta, mean_out_relu)
|
||||
plt.xlabel("Theta (degs)")
|
||||
plt.ylabel("Activity")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
def calculate_responses(weights, plot_single_responses):
|
||||
# load Gabor filter
|
||||
orientations = 32
|
||||
phases = 4
|
||||
filename: str = "gabor_dict_32o_4p.npy"
|
||||
filepath: str = os.path.join(
|
||||
"D:/Katha/Neuroscience/Semester 4/newCode/kk_contour_net_shallow-main/investigate",
|
||||
filename,
|
||||
)
|
||||
gabor_dict = np.load(filepath)
|
||||
|
||||
# collect data
|
||||
all_responses: list = []
|
||||
after_relu = np.zeros((weights.shape[0], orientations))
|
||||
for j in range(weights.shape[0]):
|
||||
w0 = weights[j, 0].detach().cpu() # .numpy()
|
||||
|
||||
response: list = []
|
||||
for i in range(gabor_dict.shape[0]):
|
||||
gabor = gabor_dict[i, 0]
|
||||
if w0.shape[0] != gabor.shape[0]:
|
||||
# TODO: for later layers
|
||||
# get number to pad
|
||||
pad = (gabor.shape[0] - w0.shape[0]) // 2
|
||||
|
||||
# pad:
|
||||
w_pad = torch.nn.functional.pad(
|
||||
w0, (pad, pad, pad, pad), mode="constant", value=0
|
||||
)
|
||||
w_pad = w_pad.numpy()
|
||||
|
||||
else:
|
||||
w_pad = w0.numpy()
|
||||
|
||||
dot = np.sum(gabor * w_pad)
|
||||
response.append(dot)
|
||||
|
||||
# axis for plotting:
|
||||
theta = np.rad2deg(np.arange(orientations) * np.pi / orientations)
|
||||
phi = np.rad2deg(np.arange(phases) * 2 * np.pi / phases)
|
||||
|
||||
# to array + mean
|
||||
syn_input = np.array(response)
|
||||
syn_input = syn_input.reshape(orientations, phases)
|
||||
mean_response_orient = syn_input.mean(axis=1)
|
||||
|
||||
# leaky relu:
|
||||
relu_response = torch.nn.functional.leaky_relu(
|
||||
torch.tensor(response), negative_slope=0.1
|
||||
)
|
||||
relu_response = relu_response.numpy().reshape(orientations, phases)
|
||||
mean_relu_orient = relu_response.mean(axis=1) # type: ignore
|
||||
|
||||
# append to save:
|
||||
after_relu[j] = mean_relu_orient
|
||||
|
||||
# plot 2D:
|
||||
if plot_single_responses:
|
||||
plot_single_phase_orientation(
|
||||
syn_input=syn_input,
|
||||
relu_response=relu_response,
|
||||
theta=theta,
|
||||
phi=phi,
|
||||
j=j,
|
||||
)
|
||||
|
||||
# plot tuning curve
|
||||
plot_single_tuning_curve(
|
||||
mean_syn_input=mean_response_orient,
|
||||
mean_relu_response=mean_relu_orient,
|
||||
theta=theta,
|
||||
)
|
||||
|
||||
# collect response for each weight
|
||||
all_responses.append(response)
|
||||
|
||||
# to array:
|
||||
response_array = np.array(all_responses)
|
||||
|
||||
return response_array, after_relu, theta
|
||||
|
||||
|
||||
def plot_mean_resp_after_relu(mean_response, theta):
|
||||
# plot tuning curves
|
||||
plt.figure(figsize=(12, 15))
|
||||
for i in range(mean_response.shape[0]):
|
||||
plt.subplot(8, 4, i + 1)
|
||||
plt.plot(theta, mean_response[i])
|
||||
plt.xlabel("Theta (degs)")
|
||||
plt.ylabel("Activity")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
def load_data_from_cnn(
|
||||
cnn_name: str,
|
||||
plot_responses: bool,
|
||||
do_stats: bool,
|
||||
plot_single_responses: bool = False,
|
||||
):
|
||||
# path to NN
|
||||
|
||||
if do_stats:
|
||||
PATH = cnn_name
|
||||
else:
|
||||
PATH = f"D:/Katha/Neuroscience/Semester 4/newCode/kk_contour_net_shallow-main/trained_models/{cnn_name}"
|
||||
|
||||
# load and evaluate model
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = torch.load(PATH).to(device)
|
||||
|
||||
# Set the model to evaluation mode
|
||||
model.eval()
|
||||
print(model)
|
||||
|
||||
# load NNs conv1 weights:
|
||||
weights = model[0]._parameters["weight"].data
|
||||
|
||||
# call
|
||||
response_array, mean_response_after_relu, theta = calculate_responses(
|
||||
weights=weights, plot_single_responses=plot_single_responses
|
||||
)
|
||||
|
||||
# plot
|
||||
if plot_responses:
|
||||
plot_all_tuning_curves(response_array=response_array, theta=theta)
|
||||
plot_mean_resp_after_relu(mean_response=mean_response_after_relu, theta=theta)
|
||||
|
||||
return np.array(mean_response_after_relu)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# path to NN
|
||||
nn = "ArghCNN_numConvLayers3_outChannels[32, 8, 8]_kernelSize[7, 15]_leaky relu_stride1_trainFirstConvLayerTrue_seed291853_Natural_314Epoch_0908-1206.pt"
|
||||
_ = load_data_from_cnn(cnn_name=nn, plot_responses=True, do_stats=False)
|
||||
exit()
|
||||
|
||||
PATH = f"D:/Katha/Neuroscience/Semester 4/newCode/kk_contour_net_shallow-main/trained_models/{nn}"
|
||||
|
||||
# load and evaluate model
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = torch.load(PATH).to(device)
|
||||
|
||||
# Set the model to evaluation mode
|
||||
model.eval()
|
||||
print(model)
|
||||
|
||||
# load NNs conv1 weights:
|
||||
weights = model[0]._parameters["weight"].data
|
||||
|
||||
# plot?
|
||||
plot_single_responses: bool = False
|
||||
|
||||
# call
|
||||
response_array, mean_response_after_relu, theta = calculate_responses(
|
||||
weights=weights, plot_single_responses=plot_single_responses
|
||||
)
|
||||
|
||||
# plot
|
||||
plot_all_tuning_curves(response_array=response_array, theta=theta)
|
||||
plot_mean_resp_after_relu(mean_response=mean_response_after_relu, theta=theta)
|
||||
print()
|
|
@ -0,0 +1,272 @@
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import warnings
|
||||
import matplotlib as mpl
|
||||
|
||||
mpl.rcParams["text.usetex"] = True
|
||||
mpl.rcParams["font.family"] = "serif"
|
||||
mpl.rcParams["font.size"] = 15
|
||||
|
||||
# suppress warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def autolabel(rects, ax):
|
||||
for rect in rects:
|
||||
height = rect.get_height()
|
||||
ax.annotate(
|
||||
f"{height:.2f}",
|
||||
xy=(rect.get_x() + rect.get_width() / 2, height),
|
||||
xytext=(-10, 3),
|
||||
textcoords="offset points",
|
||||
ha="right",
|
||||
va="bottom",
|
||||
fontsize=17,
|
||||
)
|
||||
|
||||
|
||||
def plot_mean_percentile_amplit_ratio(
|
||||
ratio_classic_21: float,
|
||||
ratio_corner_21: float,
|
||||
ratio_classic_321: float,
|
||||
ratio_corner_321: float,
|
||||
ratio_classic_332: float,
|
||||
ratio_corner_332: float,
|
||||
percentile_classic21: tuple,
|
||||
percentile_classic321: tuple,
|
||||
percentile_classic_332: tuple,
|
||||
percentile_corner_21: tuple,
|
||||
percentile_corner_321: tuple,
|
||||
percentile_corner_332: tuple,
|
||||
save_name: str,
|
||||
saveplot: bool,
|
||||
):
|
||||
num_von_mises = [2, 3, 3] # X-axis ticks
|
||||
|
||||
# bar setup
|
||||
bar_width = 0.35
|
||||
index = np.arange(len(num_von_mises))
|
||||
|
||||
# position error bars correctly:
|
||||
lower_err_classic = [
|
||||
ratio_classic_21 - percentile_classic21[0],
|
||||
ratio_classic_332 - percentile_classic_332[0],
|
||||
ratio_classic_321 - percentile_classic321[0],
|
||||
]
|
||||
upper_err_classic = [
|
||||
percentile_classic21[1] - ratio_classic_21,
|
||||
percentile_classic_332[1] - ratio_classic_332,
|
||||
percentile_classic321[1] - ratio_classic_321,
|
||||
]
|
||||
|
||||
lower_err_corner = [
|
||||
ratio_corner_21 - percentile_corner_21[0],
|
||||
ratio_corner_332 - percentile_corner_332[0],
|
||||
ratio_corner_321 - percentile_corner_321[0],
|
||||
]
|
||||
upper_err_corner = [
|
||||
percentile_corner_21[1] - ratio_corner_21,
|
||||
percentile_corner_332[1] - ratio_corner_332,
|
||||
percentile_corner_321[1] - ratio_corner_321,
|
||||
]
|
||||
|
||||
yerr_classic = [lower_err_classic, upper_err_classic]
|
||||
yerr_corner = [lower_err_corner, upper_err_corner]
|
||||
|
||||
# subplots
|
||||
fig, ax = plt.subplots(figsize=(7, 7))
|
||||
bars_classic = ax.bar(
|
||||
index - bar_width / 2,
|
||||
[ratio_classic_21, ratio_classic_332, ratio_classic_321],
|
||||
bar_width,
|
||||
yerr=yerr_classic,
|
||||
capsize=5,
|
||||
label="Classic",
|
||||
color="cornflowerblue",
|
||||
)
|
||||
bars_corner = ax.bar(
|
||||
index + bar_width / 2,
|
||||
[ratio_corner_21, ratio_corner_332, ratio_corner_321],
|
||||
bar_width,
|
||||
yerr=yerr_corner,
|
||||
capsize=5,
|
||||
label="Corner",
|
||||
color="coral",
|
||||
)
|
||||
|
||||
autolabel(bars_classic, ax)
|
||||
autolabel(bars_corner, ax)
|
||||
|
||||
ax.set_ylabel("Median ratio of amplitudes", fontsize=18)
|
||||
ax.set_xticks(index)
|
||||
ax.set_xticklabels(
|
||||
[
|
||||
"2 von Mises \n(min/max)",
|
||||
"3 von Mises \n(mid/max)",
|
||||
"3 von Mises\n(min/mid)",
|
||||
],
|
||||
fontsize=17,
|
||||
)
|
||||
ax.legend(fontsize=17)
|
||||
ax.set_ylim(bottom=0.0)
|
||||
|
||||
# plot
|
||||
plt.yticks(fontsize=17)
|
||||
plt.tight_layout()
|
||||
if saveplot:
|
||||
plt.savefig(
|
||||
f"additional thesis plots/saved_plots/fitkarotte/median_quartiles_ampli_ratio_{save_name}_corn_class.pdf",
|
||||
dpi=300,
|
||||
bbox_inches="tight",
|
||||
)
|
||||
plt.show(block=True)
|
||||
|
||||
|
||||
def plot_means_std_corner_classic(
|
||||
means_classic: list,
|
||||
means_corner: list,
|
||||
std_classic: list,
|
||||
std_corner: list,
|
||||
saveplot: bool,
|
||||
save_name: str,
|
||||
):
|
||||
num_von_mises = [1, 2, 3] # X-axis ticks
|
||||
|
||||
# bar setup
|
||||
bar_width = 0.35
|
||||
index = np.arange(len(num_von_mises))
|
||||
|
||||
# subplots
|
||||
fig, ax = plt.subplots(figsize=(7, 7))
|
||||
bars_classic = ax.bar(
|
||||
index - bar_width / 2,
|
||||
means_classic,
|
||||
bar_width,
|
||||
yerr=std_classic,
|
||||
capsize=5,
|
||||
label="Classic",
|
||||
color="cornflowerblue",
|
||||
)
|
||||
bars_corner = ax.bar(
|
||||
index + bar_width / 2,
|
||||
means_corner,
|
||||
bar_width,
|
||||
yerr=std_corner,
|
||||
capsize=5,
|
||||
label="Corner",
|
||||
color="coral",
|
||||
)
|
||||
|
||||
autolabel(bars_classic, ax)
|
||||
autolabel(bars_corner, ax)
|
||||
|
||||
ax.set_ylabel("Average number of fits", fontsize=17)
|
||||
ax.set_xticks(index)
|
||||
ax.set_xticklabels(["1 von Mises", "2 von Mises", "3 von Mises"], fontsize=17)
|
||||
ax.legend(fontsize=16)
|
||||
ax.set_ylim(bottom=0.0)
|
||||
|
||||
# plot
|
||||
plt.yticks(fontsize=17)
|
||||
plt.tight_layout()
|
||||
if saveplot:
|
||||
plt.savefig(
|
||||
f"additional thesis plots/saved_plots/fitkarotte/y_lim_mean_fits_{save_name}_corn_class.pdf",
|
||||
dpi=300,
|
||||
)
|
||||
plt.show(block=True)
|
||||
|
||||
|
||||
def plot_mean_std_amplit_ratio(
|
||||
ratio_classic_21: float,
|
||||
std_class_21: float,
|
||||
ratio_corner_21: float,
|
||||
std_corn_21: float,
|
||||
ratio_classic_321: float,
|
||||
std_class_321: float,
|
||||
ratio_corner_321: float,
|
||||
std_corn_321: float,
|
||||
ratio_classic_332: float,
|
||||
std_class_332: float,
|
||||
ratio_corner_332: float,
|
||||
std_corn_332: float,
|
||||
save_name: str,
|
||||
saveplot: bool,
|
||||
):
|
||||
num_von_mises = [2, 3, 3] # X-axis ticks
|
||||
|
||||
# bar setup
|
||||
bar_width = 0.35
|
||||
index = np.arange(len(num_von_mises))
|
||||
|
||||
# subplots
|
||||
fig, ax = plt.subplots(figsize=(12, 7))
|
||||
bars_classic = ax.bar(
|
||||
index - bar_width / 2,
|
||||
[ratio_classic_21, ratio_classic_332, ratio_classic_321],
|
||||
bar_width,
|
||||
yerr=[std_class_21, std_class_332, std_class_321],
|
||||
capsize=5,
|
||||
label="Classic",
|
||||
color="cornflowerblue",
|
||||
)
|
||||
bars_corner = ax.bar(
|
||||
index + bar_width / 2,
|
||||
[ratio_corner_21, ratio_corner_332, ratio_corner_321],
|
||||
bar_width,
|
||||
yerr=[std_corn_21, std_corn_332, std_corn_321],
|
||||
capsize=5,
|
||||
label="Corner",
|
||||
color="coral",
|
||||
)
|
||||
|
||||
autolabel(bars_classic, ax)
|
||||
autolabel(bars_corner, ax)
|
||||
|
||||
ax.set_ylabel("Mean ratio of amplitudes", fontsize=17)
|
||||
ax.set_xticks(index)
|
||||
ax.set_xticklabels(
|
||||
[
|
||||
"2 von Mises \n(max/min)",
|
||||
"3 von Mises \n(max/mid)",
|
||||
"3 von Mises\n(mid/min)",
|
||||
],
|
||||
fontsize=17,
|
||||
)
|
||||
ax.legend(fontsize=16)
|
||||
ax.set_ylim(bottom=0.0)
|
||||
|
||||
# plot
|
||||
plt.yticks(fontsize=17)
|
||||
plt.tight_layout()
|
||||
if saveplot:
|
||||
plt.savefig(
|
||||
f"additional thesis plots/saved_plots/fitkarotte/y_lim_mean_std_ampli_ratio_{save_name}_corn_class.pdf",
|
||||
dpi=300,
|
||||
bbox_inches="tight",
|
||||
)
|
||||
plt.show(block=True)
|
||||
|
||||
|
||||
def plot_hist(fits_per_mises, num_mises: int):
|
||||
"""
|
||||
Plot to see if data has normal distribution
|
||||
"""
|
||||
# get correct x-ticks
|
||||
x_ticks = np.arange(start=min(fits_per_mises), stop=max(fits_per_mises) + 1, step=1)
|
||||
|
||||
# plot
|
||||
plt.hist(
|
||||
fits_per_mises,
|
||||
# bins=bins,
|
||||
alpha=0.5,
|
||||
label=f"{num_mises} von Mises function",
|
||||
align="mid",
|
||||
)
|
||||
plt.xlabel(f"Number of weights fitted with {num_mises} ")
|
||||
plt.ylabel(f"Frequency of fit with {num_mises} for 20 CNNs")
|
||||
plt.title(f"Histogram of Fits with {num_mises} von Mises Function")
|
||||
plt.xticks(x_ticks)
|
||||
plt.legend()
|
||||
plt.tight_layout()
|
||||
plt.show(block=True)
|
|
@ -0,0 +1,4 @@
|
|||
Folder psychometric_curves:
|
||||
|
||||
1. error_bar_performance_pfinkel:
|
||||
* caculates the average performance of all 20 CNNs across all stimulus conditions and path angles within one stimulus condition
|
|
@ -0,0 +1,223 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib as mpl
|
||||
import os
|
||||
import datetime
|
||||
|
||||
# import re
|
||||
# import glob
|
||||
# from natsort import natsorted
|
||||
|
||||
mpl.rcParams["text.usetex"] = True
|
||||
mpl.rcParams["font.family"] = "serif"
|
||||
|
||||
from functions.alicorn_data_loader import alicorn_data_loader
|
||||
from functions.create_logger import create_logger
|
||||
|
||||
|
||||
def performance_pfinkel_plot(
|
||||
performances_list: list[dict],
|
||||
all_performances: dict,
|
||||
labels: list[str],
|
||||
save_name: str,
|
||||
logger,
|
||||
) -> None:
|
||||
figure_path: str = "rerun_errorbar_performance_pfinkel"
|
||||
os.makedirs(figure_path, exist_ok=True)
|
||||
|
||||
plt.figure(figsize=[10, 10])
|
||||
with open(f"./{figure_path}/performances_{save_name}_{current}.txt", "w") as f:
|
||||
for id, selected_condition in enumerate(condition):
|
||||
f.write(
|
||||
f"Condition:{selected_condition} Path angle (in °), Mean accuracy (\\%), Standard deviation (\\%)\n"
|
||||
)
|
||||
|
||||
x_values = np.array(num_pfinkel)
|
||||
y_values = np.array(
|
||||
[
|
||||
np.mean(all_performances[selected_condition][pfinkel])
|
||||
for pfinkel in num_pfinkel
|
||||
]
|
||||
)
|
||||
yerr_values = np.array(
|
||||
[
|
||||
np.std(all_performances[selected_condition][pfinkel])
|
||||
for pfinkel in num_pfinkel
|
||||
]
|
||||
)
|
||||
|
||||
for x, y, yerr in zip(x_values, y_values, yerr_values):
|
||||
f.write(f"{x}, {y/100.0:.3f}, {yerr/100.0:.3f}\n")
|
||||
f.write(f"{x}, {y}, {yerr}\n")
|
||||
|
||||
plt.errorbar(
|
||||
x_values,
|
||||
y_values / 100.0,
|
||||
yerr=yerr_values / 100.0,
|
||||
fmt="o",
|
||||
capsize=5,
|
||||
label=labels[id],
|
||||
)
|
||||
plt.xticks(x_values)
|
||||
plt.title("Average accuracy", fontsize=19)
|
||||
plt.xlabel("Path angle (in °)", fontsize=18)
|
||||
plt.ylabel("Accuracy (\\%)", fontsize=18)
|
||||
plt.ylim(0.5, 1.0)
|
||||
plt.legend(fontsize=15)
|
||||
|
||||
# Increase tick label font size
|
||||
plt.xticks(fontsize=17)
|
||||
plt.yticks(fontsize=17)
|
||||
plt.grid(True)
|
||||
plt.tight_layout()
|
||||
logger.info("")
|
||||
logger.info("Saved in:")
|
||||
|
||||
print(
|
||||
os.path.join(
|
||||
figure_path,
|
||||
f"ylim_ErrorBarPerformancePfinkel_{save_name}_{current}.pdf",
|
||||
)
|
||||
)
|
||||
plt.savefig(
|
||||
os.path.join(
|
||||
figure_path,
|
||||
f"ylim_ErrorBarPerformancePfinkel_{save_name}_{current}.pdf",
|
||||
),
|
||||
dpi=300,
|
||||
bbox_inches="tight",
|
||||
)
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_path: str = "classic_3288_fest"
|
||||
print(model_path)
|
||||
data_path: str = "/home/kk/Documents/Semester4/code/RenderStimuli/Output/"
|
||||
|
||||
# num stimuli per Pfinkel and batch size
|
||||
stim_per_pfinkel: int = 10000
|
||||
batch_size: int = 1000
|
||||
# stimulus condition:
|
||||
performances_list: list = []
|
||||
|
||||
condition: list[str] = ["Coignless", "Natural", "Angular"]
|
||||
figure_label: list[str] = ["Classic", "Corner", "Bridge"]
|
||||
# load test data:
|
||||
num_pfinkel: list = np.arange(0, 100, 10).tolist()
|
||||
image_scale: float = 255.0
|
||||
|
||||
# ------------------------------------------
|
||||
|
||||
# create logger:
|
||||
logger = create_logger(
|
||||
save_logging_messages=False,
|
||||
display_logging_messages=True,
|
||||
model_name=model_path,
|
||||
)
|
||||
|
||||
device_str: str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"Using {device_str} device")
|
||||
device: torch.device = torch.device(device_str)
|
||||
torch.set_default_dtype(torch.float32)
|
||||
|
||||
# current time:
|
||||
current = datetime.datetime.now().strftime("%d%m-%H%M")
|
||||
|
||||
all_performances: dict = {
|
||||
condition_name: {pfinkel: [] for pfinkel in num_pfinkel}
|
||||
for condition_name in condition
|
||||
}
|
||||
|
||||
for filename in os.listdir(model_path):
|
||||
if filename.endswith(".pt"):
|
||||
model_filename = os.path.join(model_path, filename)
|
||||
model = torch.load(model_filename, map_location=device)
|
||||
model.eval()
|
||||
print(model_filename)
|
||||
|
||||
for selected_condition in condition:
|
||||
# save performances:
|
||||
logger.info(f"Condition: {selected_condition}")
|
||||
performances: dict = {}
|
||||
for pfinkel in num_pfinkel:
|
||||
test_loss: float = 0.0
|
||||
correct: int = 0
|
||||
pattern_count: int = 0
|
||||
|
||||
data_test = alicorn_data_loader(
|
||||
num_pfinkel=[pfinkel],
|
||||
load_stimuli_per_pfinkel=stim_per_pfinkel,
|
||||
condition=selected_condition,
|
||||
logger=logger,
|
||||
data_path=data_path,
|
||||
)
|
||||
loader = torch.utils.data.DataLoader(
|
||||
data_test, shuffle=False, batch_size=batch_size
|
||||
)
|
||||
|
||||
# start testing network on new stimuli:
|
||||
logger.info("")
|
||||
logger.info(
|
||||
f"-==- Start {selected_condition} " f"Pfinkel {pfinkel}° -==-"
|
||||
)
|
||||
with torch.no_grad():
|
||||
for batch_num, data in enumerate(loader):
|
||||
label = data[0].to(device)
|
||||
image = data[1].type(dtype=torch.float32).to(device)
|
||||
image /= image_scale
|
||||
|
||||
# compute prediction error;
|
||||
output = model(image)
|
||||
|
||||
# Label Typecast:
|
||||
label = label.to(device)
|
||||
|
||||
# loss and optimization
|
||||
loss = torch.nn.functional.cross_entropy(
|
||||
output, label, reduction="sum"
|
||||
)
|
||||
pattern_count += int(label.shape[0])
|
||||
test_loss += float(loss)
|
||||
prediction = output.argmax(dim=1)
|
||||
correct += prediction.eq(label).sum().item()
|
||||
|
||||
total_number_of_pattern: int = int(len(loader)) * int(
|
||||
label.shape[0]
|
||||
)
|
||||
|
||||
# logging:
|
||||
logger.info(
|
||||
(
|
||||
f"{selected_condition},{pfinkel}° "
|
||||
"Pfinkel: "
|
||||
f"[{int(pattern_count)}/{total_number_of_pattern} ({100.0 * pattern_count / total_number_of_pattern:.2f}%)],"
|
||||
f" Average loss: {test_loss / pattern_count:.3e}, "
|
||||
"Accuracy: "
|
||||
f"{100.0 * correct / pattern_count:.2f}% "
|
||||
)
|
||||
)
|
||||
|
||||
performances[pfinkel] = {
|
||||
"pfinkel": pfinkel,
|
||||
"test_accuracy": 100 * correct / pattern_count,
|
||||
"test_losses": float(loss) / pattern_count,
|
||||
}
|
||||
all_performances[selected_condition][pfinkel].append(
|
||||
100 * correct / pattern_count
|
||||
)
|
||||
|
||||
performances_list.append(performances)
|
||||
else:
|
||||
print("No files found!")
|
||||
break
|
||||
|
||||
performance_pfinkel_plot(
|
||||
performances_list=performances_list,
|
||||
all_performances=all_performances,
|
||||
labels=figure_label,
|
||||
save_name=model_path,
|
||||
logger=logger,
|
||||
)
|
||||
logger.info("-==- DONE -==-")
|
|
@ -0,0 +1,603 @@
|
|||
# %%
|
||||
#
|
||||
# contours.py
|
||||
#
|
||||
# Tools for contour integration studies
|
||||
#
|
||||
# Version 1.0, 24.03.2023
|
||||
#
|
||||
|
||||
#
|
||||
# Coordinate system assumptions:
|
||||
#
|
||||
# for arrays:
|
||||
# [..., HEIGHT, WIDTH], origin is on TOP LEFT
|
||||
# HEIGHT indices *decrease* with increasing y-coordinates (reversed)
|
||||
# WIDTH indices *increase* with increasing x-coordinates (normal)
|
||||
#
|
||||
# Orientations:
|
||||
# 0 is horizontal, orientation *increase* counter-clockwise
|
||||
# Corner elements, quantified by [dir_source, dir_change]:
|
||||
# - consist of two legs
|
||||
# - contour *enters* corner from *source direction* at one leg
|
||||
# and goes from border to its center...
|
||||
# - contour path changes by *direction change* and goes
|
||||
# from center to the border
|
||||
#
|
||||
|
||||
import torch
|
||||
import time
|
||||
import matplotlib.pyplot as plt
|
||||
import math
|
||||
import scipy.io
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
torch_device = "cuda"
|
||||
default_dtype = torch.float32
|
||||
torch.set_default_dtype(default_dtype)
|
||||
torch.device(torch_device)
|
||||
|
||||
|
||||
#
|
||||
# performs a coordinate transform (rotation with phi around origin)
|
||||
# rotation is performed CLOCKWISE with increasing phi
|
||||
#
|
||||
# remark: rotating a mesh grid by phi and orienting an image element
|
||||
# along the new x-axis is EQUIVALENT to rotating the image element
|
||||
# by -phi (so this realizes a rotation COUNTER-CLOCKWISE with
|
||||
# increasing phi)
|
||||
#
|
||||
def rotate_CW(x: torch.Tensor, y: torch.Tensor, phi: torch.float32): # type: ignore
|
||||
xr = +x * torch.cos(phi) + y * torch.sin(phi)
|
||||
yr = -x * torch.sin(phi) + y * torch.cos(phi)
|
||||
|
||||
return xr, yr
|
||||
|
||||
|
||||
#
|
||||
# renders a Gabor with (or without) corner
|
||||
#
|
||||
def gaborner(
|
||||
r_gab: int, # radius, size will be 2*r_gab+1
|
||||
dir_source: float, # contour enters in this dir
|
||||
dir_change: float, # contour turns around by this dir
|
||||
lambdah: float, # wavelength of Gabor
|
||||
sigma: float, # half-width of Gabor
|
||||
phase: float, # phase of Gabor
|
||||
normalize: bool, # normalize patch to zero
|
||||
torch_device: str, # GPU or CPU...
|
||||
) -> torch.Tensor:
|
||||
# incoming dir: change to outgoing dir
|
||||
dir1 = dir_source + torch.pi
|
||||
nook = dir_change - torch.pi
|
||||
|
||||
# create coordinate grids
|
||||
d_gab = 2 * r_gab + 1
|
||||
x = -r_gab + torch.arange(d_gab, device=torch_device)
|
||||
yg, xg = torch.meshgrid(x, x, indexing="ij")
|
||||
|
||||
# put into tensor for performing vectorized scalar products
|
||||
xyg = torch.zeros([d_gab, d_gab, 1, 2], device=torch_device)
|
||||
xyg[:, :, 0, 0] = xg
|
||||
xyg[:, :, 0, 1] = yg
|
||||
|
||||
# create Gaussian hull
|
||||
gauss = torch.exp(-(xg**2 + yg**2) / 2 / sigma**2)
|
||||
gabor_corner = gauss.clone()
|
||||
|
||||
if (dir_change == 0) or (dir_change == torch.pi):
|
||||
# handle special case of straight Gabor or change by 180 deg
|
||||
|
||||
# vector orth to Gabor axis
|
||||
ev1_orth = torch.tensor(
|
||||
[math.cos(-dir1 + math.pi / 2), math.sin(-dir1 + math.pi / 2)],
|
||||
device=torch_device,
|
||||
)
|
||||
# project coords to orth vector to get distance
|
||||
legs = torch.cos(
|
||||
2
|
||||
* torch.pi
|
||||
* torch.matmul(xyg, ev1_orth.unsqueeze(1).unsqueeze(0).unsqueeze(0))
|
||||
/ lambdah
|
||||
+ phase
|
||||
)
|
||||
gabor_corner *= legs[:, :, 0, 0]
|
||||
|
||||
else:
|
||||
dir2 = dir1 + nook
|
||||
|
||||
# compute separation line between corner's legs
|
||||
ev1 = torch.tensor([math.cos(-dir1), math.sin(-dir1)], device=torch_device)
|
||||
ev2 = torch.tensor([math.cos(-dir2), math.sin(-dir2)], device=torch_device)
|
||||
v_towards_1 = (ev1 - ev2).unsqueeze(1).unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# which coords belong to which leg?
|
||||
which_side = torch.matmul(xyg, v_towards_1)[:, :, 0, 0]
|
||||
towards_1y, towards_1x = torch.where(which_side > 0)
|
||||
towards_2y, towards_2x = torch.where(which_side <= 0)
|
||||
|
||||
# compute orth distance to legs
|
||||
side_sign = -1 + 2 * ((dir_change % 2 * torch.pi) > torch.pi)
|
||||
ev12 = ev1 + ev2
|
||||
v1_orth = ev12 - ev1 * torch.matmul(ev12, ev1)
|
||||
v2_orth = ev12 - ev2 * torch.matmul(ev12, ev2)
|
||||
ev1_orth = side_sign * v1_orth / torch.sqrt((v1_orth**2).sum())
|
||||
ev2_orth = side_sign * v2_orth / torch.sqrt((v2_orth**2).sum())
|
||||
|
||||
leg1 = torch.cos(
|
||||
2
|
||||
* torch.pi
|
||||
* torch.matmul(xyg, ev1_orth.unsqueeze(1).unsqueeze(0).unsqueeze(0))
|
||||
/ lambdah
|
||||
+ phase
|
||||
)
|
||||
leg2 = torch.cos(
|
||||
2
|
||||
* torch.pi
|
||||
* torch.matmul(xyg, ev2_orth.unsqueeze(1).unsqueeze(0).unsqueeze(0))
|
||||
/ lambdah
|
||||
+ phase
|
||||
)
|
||||
gabor_corner[towards_1y, towards_1x] *= leg1[towards_1y, towards_1x, 0, 0]
|
||||
gabor_corner[towards_2y, towards_2x] *= leg2[towards_2y, towards_2x, 0, 0]
|
||||
|
||||
# depending on phase, Gabor might not be normalized...
|
||||
if normalize:
|
||||
s = gabor_corner.sum()
|
||||
s0 = gauss.sum()
|
||||
gabor_corner -= s / s0 * gauss
|
||||
|
||||
return gabor_corner
|
||||
|
||||
|
||||
#
|
||||
# creates a filter bank of Gabor corners
|
||||
#
|
||||
# outputs:
|
||||
# filters: [n_source, n_change, HEIGHT, WIDTH]
|
||||
# dirs_source: [n_source]
|
||||
# dirs_change: [n_change]
|
||||
#
|
||||
def gaborner_filterbank(
|
||||
r_gab: int, # radius, size will be 2*r_gab+1
|
||||
n_source: int, # number of source orientations
|
||||
n_change: int, # number of direction changes
|
||||
lambdah: float, # wavelength of Gabor
|
||||
sigma: float, # half-width of Gabor
|
||||
phase: float, # phase of Gabor
|
||||
normalize: bool, # normalize patch to zero
|
||||
torch_device: str, # GPU or CPU...
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
kernels = torch.zeros(
|
||||
[n_source, n_change, 2 * r_gab + 1, 2 * r_gab + 1],
|
||||
device=torch_device,
|
||||
requires_grad=False,
|
||||
)
|
||||
dirs_source = 2 * torch.pi * torch.arange(n_source, device=torch_device) / n_source
|
||||
dirs_change = 2 * torch.pi * torch.arange(n_change, device=torch_device) / n_change
|
||||
|
||||
for i_source in range(n_source):
|
||||
for i_change in range(n_change):
|
||||
gabor_corner = gaborner(
|
||||
r_gab=r_gab,
|
||||
dir_source=dirs_source[i_source], # type: ignore
|
||||
dir_change=dirs_change[i_change], # type: ignore
|
||||
lambdah=lambdah,
|
||||
sigma=sigma,
|
||||
phase=phase,
|
||||
normalize=normalize,
|
||||
torch_device=torch_device,
|
||||
)
|
||||
kernels[i_source, i_change] = gabor_corner
|
||||
|
||||
# check = torch.isnan(gabor_corner).sum()
|
||||
# if check > 0:
|
||||
# print(i_source, i_change, check)
|
||||
# kernels[i_source, i_change] = 1
|
||||
|
||||
return kernels, dirs_source, dirs_change
|
||||
|
||||
|
||||
def discretize_stimuli(
|
||||
posori,
|
||||
x_range: tuple,
|
||||
y_range: tuple,
|
||||
scale_factor: float,
|
||||
r_gab_PIX: int,
|
||||
n_source: int,
|
||||
n_change: int,
|
||||
torch_device: str,
|
||||
) -> torch.Tensor:
|
||||
# check correct input size
|
||||
s = posori.shape
|
||||
assert len(s) == 2, "posori should be NDARRAY with N x 1 entries"
|
||||
assert s[1] == 1, "posori should be NDARRAY with N x 1 entries"
|
||||
|
||||
# determine size of (extended) canvas
|
||||
x_canvas_PIX = torch.tensor(
|
||||
(x_range[1] - x_range[0]) * scale_factor, device=torch_device
|
||||
).ceil()
|
||||
y_canvas_PIX = torch.tensor(
|
||||
(y_range[1] - y_range[0]) * scale_factor, device=torch_device
|
||||
).ceil()
|
||||
x_canvas_ext_PIX = int(x_canvas_PIX + 2 * r_gab_PIX)
|
||||
y_canvas_ext_PIX = int(y_canvas_PIX + 2 * r_gab_PIX)
|
||||
|
||||
# get number of contours
|
||||
n_contours = s[0]
|
||||
index_srcchg = []
|
||||
index_y = []
|
||||
index_x = []
|
||||
for i_contour in range(n_contours):
|
||||
x_y_src_chg = torch.asarray(posori[i_contour, 0][1:, :].copy())
|
||||
x_y_src_chg[2] += torch.pi
|
||||
|
||||
# if i_contour == 0:
|
||||
# print(x_y_src_chg[2][:3])
|
||||
|
||||
# compute integer coordinates and find all visible elements
|
||||
x = ((x_y_src_chg[0] - x_range[0]) * scale_factor + r_gab_PIX).type(torch.long)
|
||||
y = y_canvas_ext_PIX - (
|
||||
(x_y_src_chg[1] - y_range[0]) * scale_factor + r_gab_PIX
|
||||
).type(torch.long)
|
||||
i_visible = torch.where(
|
||||
(x >= 0) * (y >= 0) * (x < x_canvas_ext_PIX) * (y < y_canvas_ext_PIX)
|
||||
)[0]
|
||||
|
||||
# compute integer (changes of) directions
|
||||
i_source = (
|
||||
((((x_y_src_chg[2]) / (2 * torch.pi)) + 1 / (2 * n_source)) % 1) * n_source
|
||||
).type(torch.long)
|
||||
i_change = (
|
||||
(((x_y_src_chg[3] / (2 * torch.pi)) + 1 / (2 * n_change)) % 1) * n_change
|
||||
).type(torch.long)
|
||||
|
||||
# stimulus = torch.zeros(
|
||||
# (n_source, n_change, y_canvas_ext_PIX, x_canvas_ext_PIX), device=torch_device
|
||||
# )
|
||||
# stimulus[i_source[i_visible], i_change[i_visible], y[i_visible], x[i_visible]] = 1
|
||||
|
||||
index_srcchg.append(i_source[i_visible] * n_change + i_change[i_visible])
|
||||
# index_change.append(i_change[i_visible])
|
||||
index_y.append(y[i_visible])
|
||||
index_x.append(x[i_visible])
|
||||
|
||||
return ( # type: ignore
|
||||
index_srcchg,
|
||||
index_x,
|
||||
index_y,
|
||||
x_canvas_ext_PIX,
|
||||
y_canvas_ext_PIX,
|
||||
)
|
||||
|
||||
|
||||
def render_stimulus(
|
||||
kernels, index_srcchg, index_y, index_x, y_canvas, x_canvas, torch_device
|
||||
):
|
||||
s = kernels.shape
|
||||
kx = s[-1]
|
||||
ky = s[-2]
|
||||
|
||||
stimulus = torch.zeros((y_canvas + ky - 1, x_canvas + kx - 1), device=torch_device)
|
||||
n = index_srcchg.size()[0]
|
||||
for i in torch.arange(n, device=torch_device):
|
||||
x = index_x[i]
|
||||
y = index_y[i]
|
||||
stimulus[y : y + ky, x : x + kx] += kernels[index_srcchg[i]]
|
||||
|
||||
return stimulus[ky - 1 : -(ky - 1), kx - 1 : -(kx - 1)]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
VERBOSE = True
|
||||
BENCH_CONVOLVE = False
|
||||
BENCH_GPU = True
|
||||
BENCH_CPU = True
|
||||
BENCH_DAVID = True
|
||||
|
||||
print("Testing contour rendering speed:")
|
||||
print("================================")
|
||||
|
||||
# load contours, multiplex coordinates to simulate a larger set of contours
|
||||
n_multiplex = 1000
|
||||
mat = scipy.io.loadmat("z.mat")
|
||||
posori = np.tile(mat["z"], (n_multiplex, 1))
|
||||
n_contours = posori.shape[0]
|
||||
print(f"Processing {n_contours} contour stimuli")
|
||||
|
||||
# how many contours to render simultaneously?
|
||||
n_simultaneous = 5
|
||||
n_simultaneous_chunks, n_remaining = divmod(n_contours, n_simultaneous)
|
||||
assert n_remaining == 0, "Check parameters for simultaneous contour rendering!"
|
||||
|
||||
# repeat some times for speed testing
|
||||
n_repeat = 10
|
||||
t_dis = torch.zeros((n_repeat + 2), device=torch_device)
|
||||
t_con = torch.zeros((n_repeat + 2), device=torch_device)
|
||||
t_rsg = torch.zeros((n_repeat + 2), device=torch_device)
|
||||
t_rsc = torch.zeros((n_repeat + 2), device="cpu")
|
||||
t_rsd = torch.zeros((n_repeat + 2), device="cpu")
|
||||
|
||||
# cutout for stimuli, and gabor parameters
|
||||
x_range = [140, 940]
|
||||
y_range = [140, 940]
|
||||
d_gab = 40
|
||||
lambdah = 12
|
||||
sigma = 8
|
||||
phase = 0.0
|
||||
normalize = True
|
||||
|
||||
# scale to convert coordinates to pixel values
|
||||
scale_factor = 0.25
|
||||
|
||||
# number of directions for dictionary
|
||||
n_source = 32
|
||||
n_change = 32
|
||||
|
||||
# convert sizes to pixel units
|
||||
lambdah_PIX = lambdah * scale_factor
|
||||
sigma_PIX = sigma * scale_factor
|
||||
r_gab_PIX = int(d_gab * scale_factor / 2)
|
||||
d_gab_PIX = r_gab_PIX * 2 + 1
|
||||
|
||||
# make filterbank
|
||||
kernels, dirs_source, dirs_change = gaborner_filterbank(
|
||||
r_gab=r_gab_PIX,
|
||||
n_source=n_source,
|
||||
n_change=n_change,
|
||||
lambdah=lambdah_PIX,
|
||||
sigma=sigma_PIX,
|
||||
phase=phase,
|
||||
normalize=normalize,
|
||||
torch_device=torch_device,
|
||||
)
|
||||
kernels = kernels.reshape([1, n_source * n_change, d_gab_PIX, d_gab_PIX])
|
||||
kernels_flip = kernels.flip(dims=(-1, -2))
|
||||
|
||||
# define "network" and put to cuda
|
||||
conv = torch.nn.Conv2d(
|
||||
in_channels=n_source * n_change,
|
||||
out_channels=1,
|
||||
kernel_size=d_gab_PIX,
|
||||
stride=1,
|
||||
device=torch_device,
|
||||
)
|
||||
conv.weight.data = kernels_flip
|
||||
|
||||
print("Discretizing START!!!")
|
||||
t_dis[0] = time.perf_counter()
|
||||
for i_rep in range(n_repeat):
|
||||
# discretize
|
||||
(
|
||||
index_srcchg,
|
||||
index_x,
|
||||
index_y,
|
||||
x_canvas,
|
||||
y_canvas,
|
||||
) = discretize_stimuli(
|
||||
posori=posori,
|
||||
x_range=x_range, # type: ignore
|
||||
y_range=y_range, # type: ignore
|
||||
scale_factor=scale_factor,
|
||||
r_gab_PIX=r_gab_PIX,
|
||||
n_source=n_source,
|
||||
n_change=n_change,
|
||||
torch_device=torch_device,
|
||||
)
|
||||
t_dis[i_rep + 1] = time.perf_counter()
|
||||
t_dis[-1] = time.perf_counter()
|
||||
print("Discretizing END!!!")
|
||||
|
||||
if BENCH_CONVOLVE:
|
||||
print("Allocating!")
|
||||
stimuli = torch.zeros(
|
||||
[n_simultaneous, n_source * n_change, y_canvas, x_canvas],
|
||||
device=torch_device,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
print("Generation by CONVOLUTION start!")
|
||||
t_con[0] = time.perf_counter()
|
||||
for i_rep in torch.arange(n_repeat):
|
||||
for i_simultaneous_chunks in torch.arange(n_simultaneous_chunks):
|
||||
i_ofs = i_simultaneous_chunks * n_simultaneous
|
||||
|
||||
for i_sim in torch.arange(n_simultaneous):
|
||||
stimuli[
|
||||
i_sim,
|
||||
index_srcchg[i_sim + i_ofs],
|
||||
index_y[i_sim + i_ofs],
|
||||
index_x[i_sim + i_ofs],
|
||||
] = 1
|
||||
|
||||
output = conv(stimuli)
|
||||
|
||||
for i_sim in range(n_simultaneous):
|
||||
stimuli[
|
||||
i_sim,
|
||||
index_srcchg[i_sim + i_ofs],
|
||||
index_y[i_sim + i_ofs],
|
||||
index_x[i_sim + i_ofs],
|
||||
] = 0
|
||||
|
||||
t_con[i_rep + 1] = time.perf_counter()
|
||||
t_con[-1] = time.perf_counter()
|
||||
print("Generation by CONVOLUTION stop!")
|
||||
|
||||
if BENCH_GPU:
|
||||
print("Generation by GPU start!")
|
||||
output_gpu = torch.zeros(
|
||||
(
|
||||
n_contours,
|
||||
y_canvas - d_gab_PIX + 1,
|
||||
x_canvas - d_gab_PIX + 1,
|
||||
),
|
||||
device=torch_device,
|
||||
)
|
||||
t_rsg[0] = time.perf_counter()
|
||||
for i_rep in torch.arange(n_repeat):
|
||||
for i_con in torch.arange(n_contours):
|
||||
output_gpu[i_con] = render_stimulus(
|
||||
kernels=kernels[0],
|
||||
index_srcchg=index_srcchg[i_con],
|
||||
index_y=index_y[i_con],
|
||||
index_x=index_x[i_con],
|
||||
y_canvas=y_canvas,
|
||||
x_canvas=x_canvas,
|
||||
torch_device=torch_device,
|
||||
)
|
||||
# output_gpu = torch.clip(output_gpu, -1, +1)
|
||||
|
||||
t_rsg[i_rep + 1] = time.perf_counter()
|
||||
t_rsg[-1] = time.perf_counter()
|
||||
print("Generation by GPU stop!")
|
||||
|
||||
if BENCH_CPU:
|
||||
print("Generation by CPU start!")
|
||||
output_cpu = torch.zeros(
|
||||
(
|
||||
n_contours,
|
||||
y_canvas - d_gab_PIX + 1,
|
||||
x_canvas - d_gab_PIX + 1,
|
||||
),
|
||||
device="cpu",
|
||||
)
|
||||
kernels_cpu = kernels.detach().cpu()
|
||||
t_rsc[0] = time.perf_counter()
|
||||
for i_rep in range(n_repeat):
|
||||
for i_con in range(n_contours):
|
||||
output_cpu[i_con] = render_stimulus(
|
||||
kernels=kernels_cpu[0],
|
||||
index_srcchg=index_srcchg[i_con],
|
||||
index_y=index_y[i_con],
|
||||
index_x=index_x[i_con],
|
||||
y_canvas=y_canvas,
|
||||
x_canvas=x_canvas,
|
||||
torch_device="cpu",
|
||||
)
|
||||
# output_cpu = torch.clip(output_cpu, -1, +1)
|
||||
|
||||
t_rsc[i_rep + 1] = time.perf_counter()
|
||||
t_rsc[-1] = time.perf_counter()
|
||||
print("Generation by CPU stop!")
|
||||
|
||||
if BENCH_DAVID:
|
||||
print("Generation by DAVID start!")
|
||||
from CPPExtensions.PyTCopyCPU import TCopyCPU as render_stimulus_CPP
|
||||
|
||||
copyier = render_stimulus_CPP()
|
||||
|
||||
number_of_cpu_processes = os.cpu_count()
|
||||
output_dav_tmp = torch.zeros(
|
||||
(
|
||||
n_contours,
|
||||
y_canvas + 2 * r_gab_PIX,
|
||||
x_canvas + 2 * r_gab_PIX,
|
||||
),
|
||||
device="cpu",
|
||||
dtype=torch.float,
|
||||
)
|
||||
gabor = kernels[0].detach().cpu()
|
||||
|
||||
# Umsort!
|
||||
n_elements_total = 0
|
||||
for i_con in range(n_contours):
|
||||
n_elements_total += len(index_x[i_con])
|
||||
sparse_matrix = torch.zeros(
|
||||
(n_elements_total, 4), device="cpu", dtype=torch.int64
|
||||
)
|
||||
i_elements_total = 0
|
||||
for i_con in range(n_contours):
|
||||
n_add = len(index_x[i_con])
|
||||
sparse_matrix[i_elements_total : i_elements_total + n_add, 0] = i_con
|
||||
sparse_matrix[
|
||||
i_elements_total : i_elements_total + n_add, 1
|
||||
] = index_srcchg[i_con]
|
||||
sparse_matrix[i_elements_total : i_elements_total + n_add, 2] = index_y[
|
||||
i_con
|
||||
]
|
||||
sparse_matrix[i_elements_total : i_elements_total + n_add, 3] = index_x[
|
||||
i_con
|
||||
]
|
||||
i_elements_total += n_add
|
||||
assert i_elements_total == n_elements_total, "UNBEHAGEN macht sich breit!"
|
||||
|
||||
t_dav = torch.zeros((n_repeat + 2), device="cpu")
|
||||
t_dav[0] = time.perf_counter()
|
||||
for i_rep in range(n_repeat):
|
||||
output_dav_tmp.fill_(0.0)
|
||||
copyier.process(
|
||||
sparse_matrix.data_ptr(),
|
||||
int(sparse_matrix.shape[0]),
|
||||
int(sparse_matrix.shape[1]),
|
||||
gabor.data_ptr(),
|
||||
int(gabor.shape[0]),
|
||||
int(gabor.shape[1]),
|
||||
int(gabor.shape[2]),
|
||||
output_dav_tmp.data_ptr(),
|
||||
int(output_dav_tmp.shape[0]),
|
||||
int(output_dav_tmp.shape[1]),
|
||||
int(output_dav_tmp.shape[2]),
|
||||
int(number_of_cpu_processes), # type: ignore
|
||||
)
|
||||
output_dav = output_dav_tmp[
|
||||
:,
|
||||
d_gab_PIX - 1 : -(d_gab_PIX - 1),
|
||||
d_gab_PIX - 1 : -(d_gab_PIX - 1),
|
||||
].clone()
|
||||
t_dav[i_rep + 1] = time.perf_counter()
|
||||
t_dav[-1] = time.perf_counter()
|
||||
print("Generation by DAVID done!")
|
||||
|
||||
if VERBOSE: # show last stimulus
|
||||
if BENCH_CONVOLVE:
|
||||
plt.subplot(2, 2, 1)
|
||||
plt.imshow(output[-1, 0].detach().cpu(), cmap="gray", vmin=-1, vmax=+1)
|
||||
plt.title("convolve")
|
||||
if BENCH_GPU:
|
||||
plt.subplot(2, 2, 2)
|
||||
plt.imshow(output_gpu[-1].detach().cpu(), cmap="gray", vmin=-1, vmax=+1)
|
||||
plt.title("gpu")
|
||||
if BENCH_CPU:
|
||||
plt.subplot(2, 2, 3)
|
||||
plt.imshow(output_cpu[-1], cmap="gray", vmin=-1, vmax=+1)
|
||||
plt.title("cpu")
|
||||
if BENCH_DAVID:
|
||||
plt.subplot(2, 2, 4)
|
||||
plt.imshow(output_dav[-1], cmap="gray", vmin=-1, vmax=+1)
|
||||
plt.title("david")
|
||||
plt.show()
|
||||
|
||||
dt_discretize = t_dis.diff() / n_contours
|
||||
plt.plot(dt_discretize.detach().cpu())
|
||||
dt_convolve = t_con.diff() / n_contours
|
||||
plt.plot(dt_convolve.detach().cpu())
|
||||
dt_gpu = t_rsg.diff() / n_contours
|
||||
plt.plot(dt_gpu.detach().cpu())
|
||||
dt_cpu = t_rsc.diff() / n_contours
|
||||
plt.plot(dt_cpu.detach().cpu())
|
||||
dt_david = t_dav.diff() / n_contours
|
||||
plt.plot(dt_david.detach().cpu())
|
||||
|
||||
plt.legend(["discretize", "convolve", "gpu", "cpu", "david"])
|
||||
plt.show()
|
||||
print(
|
||||
f"Average discretize for 1k stims: {1000*dt_discretize[:-1].detach().cpu().mean()} secs."
|
||||
)
|
||||
print(
|
||||
f"Average convolve for 1k stims: {1000*dt_convolve[:-1].detach().cpu().mean()} secs."
|
||||
)
|
||||
print(f"Average gpu for 1k stims: {1000*dt_gpu[:-1].detach().cpu().mean()} secs.")
|
||||
print(f"Average cpu for 1k stims: {1000*dt_cpu[:-1].detach().cpu().mean()} secs.")
|
||||
print(
|
||||
f"Average david for 1k stims: {1000*dt_david[:-1].detach().cpu().mean()} secs."
|
||||
)
|
||||
|
||||
if BENCH_GPU and BENCH_CPU and BENCH_DAVID:
|
||||
df1 = (torch.abs(output_gpu[-1].detach().cpu() - output_cpu[-1])).mean()
|
||||
df2 = (torch.abs(output_gpu[-1].detach().cpu() - output_dav[-1])).mean()
|
||||
df3 = (torch.abs(output_dav[-1].cpu() - output_cpu[-1])).mean()
|
||||
print(f"Differences: CPU-GPU:{df1}, GPU-David:{df2}, David-CPU:{df3}")
|
||||
|
||||
# %%
|
349
thesis code/network analysis/render_including_minDist/render.py
Normal file
349
thesis code/network analysis/render_including_minDist/render.py
Normal file
|
@ -0,0 +1,349 @@
|
|||
# %%
|
||||
|
||||
import torch
|
||||
import time
|
||||
import scipy
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import contours
|
||||
import glob
|
||||
|
||||
USE_CEXT_FROM_DAVID = False
|
||||
if USE_CEXT_FROM_DAVID:
|
||||
# from CPPExtensions.PyTCopyCPU import TCopyCPU
|
||||
from CPPExtensions.PyTCopyCPU import TCopyCPU as render_stimulus_CPP
|
||||
|
||||
|
||||
import matplotlib as mpl
|
||||
|
||||
|
||||
mpl.rcParams["text.usetex"] = True
|
||||
mpl.rcParams["font.family"] = "serif"
|
||||
|
||||
|
||||
def plot_single_gabor_filter(
|
||||
gabors,
|
||||
dirs_source,
|
||||
dirs_change,
|
||||
source_idx: int = 0,
|
||||
change_idx: int = 0,
|
||||
save_plot: bool = False,
|
||||
):
|
||||
print(
|
||||
f"dirs_source:{dirs_source[source_idx]:.2f}, dirs_change: {dirs_change[change_idx]:.2f}"
|
||||
)
|
||||
print(f"Inflection angle in deg:{torch.rad2deg(dirs_change[change_idx])}")
|
||||
plt.imshow(
|
||||
gabors[source_idx, change_idx],
|
||||
cmap="gray",
|
||||
vmin=gabors.min(),
|
||||
vmax=gabors.max(),
|
||||
)
|
||||
cbar = plt.colorbar()
|
||||
cbar.ax.tick_params(labelsize=14)
|
||||
plt.xticks(fontsize=16)
|
||||
plt.yticks(fontsize=16)
|
||||
plt.tight_layout()
|
||||
|
||||
if save_plot:
|
||||
if change_idx != 0:
|
||||
plt.savefig(
|
||||
f"additional thesis plots/saved_plots/gabor_in{torch.rad2deg(dirs_source[source_idx])}inflect{torch.rad2deg(dirs_change[change_idx])}.pdf",
|
||||
dpi=300,
|
||||
)
|
||||
else:
|
||||
plt.savefig(
|
||||
f"additional thesis plots/saved_plots/gabor_mono_{dirs_source[source_idx]:.2f}_deg{torch.rad2deg(dirs_source[source_idx])}.pdf",
|
||||
dpi=300,
|
||||
)
|
||||
plt.show(block=True)
|
||||
|
||||
|
||||
def render_gaborfield(posori, params, verbose=False):
|
||||
scale_factor = params["scale_factor"]
|
||||
n_source = params["n_source"]
|
||||
n_change = params["n_change"]
|
||||
|
||||
# convert sizes to pixel units
|
||||
lambda_PIX = params["lambda_gabor"] * scale_factor
|
||||
sigma_PIX = params["sigma_gabor"] * scale_factor
|
||||
r_gab_PIX = int(params["d_gabor"] * scale_factor / 2)
|
||||
d_gab_PIX = r_gab_PIX * 2 + 1
|
||||
|
||||
# make filterbank
|
||||
gabors, dirs_source, dirs_change = contours.gaborner_filterbank(
|
||||
r_gab=r_gab_PIX,
|
||||
n_source=n_source,
|
||||
n_change=n_change,
|
||||
lambdah=lambda_PIX,
|
||||
sigma=sigma_PIX,
|
||||
phase=params["phase_gabor"],
|
||||
normalize=params["normalize_gabor"],
|
||||
torch_device="cpu",
|
||||
)
|
||||
|
||||
gabors = gabors.reshape([n_source * n_change, d_gab_PIX, d_gab_PIX])
|
||||
|
||||
n_contours = posori.shape[0]
|
||||
|
||||
# discretize ALL stimuli
|
||||
if verbose:
|
||||
print("Discretizing START!!!")
|
||||
t_dis0 = time.perf_counter()
|
||||
(
|
||||
index_srcchg,
|
||||
index_x,
|
||||
index_y,
|
||||
x_canvas,
|
||||
y_canvas,
|
||||
) = contours.discretize_stimuli(
|
||||
posori=posori,
|
||||
x_range=params["x_range"],
|
||||
y_range=params["y_range"],
|
||||
scale_factor=scale_factor,
|
||||
r_gab_PIX=r_gab_PIX,
|
||||
n_source=n_source,
|
||||
n_change=n_change,
|
||||
torch_device="cpu",
|
||||
)
|
||||
|
||||
# find out minimal distance between neighboring gabors:
|
||||
mean_mean_dist: list = []
|
||||
for i in range(len(index_x)):
|
||||
xx_center = index_x[i] + r_gab_PIX
|
||||
yy_center = index_y[i] + r_gab_PIX
|
||||
|
||||
# calc mean distances within one image
|
||||
x = xx_center[:, np.newaxis] - xx_center[np.newaxis, :]
|
||||
y = yy_center[:, np.newaxis] - yy_center[np.newaxis, :]
|
||||
print(x.shape, y.shape)
|
||||
distances = np.sqrt(
|
||||
(xx_center[:, np.newaxis] - xx_center[np.newaxis, :]) ** 2
|
||||
+ (yy_center[:, np.newaxis] - yy_center[np.newaxis, :]) ** 2
|
||||
)
|
||||
distances = distances.numpy()
|
||||
|
||||
# diagonal elements to infinity to exclude self-distances
|
||||
np.fill_diagonal(distances, np.inf)
|
||||
|
||||
# nearest neighbor of each contour element
|
||||
nearest_neighbors = np.argmin(distances, axis=1)
|
||||
|
||||
# dist to nearest neighbors
|
||||
nearest_distances = distances[np.arange(distances.shape[0]), nearest_neighbors]
|
||||
|
||||
# mean distance
|
||||
mean_dist = np.mean(nearest_distances)
|
||||
print(f"Mean distance between contour elements: {mean_dist}")
|
||||
mean_mean_dist.append(mean_dist)
|
||||
|
||||
m = np.mean(mean_mean_dist)
|
||||
print(f"Mean distance between contour elements over all images: {m}")
|
||||
|
||||
t_dis1 = time.perf_counter()
|
||||
if verbose:
|
||||
print(f"Discretizing END, took {t_dis1-t_dis0} seconds.!!!")
|
||||
|
||||
if verbose:
|
||||
print("Generation START!!!")
|
||||
t0 = time.perf_counter()
|
||||
|
||||
if not USE_CEXT_FROM_DAVID:
|
||||
if verbose:
|
||||
print(" (using NUMPY...)")
|
||||
output = torch.zeros(
|
||||
(
|
||||
n_contours,
|
||||
y_canvas - d_gab_PIX + 1,
|
||||
x_canvas - d_gab_PIX + 1,
|
||||
),
|
||||
device="cpu",
|
||||
)
|
||||
kernels_cpu = gabors.detach().cpu()
|
||||
for i_con in range(n_contours):
|
||||
output[i_con] = contours.render_stimulus(
|
||||
kernels=kernels_cpu,
|
||||
index_srcchg=index_srcchg[i_con],
|
||||
index_y=index_y[i_con],
|
||||
index_x=index_x[i_con],
|
||||
y_canvas=y_canvas,
|
||||
x_canvas=x_canvas,
|
||||
torch_device="cpu",
|
||||
)
|
||||
output = torch.clip(output, -1, +1)
|
||||
|
||||
else:
|
||||
if verbose:
|
||||
print(" (using C++...)")
|
||||
copyier = render_stimulus_CPP()
|
||||
number_of_cpu_processes = os.cpu_count()
|
||||
output_dav_tmp = torch.zeros(
|
||||
(
|
||||
n_contours,
|
||||
y_canvas + 2 * r_gab_PIX,
|
||||
x_canvas + 2 * r_gab_PIX,
|
||||
),
|
||||
device="cpu",
|
||||
dtype=torch.float,
|
||||
)
|
||||
|
||||
# Umsort!
|
||||
n_elements_total = 0
|
||||
for i_con in range(n_contours):
|
||||
n_elements_total += len(index_x[i_con])
|
||||
sparse_matrix = torch.zeros(
|
||||
(n_elements_total, 4), device="cpu", dtype=torch.int64
|
||||
)
|
||||
i_elements_total = 0
|
||||
for i_con in range(n_contours):
|
||||
n_add = len(index_x[i_con])
|
||||
sparse_matrix[i_elements_total : i_elements_total + n_add, 0] = i_con
|
||||
sparse_matrix[
|
||||
i_elements_total : i_elements_total + n_add, 1
|
||||
] = index_srcchg[i_con]
|
||||
sparse_matrix[i_elements_total : i_elements_total + n_add, 2] = index_y[
|
||||
i_con
|
||||
]
|
||||
sparse_matrix[i_elements_total : i_elements_total + n_add, 3] = index_x[
|
||||
i_con
|
||||
]
|
||||
i_elements_total += n_add
|
||||
assert i_elements_total == n_elements_total, "UNBEHAGEN macht sich breit!"
|
||||
|
||||
# output_dav_tmp.fill_(0.0)
|
||||
copyier.process(
|
||||
sparse_matrix.data_ptr(),
|
||||
int(sparse_matrix.shape[0]),
|
||||
int(sparse_matrix.shape[1]),
|
||||
gabors.data_ptr(),
|
||||
int(gabors.shape[0]),
|
||||
int(gabors.shape[1]),
|
||||
int(gabors.shape[2]),
|
||||
output_dav_tmp.data_ptr(),
|
||||
int(output_dav_tmp.shape[0]),
|
||||
int(output_dav_tmp.shape[1]),
|
||||
int(output_dav_tmp.shape[2]),
|
||||
int(number_of_cpu_processes), # type: ignore
|
||||
)
|
||||
output = torch.clip(
|
||||
output_dav_tmp[
|
||||
:,
|
||||
d_gab_PIX - 1 : -(d_gab_PIX - 1),
|
||||
d_gab_PIX - 1 : -(d_gab_PIX - 1),
|
||||
],
|
||||
-1,
|
||||
+1,
|
||||
)
|
||||
|
||||
t1 = time.perf_counter()
|
||||
if verbose:
|
||||
print(f"Generating END, took {t1-t0} seconds.!!!")
|
||||
|
||||
if verbose:
|
||||
print("Showing first and last stimulus generated...")
|
||||
plt.imshow(output[0], cmap="gray", vmin=-1, vmax=+1)
|
||||
plt.show()
|
||||
plt.imshow(output[-1], cmap="gray", vmin=-1, vmax=+1)
|
||||
plt.show()
|
||||
print(f"Processed {n_contours} stimuli in {t1-t_dis0} seconds!")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def render_gaborfield_frommatfiles(
|
||||
files, params, varname, varname_dist, altpath=None, verbose=False
|
||||
):
|
||||
n_total = 0
|
||||
n_files = len(files)
|
||||
print(f"Going through {n_files} contour files...")
|
||||
|
||||
for i_file in range(n_files):
|
||||
# get path, basename, suffix...
|
||||
full = files[i_file]
|
||||
path, file = os.path.split(full)
|
||||
base, suffix = os.path.splitext(file)
|
||||
|
||||
# load file
|
||||
mat = scipy.io.loadmat(full)
|
||||
if "dist" in full:
|
||||
posori = mat[varname_dist]
|
||||
else:
|
||||
posori = mat[varname]
|
||||
|
||||
n_contours = posori.shape[0]
|
||||
n_total += n_contours
|
||||
print(f" ...file {file} contains {n_contours} contours.")
|
||||
|
||||
# process...
|
||||
gaborfield = render_gaborfield(posori, params=params, verbose=verbose)
|
||||
|
||||
# save
|
||||
if altpath:
|
||||
savepath = altpath
|
||||
else:
|
||||
savepath = path
|
||||
savefull = savepath + os.sep + base + "_RENDERED.npz"
|
||||
print(f" ...saving under {savefull}...")
|
||||
gaborfield = (torch.clip(gaborfield, -1, 1) * 127 + 128).type(torch.uint8)
|
||||
# np.savez_compressed(savefull, gaborfield=gaborfield)
|
||||
|
||||
return n_total
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
TESTMODE = "files" # "files" or "posori"
|
||||
|
||||
# cutout for stimuli, and gabor parameters
|
||||
params = {
|
||||
"x_range": [140, 940],
|
||||
"y_range": [140, 940],
|
||||
"scale_factor": 0.25, # scale to convert coordinates to pixel values
|
||||
"d_gabor": 40,
|
||||
"lambda_gabor": 16,
|
||||
"sigma_gabor": 8,
|
||||
"phase_gabor": 0.0,
|
||||
"normalize_gabor": True,
|
||||
# number of directions for dictionary
|
||||
"n_source": 32,
|
||||
"n_change": 32,
|
||||
}
|
||||
|
||||
if TESTMODE == "files":
|
||||
# path = "/data_1/kk/StimulusGeneration/Alicorn/Natural/Corner000_n10000"
|
||||
path = "D:/Katha/Neuroscience/Semester 4/newCode/RenderAlicorns/Coignless"
|
||||
files = glob.glob(path + os.sep + "*.mat")
|
||||
|
||||
t0 = time.perf_counter()
|
||||
n_total = render_gaborfield_frommatfiles(
|
||||
files=files,
|
||||
params=params,
|
||||
varname="Table_base_crn090",
|
||||
varname_dist="Table_base_crn090_dist",
|
||||
altpath="D:/Katha/Neuroscience/Semester 4/newCode/RenderAlicorns/Output/Coignless",
|
||||
)
|
||||
t1 = time.perf_counter()
|
||||
dt = t1 - t0
|
||||
print(
|
||||
f"Rendered {n_total} contours in {dt} secs, yielding {n_total/dt} contours/sec."
|
||||
)
|
||||
|
||||
if TESTMODE == "posori":
|
||||
print("Sample stimulus generation:")
|
||||
print("===========================")
|
||||
|
||||
# load contours, multiplex coordinates to simulate a larger set of contours
|
||||
n_multiplex = 5
|
||||
mat = scipy.io.loadmat(
|
||||
"D:/Katha/Neuroscience/Semester 4/newCode/RenderAlicorns/corner_angle_090_dist_b001_n100.mat"
|
||||
)
|
||||
posori = np.tile(mat["Table_crn_crn090_dist"], (n_multiplex, 1))
|
||||
n_contours = posori.shape[0]
|
||||
print(f"Processing {n_contours} contour stimuli")
|
||||
|
||||
output = render_gaborfield(posori, params=params, verbose=True)
|
||||
output8 = (torch.clip(output, -1, 1) * 127 + 128).type(torch.uint8)
|
||||
np.savez_compressed("output8_compressed.npz", output8=output8)
|
||||
|
||||
|
||||
# %%
|
28
thesis code/network analysis/weights_correlation/README.txt
Normal file
28
thesis code/network analysis/weights_correlation/README.txt
Normal file
|
@ -0,0 +1,28 @@
|
|||
Folder "weights_correlation":
|
||||
|
||||
File:
|
||||
|
||||
1. create_gabor_dict:
|
||||
* contains the code to generate the Gabor dictionary used for the weights of convolutional layer 1
|
||||
* 32 Gabors: 8 orientations, 4 phases
|
||||
* Gabors have a diameter of 11 pixels
|
||||
|
||||
2. draw_input_fields:
|
||||
* used to calculate how much of the input the kernel of each CNN layer has access to
|
||||
* draws these sizes into a chosen image from the dataset
|
||||
|
||||
3. all_cnns_mean_correlation:
|
||||
* includes the code to plot the correlation matrices seen in the written thesis
|
||||
* includes statistical test
|
||||
* includes code to plot every single correlation matrix of the 20 CNNs
|
||||
|
||||
|
||||
|
||||
In folder "weight visualization":
|
||||
|
||||
1. plot_as_grid:
|
||||
* visualizes the weights and bias (optional)
|
||||
|
||||
2. plot_weights:
|
||||
* loads model
|
||||
* choose layer to visualize weights from (+ bias optionally)
|
|
@ -0,0 +1,274 @@
|
|||
import torch
|
||||
import sys
|
||||
import os
|
||||
import matplotlib.pyplot as plt # noqa
|
||||
import numpy as np
|
||||
import matplotlib as mpl
|
||||
|
||||
mpl.rcParams["text.usetex"] = True
|
||||
mpl.rcParams["font.family"] = "serif"
|
||||
mpl.rcParams["font.size"] = 15
|
||||
|
||||
# import files from parent dir
|
||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from functions.make_cnn import make_cnn # noqa
|
||||
|
||||
|
||||
def show_20mean_correlations(model_list, save: bool = False, cnn: str = "CORNER"):
|
||||
"""
|
||||
Displays a correlation matrix for every single of the 20 CNNs
|
||||
"""
|
||||
|
||||
fig, axs = plt.subplots(4, 5, figsize=(15, 15))
|
||||
for i, load_model in enumerate(model_list):
|
||||
# load model
|
||||
model = torch.load(load_model).to("cpu")
|
||||
model.eval()
|
||||
|
||||
# load 2nd convs weights
|
||||
weights = model[3].weight.cpu().detach().clone().numpy()
|
||||
corr_matrices = []
|
||||
for j in range(weights.shape[0]):
|
||||
w_j = weights[j]
|
||||
w = w_j.reshape(w_j.shape[0], -1)
|
||||
corr_matrix = np.corrcoef(w)
|
||||
corr_matrices.append(corr_matrix)
|
||||
|
||||
mean_corr_matrix = np.mean(corr_matrices, axis=0)
|
||||
ax = axs[i // 5, i % 5]
|
||||
im = ax.matshow(mean_corr_matrix, cmap="RdBu_r")
|
||||
cbar = fig.colorbar(
|
||||
im, ax=ax, fraction=0.046, pad=0.04, ticks=np.arange(-1.1, 1.1, 0.2)
|
||||
)
|
||||
ax.set_title(f"Model {i+1}")
|
||||
|
||||
# remove lower x-axis ticks
|
||||
ax.tick_params(axis="x", which="both", bottom=False)
|
||||
ax.tick_params(axis="both", which="major", labelsize=14)
|
||||
cbar.ax.tick_params(labelsize=13)
|
||||
|
||||
# fig.colorbar(im, ax=axs.ravel().tolist())
|
||||
plt.tight_layout()
|
||||
if save:
|
||||
plt.savefig(
|
||||
f"additional thesis plots/saved_plots/weight plots/all20cnn_mean_corr_{cnn}.pdf",
|
||||
dpi=300,
|
||||
)
|
||||
plt.show()
|
||||
|
||||
|
||||
def show_overall_mean_correlation(model_list, save: bool = False, cnn: str = "CORNER"):
|
||||
"""
|
||||
Displays the mean correlation across all 20 CNNs
|
||||
"""
|
||||
|
||||
fig, ax = plt.subplots(figsize=(7, 7))
|
||||
overall_corr_matrices = []
|
||||
for i, load_model in enumerate(model_list):
|
||||
# load model
|
||||
model = torch.load(load_model).to("cpu")
|
||||
model.eval()
|
||||
|
||||
# load 2nd convs weights
|
||||
weights = model[3].weight.cpu().detach().clone().numpy()
|
||||
corr_matrices = []
|
||||
for j in range(weights.shape[0]):
|
||||
w_j = weights[j]
|
||||
w = w_j.reshape(w_j.shape[0], -1)
|
||||
corr_matrix = np.corrcoef(w)
|
||||
corr_matrices.append(corr_matrix)
|
||||
|
||||
mean_corr_matrix = np.mean(corr_matrices, axis=0)
|
||||
overall_corr_matrices.append(mean_corr_matrix)
|
||||
|
||||
overall_mean_corr_matrix = np.mean(overall_corr_matrices, axis=0)
|
||||
im = ax.matshow(overall_mean_corr_matrix, cmap="RdBu_r")
|
||||
cbar = fig.colorbar(
|
||||
im, ax=ax, fraction=0.046, pad=0.04, ticks=np.arange(-1.1, 1.1, 0.1)
|
||||
)
|
||||
|
||||
# remove lower x-axis ticks
|
||||
ax.tick_params(axis="x", which="both", bottom=False)
|
||||
ax.tick_params(axis="both", which="major", labelsize=17)
|
||||
cbar.ax.tick_params(labelsize=15)
|
||||
|
||||
plt.tight_layout()
|
||||
if save:
|
||||
plt.savefig(
|
||||
f"additional thesis plots/saved_plots/weight plots/mean20cnn_mean_corr_{cnn}.pdf",
|
||||
dpi=300,
|
||||
)
|
||||
plt.show()
|
||||
return overall_mean_corr_matrix
|
||||
|
||||
|
||||
def get_file_list_all_cnns(dir: str) -> list:
|
||||
all_results: list = []
|
||||
for filename in os.listdir(dir):
|
||||
if filename.endswith(".pt"):
|
||||
# print(os.path.join(dir, filename))
|
||||
all_results.append(os.path.join(dir, filename))
|
||||
|
||||
return all_results
|
||||
|
||||
|
||||
def test_normality(correlation_data, condition: str, alpha: float = 0.05):
|
||||
"""
|
||||
Tests if data has normal distribution
|
||||
* 0-hyp: data is normally distributed
|
||||
* low p-val: data not normally distributed
|
||||
"""
|
||||
from scipy import stats
|
||||
|
||||
statistic, p_value = stats.normaltest(correlation_data)
|
||||
print(
|
||||
f"\nD'Agostino-Pearson Test for {condition} - p-val :",
|
||||
p_value,
|
||||
)
|
||||
print(
|
||||
f"D'Agostino-Pearson Test for {condition} - statistic :",
|
||||
statistic,
|
||||
)
|
||||
|
||||
# set alpha
|
||||
if p_value < alpha:
|
||||
print("P-val < alpha. Reject 0-hypothesis. Data is not normally distributed")
|
||||
else:
|
||||
print("P-val > alpha. Keep 0-hypothesis. Data is normally distributed")
|
||||
|
||||
return p_value
|
||||
|
||||
|
||||
def two_sample_ttest(corr_classic, corr_coner, alpha: float = 0.05):
|
||||
"""
|
||||
This is a test for the null hypothesis that 2 independent samples have identical average (expected) values. This test assumes that the populations have identical variances by default.
|
||||
"""
|
||||
|
||||
from scipy.stats import ttest_ind
|
||||
|
||||
t_stat, p_value = ttest_ind(corr_classic, corr_coner)
|
||||
print(f"t-statistic: {t_stat}")
|
||||
|
||||
# check if the p-value less than significance level
|
||||
if p_value < alpha:
|
||||
print(
|
||||
"There is a significant difference in the mean correlation values between the two groups."
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"There is no significant difference in the mean correlation values between the two groups."
|
||||
)
|
||||
|
||||
|
||||
def willy_is_not_whitney_test(data_classic, data_corner):
|
||||
from scipy.stats import mannwhitneyu
|
||||
|
||||
"""
|
||||
* Test does not assume normal distribution
|
||||
* Compares means between 2 indep groups
|
||||
"""
|
||||
|
||||
# call test
|
||||
statistic, p_value = mannwhitneyu(data_classic, data_corner)
|
||||
|
||||
# results
|
||||
print("\nMann-Whitney U Test Statistic:", statistic)
|
||||
print("Mann-Whitney U Test p-value:", p_value)
|
||||
|
||||
# check significance:
|
||||
alpha = 0.05
|
||||
if p_value < alpha:
|
||||
print("The distributions are significantly different.")
|
||||
else:
|
||||
print("The distributions are not significantly different.")
|
||||
|
||||
return p_value
|
||||
|
||||
|
||||
def visualize_differences(corr_class, corr_corn, save: bool = False):
|
||||
# calc mean, std, median
|
||||
mean_class = np.mean(corr_class)
|
||||
median_class = np.median(corr_class)
|
||||
std_class = np.std(corr_class)
|
||||
|
||||
mean_corn = np.mean(corr_corn)
|
||||
median_corn = np.median(corr_corn)
|
||||
std_corn = np.std(corr_corn)
|
||||
|
||||
# plot
|
||||
labels = ["Mean", "Median", "Standard Deviation"]
|
||||
condition_class = [mean_class, median_class, std_class]
|
||||
condition_corn = [mean_corn, median_corn, std_corn]
|
||||
|
||||
x = np.arange(len(labels))
|
||||
width = 0.35
|
||||
|
||||
_, ax = plt.subplots(figsize=(7, 7))
|
||||
rect_class = ax.bar(
|
||||
x - width / 2, condition_class, width, label="CLASSIC", color="cornflowerblue"
|
||||
)
|
||||
rect_corn = ax.bar(
|
||||
x + width / 2, condition_corn, width, label="CORNER", color="coral"
|
||||
)
|
||||
|
||||
# show bar values
|
||||
for i, rect in enumerate(rect_class + rect_corn):
|
||||
height = rect.get_height()
|
||||
ax.text(
|
||||
rect.get_x() + rect.get_width() / 2.0,
|
||||
height,
|
||||
f"{height:.3f}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=15,
|
||||
)
|
||||
|
||||
# ax.set_ylabel('Value')
|
||||
ax.set_title("Summary Statistics by Condition")
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(labels, fontsize=17)
|
||||
ax.legend()
|
||||
|
||||
plt.tight_layout()
|
||||
if save:
|
||||
plt.savefig(
|
||||
"additional thesis plots/saved_plots/weight plots/summary_stats_correlation_CLASSvsCORN.pdf",
|
||||
dpi=300,
|
||||
)
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# CLASSIC:
|
||||
directory_classic: str = "D:/Katha/Neuroscience/Semester 4/newCode/kk_contour_net_shallow-main/classic3288_fest"
|
||||
all_results_classic = get_file_list_all_cnns(dir=directory_classic)
|
||||
show_20mean_correlations(all_results_classic)
|
||||
mean_corr_classic = show_overall_mean_correlation(all_results_classic)
|
||||
|
||||
# CORNER:
|
||||
directory_corner: str = "D:/Katha/Neuroscience/Semester 4/newCode/kk_contour_net_shallow-main/corner3288_fest"
|
||||
all_results_corner = get_file_list_all_cnns(dir=directory_corner)
|
||||
show_20mean_correlations(all_results_corner)
|
||||
mean_corr_corner = show_overall_mean_correlation(all_results_corner)
|
||||
|
||||
# flatten
|
||||
corr_classic = mean_corr_classic.flatten()
|
||||
corr_corner = mean_corr_corner.flatten()
|
||||
|
||||
# test how data is distributed
|
||||
p_class = test_normality(correlation_data=corr_classic, condition="CLASSIC")
|
||||
p_corn = test_normality(correlation_data=corr_corner, condition="CORNER")
|
||||
|
||||
# perform statistical test:
|
||||
alpha: float = 0.05
|
||||
|
||||
if p_class < alpha and p_corn < alpha:
|
||||
willy_is_not_whitney_test(data_classic=corr_classic, data_corner=corr_corner)
|
||||
else:
|
||||
# do ttest:
|
||||
two_sample_ttest(corr_classic=corr_classic, corr_coner=corr_corner)
|
||||
|
||||
# visualize the differences:
|
||||
visualize_differences(corr_class=corr_classic, corr_corn=corr_corner, save=True)
|
|
@ -0,0 +1,87 @@
|
|||
import numpy as np
|
||||
import matplotlib.pyplot as plt # noqa
|
||||
|
||||
|
||||
def change_base(
|
||||
x: np.ndarray, y: np.ndarray, theta: float
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
x_theta: np.ndarray = x.astype(dtype=np.float32) * np.cos(theta) + y.astype(
|
||||
dtype=np.float32
|
||||
) * np.sin(theta)
|
||||
y_theta: np.ndarray = y.astype(dtype=np.float32) * np.cos(theta) - x.astype(
|
||||
dtype=np.float32
|
||||
) * np.sin(theta)
|
||||
return x_theta, y_theta
|
||||
|
||||
|
||||
def cos_gabor_function(
|
||||
x: np.ndarray, y: np.ndarray, theta: float, f: float, sigma: float, phi: float
|
||||
) -> np.ndarray:
|
||||
r_a: np.ndarray = change_base(x, y, theta)[0]
|
||||
r_b: np.ndarray = change_base(x, y, theta)[1]
|
||||
r2 = r_a**2 + r_b**2
|
||||
gauss: np.ndarray = np.exp(-0.5 * r2 / sigma**2)
|
||||
correction = np.exp(-2 * (np.pi * sigma * f) ** 2) * np.cos(phi)
|
||||
envelope = np.cos(2 * np.pi * f * change_base(x, y, theta)[0] + phi) - correction
|
||||
patch = gauss * envelope
|
||||
|
||||
return patch
|
||||
|
||||
|
||||
def weights(num_orients, num_phase, f, sigma, diameter, delta_x):
|
||||
dx = delta_x
|
||||
n = np.ceil(diameter / 2 / dx)
|
||||
x, y = np.mgrid[
|
||||
-n : n + 1,
|
||||
-n : n + 1,
|
||||
]
|
||||
|
||||
t = np.arange(num_orients) * np.pi / num_orients
|
||||
p = np.arange(num_phase) * 2 * np.pi / num_phase
|
||||
|
||||
w = np.zeros((num_orients, num_phase, x.shape[0], x.shape[0]))
|
||||
for i in range(num_orients):
|
||||
theta = t[i]
|
||||
for j in range(num_phase):
|
||||
phase = p[j]
|
||||
|
||||
w[i, j] = cos_gabor_function(
|
||||
x=x * dx, y=y * dx, theta=theta, f=f, sigma=sigma, phi=phase
|
||||
).T
|
||||
|
||||
return w
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
f = 0.25 # frequency = 1/lambda = 1/4
|
||||
sigma = 2.0
|
||||
diameter = 10
|
||||
num_orients = 8
|
||||
num_phase = 4
|
||||
we = weights(
|
||||
num_orients=num_orients,
|
||||
num_phase=num_phase,
|
||||
f=f,
|
||||
sigma=sigma,
|
||||
diameter=diameter,
|
||||
delta_x=1,
|
||||
)
|
||||
|
||||
# comment in for plotting as matrix :
|
||||
# fig = plt.figure(figsize=(5, 5))
|
||||
# for i in range(num_orients):
|
||||
# for j in range(num_phase):
|
||||
# plt.subplot(num_orients, num_phase, (i * num_phase) + j + 1)
|
||||
# plt.imshow(we[i, j], cmap="gray", vmin=we.min(), vmax=we.max())
|
||||
# plt.axis("off")
|
||||
# # plt.colorbar()
|
||||
# plt.tight_layout()
|
||||
# plt.show(block=True)
|
||||
|
||||
weights_flatten = np.ascontiguousarray(we)
|
||||
weights_flatten = np.reshape(
|
||||
weights_flatten, (we.shape[0] * we.shape[1], 1, we.shape[-2], we.shape[-1])
|
||||
)
|
||||
|
||||
# comment in for saving
|
||||
# np.save("gabor_dict_32o_8p.npy", weights_flatten)
|
|
@ -0,0 +1,156 @@
|
|||
# %%
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as patch
|
||||
import matplotlib as mpl
|
||||
from cycler import cycler
|
||||
from functions.analyse_network import analyse_network
|
||||
|
||||
mpl.rcParams["text.usetex"] = True
|
||||
mpl.rcParams["font.family"] = "serif"
|
||||
|
||||
|
||||
def draw_kernel(
|
||||
image: np.ndarray,
|
||||
coordinate_list: list,
|
||||
layer_type_list: list,
|
||||
ignore_output_conv_layer: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Call function after creating the model-to-be-trained.
|
||||
"""
|
||||
assert image.shape[0] == 200
|
||||
assert image.shape[1] == 200
|
||||
|
||||
# list of colors to choose from:
|
||||
prop_cycle = plt.rcParams["axes.prop_cycle"]
|
||||
colors = prop_cycle.by_key()["color"]
|
||||
edge_color_cycler = iter(
|
||||
cycler(color=["sienna", "orange", "gold", "bisque"] + colors)
|
||||
)
|
||||
|
||||
# position first kernel
|
||||
start_x: int = 4
|
||||
start_y: int = 15
|
||||
|
||||
# general plot structure:
|
||||
plt.ion()
|
||||
_, ax = plt.subplots()
|
||||
ax.imshow(image, cmap="gray")
|
||||
ax.tick_params(axis="both", which="major", labelsize=15)
|
||||
|
||||
if ignore_output_conv_layer:
|
||||
number_of_layers: int = len(layer_type_list) - 1
|
||||
else:
|
||||
number_of_layers = len(layer_type_list)
|
||||
|
||||
for i in range(0, number_of_layers):
|
||||
if layer_type_list[i] is not None:
|
||||
kernels = int(coordinate_list[i].shape[0])
|
||||
edgecolor = next(edge_color_cycler)["color"]
|
||||
# draw kernel
|
||||
kernel = patch.Rectangle(
|
||||
(start_x, start_y),
|
||||
kernels,
|
||||
kernels,
|
||||
linewidth=1.2,
|
||||
edgecolor=edgecolor,
|
||||
facecolor="none",
|
||||
label=layer_type_list[i],
|
||||
)
|
||||
ax.add_patch(kernel)
|
||||
|
||||
if coordinate_list[i].shape[1] > 1:
|
||||
strides = int(coordinate_list[i][0, 1]) - int(coordinate_list[i][0, 0])
|
||||
|
||||
# draw stride
|
||||
stride = patch.Rectangle(
|
||||
(start_x + strides, start_y + strides),
|
||||
kernels,
|
||||
kernels,
|
||||
linewidth=1.2,
|
||||
edgecolor=edgecolor,
|
||||
facecolor="none",
|
||||
linestyle="dashed",
|
||||
)
|
||||
ax.add_patch(stride)
|
||||
|
||||
# add distance of next drawing
|
||||
start_x += 14
|
||||
start_y += 10
|
||||
|
||||
# final plot
|
||||
plt.tight_layout()
|
||||
plt.legend(loc="upper right", fontsize=11)
|
||||
plt.show(block=True)
|
||||
|
||||
|
||||
# %%
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from jsmin import jsmin
|
||||
|
||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.append(parent_dir)
|
||||
from functions.alicorn_data_loader import alicorn_data_loader
|
||||
from functions.make_cnn_v2 import make_cnn
|
||||
from functions.create_logger import create_logger
|
||||
|
||||
ignore_output_conv_layer: bool = True
|
||||
|
||||
# get current path:
|
||||
cwd = os.path.dirname(os.path.realpath(__file__)).replace(os.sep, "/")
|
||||
|
||||
network_config_filename = f"{cwd}/network_0.json"
|
||||
config_filenname = f"{cwd}/config_v2.json"
|
||||
with open(config_filenname, "r") as file_handle:
|
||||
config = json.loads(jsmin(file_handle.read()))
|
||||
|
||||
logger = create_logger(
|
||||
save_logging_messages=False, display_logging_messages=False, model_name=None
|
||||
)
|
||||
|
||||
# test image:
|
||||
data_test = alicorn_data_loader(
|
||||
num_pfinkel=[0],
|
||||
load_stimuli_per_pfinkel=10,
|
||||
condition=str(config["condition"]),
|
||||
data_path=str(config["data_path"]),
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert data_test.__len__() > 0
|
||||
input_shape = data_test.__getitem__(0)[1].shape
|
||||
|
||||
model = make_cnn(
|
||||
network_config_filename=network_config_filename,
|
||||
logger=logger,
|
||||
input_shape=input_shape,
|
||||
)
|
||||
print(model)
|
||||
|
||||
assert input_shape[-2] == input_shape[-1]
|
||||
coordinate_list, layer_type_list, pixel_used = analyse_network(
|
||||
model=model, input_shape=int(input_shape[-1])
|
||||
)
|
||||
|
||||
for i in range(0, len(coordinate_list)):
|
||||
print(
|
||||
(
|
||||
f"Layer: {i}, Positions: {coordinate_list[i].shape[1]}, "
|
||||
f"Pixel per Positions: {coordinate_list[i].shape[0]}, "
|
||||
f"Type: {layer_type_list[i]}, Number of pixel used: {pixel_used[i]}"
|
||||
)
|
||||
)
|
||||
|
||||
image = data_test.__getitem__(6)[1].squeeze(0)
|
||||
|
||||
# call function for plotting input fields into image:
|
||||
draw_kernel(
|
||||
image=image.numpy(),
|
||||
coordinate_list=coordinate_list,
|
||||
layer_type_list=layer_type_list,
|
||||
ignore_output_conv_layer=ignore_output_conv_layer,
|
||||
)
|
|
@ -0,0 +1,194 @@
|
|||
# %%
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||
import matplotlib as mpl
|
||||
|
||||
|
||||
mpl.rcParams["text.usetex"] = True
|
||||
mpl.rcParams["font.family"] = "serif"
|
||||
|
||||
|
||||
def plot_weights(
|
||||
plot,
|
||||
s,
|
||||
grid_color,
|
||||
linewidth,
|
||||
idx,
|
||||
smallDim,
|
||||
swap_channels,
|
||||
activations,
|
||||
layer,
|
||||
title,
|
||||
colorbar,
|
||||
vmin,
|
||||
vmax,
|
||||
):
|
||||
plt.imshow(plot.T, cmap="RdBu_r", origin="lower", vmin=vmin, vmax=vmax)
|
||||
|
||||
ax = plt.gca()
|
||||
a = np.arange(0, plot.shape[1] + 1, s[3])
|
||||
b = np.arange(0, plot.shape[0] + 1, s[1])
|
||||
plt.hlines(a - 0.5, -0.5, plot.shape[0] - 0.5, colors=grid_color, lw=linewidth)
|
||||
plt.vlines(b - 0.5, -0.5, plot.shape[1] - 0.5, colors=grid_color, lw=linewidth)
|
||||
plt.ylim(-1, plot.shape[1])
|
||||
plt.xlim(-1, plot.shape[0])
|
||||
|
||||
ax.set_xticks(s[1] / 2 + np.arange(-0.5, plot.shape[0] - 1, s[1]))
|
||||
ax.set_yticks(s[3] / 2 + np.arange(-0.5, plot.shape[1] - 1, s[3]))
|
||||
|
||||
if (
|
||||
idx is not None
|
||||
and (smallDim is False and swap_channels is False)
|
||||
or (activations is True)
|
||||
):
|
||||
ax.set_xticklabels(idx, fontsize=19)
|
||||
ax.set_yticklabels(np.arange(s[2]), fontsize=19)
|
||||
elif idx is not None and layer == "FC1":
|
||||
ax.set_xticklabels(np.arange(s[0]), fontsize=19)
|
||||
ax.set_yticklabels(idx, fontsize=19)
|
||||
elif idx is not None and (smallDim is True or swap_channels is True):
|
||||
ax.set_xticklabels(np.arange(s[0]), fontsize=19)
|
||||
ax.set_yticklabels(idx, fontsize=19)
|
||||
else:
|
||||
ax.set_xticklabels(np.arange(s[0]), fontsize=19)
|
||||
ax.set_yticklabels(np.arange(s[2]), fontsize=19)
|
||||
ax.invert_yaxis()
|
||||
|
||||
ax.xaxis.set_label_position("top")
|
||||
ax.tick_params(axis="x", top=True, bottom=False, labeltop=True, labelbottom=False)
|
||||
|
||||
if title is not None:
|
||||
is_string = isinstance(title, str)
|
||||
if is_string is True:
|
||||
plt.title(title)
|
||||
|
||||
if colorbar is True:
|
||||
divider = make_axes_locatable(ax)
|
||||
cax = divider.append_axes("right", size="1.5%", pad=0.05)
|
||||
cbar = plt.colorbar(ax.get_images()[0], cax=cax)
|
||||
|
||||
# this was only for flattened conv1 weights cbar ticks!!
|
||||
# cbar.set_ticks([0.5, -0.5])
|
||||
# cbar.set_ticklabels([0.5, -0.5])
|
||||
|
||||
tick_font_size = 17
|
||||
cbar.ax.tick_params(labelsize=tick_font_size)
|
||||
|
||||
|
||||
def plot_in_grid(
|
||||
plot,
|
||||
fig_size=(10, 10),
|
||||
swap_channels=False,
|
||||
title=None,
|
||||
idx=None,
|
||||
colorbar=False,
|
||||
vmin=None,
|
||||
vmax=None,
|
||||
grid_color="k",
|
||||
linewidth=0.75,
|
||||
savetitle=None,
|
||||
activations=False,
|
||||
layer=None,
|
||||
format="pdf",
|
||||
bias=None,
|
||||
plot_bias: bool = False,
|
||||
):
|
||||
smallDim = False
|
||||
if plot.ndim < 4:
|
||||
smallDim = True
|
||||
plot = np.swapaxes(plot, 0, 1)
|
||||
plot = plot[:, :, np.newaxis, np.newaxis]
|
||||
if vmin is None and vmax is None:
|
||||
# plot_abs = np.amax(np.abs(plot))
|
||||
vmin = -(np.amax(np.abs(plot)))
|
||||
vmax = np.amax(np.abs(plot))
|
||||
|
||||
if swap_channels is True:
|
||||
plot = np.swapaxes(plot, 0, 1)
|
||||
|
||||
# print(plot.shape)
|
||||
plot = np.ascontiguousarray(np.moveaxis(plot, 1, 2))
|
||||
|
||||
for j in range(plot.shape[2]):
|
||||
for i in range(plot.shape[0]):
|
||||
plot[(i - 1), :, (j - 1), :] = plot[(i - 1), :, (j - 1), :].T
|
||||
|
||||
s = plot.shape
|
||||
plot = plot.reshape((s[0] * s[1], s[2] * s[3]))
|
||||
plt.figure(figsize=fig_size)
|
||||
|
||||
if plot_bias and bias is not None:
|
||||
if swap_channels:
|
||||
# If axes are swapped, arrange the plots side by side
|
||||
plt.subplot(1, 2, 1)
|
||||
plot_weights(
|
||||
plot=plot,
|
||||
s=s,
|
||||
grid_color=grid_color,
|
||||
linewidth=linewidth,
|
||||
idx=idx,
|
||||
smallDim=smallDim,
|
||||
swap_channels=swap_channels,
|
||||
activations=activations,
|
||||
layer=layer,
|
||||
title=title,
|
||||
colorbar=colorbar,
|
||||
vmin=vmin,
|
||||
vmax=vmax,
|
||||
)
|
||||
|
||||
plt.subplot(1, 2, 2)
|
||||
plt.plot(bias, np.arange(len(bias)))
|
||||
plt.ylim(len(bias) - 1, 0)
|
||||
plt.title("Bias", fontsize=14)
|
||||
plt.tight_layout()
|
||||
|
||||
else:
|
||||
plt.subplot(2, 1, 1)
|
||||
plot_weights(
|
||||
plot=plot,
|
||||
s=s,
|
||||
grid_color=grid_color,
|
||||
linewidth=linewidth,
|
||||
idx=idx,
|
||||
smallDim=smallDim,
|
||||
swap_channels=swap_channels,
|
||||
activations=activations,
|
||||
layer=layer,
|
||||
title=title,
|
||||
colorbar=colorbar,
|
||||
vmin=vmin,
|
||||
vmax=vmax,
|
||||
)
|
||||
|
||||
plt.subplot(2, 1, 2)
|
||||
plt.plot(np.arange(len(bias)), bias)
|
||||
plt.title("Bias", fontsize=14)
|
||||
|
||||
else:
|
||||
plot_weights(
|
||||
plot=plot,
|
||||
s=s,
|
||||
grid_color=grid_color,
|
||||
linewidth=linewidth,
|
||||
idx=idx,
|
||||
smallDim=smallDim,
|
||||
swap_channels=swap_channels,
|
||||
activations=activations,
|
||||
layer=layer,
|
||||
title=title,
|
||||
colorbar=colorbar,
|
||||
vmin=vmin,
|
||||
vmax=vmax,
|
||||
)
|
||||
|
||||
if savetitle is not None:
|
||||
plt.savefig(
|
||||
f"D:/Katha/Neuroscience/Semester 4/newCode/additional thesis plots/saved_plots/weight plots/{savetitle}.{format}",
|
||||
dpi=300,
|
||||
bbox_inches="tight",
|
||||
)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show(block=True)
|
|
@ -0,0 +1,101 @@
|
|||
import torch
|
||||
import sys
|
||||
import os
|
||||
import matplotlib.pyplot as plt # noqa
|
||||
import matplotlib as mpl
|
||||
|
||||
mpl.rcParams["text.usetex"] = True
|
||||
mpl.rcParams["font.family"] = "serif"
|
||||
mpl.rcParams["font.size"] = 14
|
||||
|
||||
# import files from parent dir
|
||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from plot_as_grid import plot_in_grid
|
||||
from functions.make_cnn import make_cnn # noqa
|
||||
|
||||
|
||||
# load on cpu
|
||||
device = torch.device("cpu")
|
||||
|
||||
# path to NN
|
||||
nn = "ArghCNN_numConvLayers3_outChannels[8, 8, 8]_kernelSize[7, 15]_leaky relu_stride1_trainFirstConvLayerTrue_seed293051_Natural_249Epoch_1308-1145"
|
||||
PATH = f"D:/Katha/Neuroscience/Semester 4/newCode/kk_contour_net_shallow-main/corner888/{nn}.pt"
|
||||
SAVE_PATH = "20 cnns weights/corner 888/seed293051_Natural_249Epoch_1308-1145"
|
||||
|
||||
# load and evaluate model
|
||||
model = torch.load(PATH).to(device)
|
||||
model.eval()
|
||||
print("Full network:")
|
||||
print(model)
|
||||
print("")
|
||||
# enter index to plot:
|
||||
idx = int(input("Please select layer: "))
|
||||
print(f"Selected layer: {idx, model[idx]}")
|
||||
|
||||
# bias
|
||||
bias_input = input("Plot bias (y/n): ")
|
||||
plot_bias: bool = False
|
||||
if bias_input == "y":
|
||||
plot_bias = True
|
||||
bias = model[idx]._parameters["bias"].data
|
||||
print(bias)
|
||||
else:
|
||||
bias = None
|
||||
|
||||
# show last layer's weights.
|
||||
if idx == len(model) - 1:
|
||||
linear_weights = model[idx].weight.cpu().detach().clone().numpy()
|
||||
|
||||
weights = linear_weights.reshape(2, 8, 74, 74)
|
||||
plot_in_grid(
|
||||
weights,
|
||||
fig_size=(10, 7),
|
||||
savetitle=f"{SAVE_PATH}_final_layer",
|
||||
colorbar=True,
|
||||
swap_channels=True,
|
||||
bias=bias,
|
||||
plot_bias=plot_bias,
|
||||
)
|
||||
|
||||
# visualize weights:
|
||||
elif idx > 0:
|
||||
weights = model[idx].weight.cpu().detach().clone().numpy()
|
||||
|
||||
if idx == 5:
|
||||
swap_channels = False
|
||||
layer = 3
|
||||
else:
|
||||
swap_channels = True
|
||||
layer = 2
|
||||
|
||||
# plot weights
|
||||
plot_in_grid(
|
||||
weights,
|
||||
fig_size=(11, 7),
|
||||
savetitle=f"{SAVE_PATH}_conv{layer}",
|
||||
colorbar=True,
|
||||
swap_channels=swap_channels,
|
||||
bias=bias,
|
||||
plot_bias=plot_bias,
|
||||
)
|
||||
else:
|
||||
first_weights = model[idx].weight.cpu().detach().clone().numpy()
|
||||
|
||||
# reshape first layer weights:
|
||||
reshape_input = input("Reshape weights to 4rows 8 cols (y/n): ")
|
||||
if reshape_input == "y":
|
||||
weights = first_weights.reshape(
|
||||
8, 4, first_weights.shape[-2], first_weights.shape[-1]
|
||||
)
|
||||
else:
|
||||
weights = first_weights
|
||||
plot_in_grid(
|
||||
weights,
|
||||
fig_size=(17, 17),
|
||||
savetitle=f"{SAVE_PATH}_conv1",
|
||||
colorbar=True,
|
||||
bias=bias,
|
||||
plot_bias=plot_bias,
|
||||
)
|
12
thesis code/shallow net/README.txt
Normal file
12
thesis code/shallow net/README.txt
Normal file
|
@ -0,0 +1,12 @@
|
|||
Folder shallow net:
|
||||
|
||||
1. config.json:
|
||||
* includes all cnn parameters and configurations
|
||||
* example for architecture: 32-8-8 (c1-c2-c3)
|
||||
|
||||
2. corner_loop_final.sh
|
||||
* bash script to train the 20 cnns of one cnn architecture
|
||||
|
||||
Folder functions:
|
||||
* contains the files do build the cnn, set the seeds, create a logging file, train and test the cnns
|
||||
* based on ---> Github: https://github.com/davrot/kk_contour_net_shallow.git
|
53
thesis code/shallow net/config.json
Normal file
53
thesis code/shallow net/config.json
Normal file
|
@ -0,0 +1,53 @@
|
|||
{
|
||||
"data_path": "/home/kk/Documents/Semester4/code/RenderStimuli/Output/",
|
||||
"save_logging_messages": true, // (true), false
|
||||
"display_logging_messages": true, // (true), false
|
||||
"batch_size_train": 500,
|
||||
"batch_size_test": 250,
|
||||
"max_epochs": 2000,
|
||||
"save_model": true,
|
||||
"conv_0_kernel_size": 11,
|
||||
"mp_1_kernel_size": 3,
|
||||
"mp_1_stride": 2,
|
||||
"use_plot_intermediate": true, // true, (false)
|
||||
"stimuli_per_pfinkel": 10000,
|
||||
"num_pfinkel_start": 0,
|
||||
"num_pfinkel_stop": 100,
|
||||
"num_pfinkel_step": 10,
|
||||
"precision_100_percent": 0, // (4)
|
||||
"train_first_layer": false, // true, (false)
|
||||
"save_ever_x_epochs": 10, // (10)
|
||||
"activation_function": "leaky relu", // tanh, relu, (leaky relu), none
|
||||
"leak_relu_negative_slope": 0.1, // (0.1)
|
||||
"switch_leakyR_to_relu": false,
|
||||
// LR Scheduler ->
|
||||
"use_scheduler": true, // (true), false
|
||||
"scheduler_verbose": true,
|
||||
"scheduler_factor": 0.1, //(0.1)
|
||||
"scheduler_patience": 10, // (10)
|
||||
"scheduler_threshold": 1e-5, // (1e-4)
|
||||
"minimum_learning_rate": 1e-8,
|
||||
"learning_rate": 0.0001,
|
||||
// <- LR Scheduler
|
||||
"pooling_type": "max", // (max), average, none
|
||||
"conv_0_enable_softmax": false, // true, (false)
|
||||
"use_adam": true, // (true) => adam, false => SGD
|
||||
"condition": "Natural",
|
||||
"scale_data": 255.0, // (255.0)
|
||||
"conv_out_channels_list": [
|
||||
[
|
||||
32,
|
||||
8,
|
||||
8
|
||||
]
|
||||
],
|
||||
"conv_kernel_sizes": [
|
||||
[
|
||||
7,
|
||||
15
|
||||
]
|
||||
],
|
||||
"conv_stride_sizes": [
|
||||
1
|
||||
]
|
||||
}
|
13
thesis code/shallow net/corner_loop_final.sh
Normal file
13
thesis code/shallow net/corner_loop_final.sh
Normal file
|
@ -0,0 +1,13 @@
|
|||
Directory="/home/kk/Documents/Semester4/code/Corner_contour_net_shallow"
|
||||
Priority="0"
|
||||
echo $Directory
|
||||
mkdir $Directory/argh_log_3288_fix
|
||||
for i in {0..20}; do
|
||||
for out_channels_idx in {0..0}; do
|
||||
for kernel_size_idx in {0..0}; do
|
||||
for stride_idx in {0..0}; do
|
||||
echo "hostname; cd $Directory ; /home/kk/P3.10/bin/python3 cnn_training.py --idx-conv-out-channels-list $out_channels_idx --idx-conv-kernel-sizes $kernel_size_idx --idx-conv-stride-sizes $stride_idx -s \$JOB_ID" | qsub -o $Directory/argh_log_3288_fix -j y -p $Priority -q gp4u,gp3u -N Corner3288fix
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
107
thesis code/shallow net/functions/alicorn_data_loader.py
Normal file
107
thesis code/shallow net/functions/alicorn_data_loader.py
Normal file
|
@ -0,0 +1,107 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def alicorn_data_loader(
|
||||
num_pfinkel: list[int] | None,
|
||||
load_stimuli_per_pfinkel: int,
|
||||
condition: str,
|
||||
data_path: str,
|
||||
logger=None,
|
||||
) -> torch.utils.data.TensorDataset:
|
||||
"""
|
||||
- num_pfinkel: list of the angles that should be loaded (ranging from
|
||||
0-90). If None: all pfinkels loaded
|
||||
- stimuli_per_pfinkel: defines amount of stimuli per path angle but
|
||||
for label 0 and label 1 seperatly (e.g., stimuli_per_pfinkel = 1000:
|
||||
1000 stimuli = label 1, 1000 stimuli = label 0)
|
||||
"""
|
||||
filename: str | None = None
|
||||
if condition == "Angular":
|
||||
filename = "angular_angle"
|
||||
elif condition == "Coignless":
|
||||
filename = "base_angle"
|
||||
elif condition == "Natural":
|
||||
filename = "corner_angle"
|
||||
else:
|
||||
filename = None
|
||||
assert filename is not None
|
||||
filepaths: str = os.path.join(data_path, f"{condition}")
|
||||
|
||||
stimuli_per_pfinkel: int = 100000
|
||||
|
||||
# ----------------------------
|
||||
|
||||
# for angles and batches
|
||||
if num_pfinkel is None:
|
||||
angle: list[int] = np.arange(0, 100, 10).tolist()
|
||||
else:
|
||||
angle = num_pfinkel
|
||||
|
||||
assert isinstance(angle, list)
|
||||
|
||||
batch: list[int] = np.arange(1, 11, 1).tolist()
|
||||
|
||||
if load_stimuli_per_pfinkel <= (stimuli_per_pfinkel // len(batch)):
|
||||
num_img_per_pfinkel: int = load_stimuli_per_pfinkel
|
||||
num_batches: int = 1
|
||||
else:
|
||||
# handle case where more than 10,000 stimuli per pfinkel needed
|
||||
num_batches = load_stimuli_per_pfinkel // (stimuli_per_pfinkel // len(batch))
|
||||
num_img_per_pfinkel = load_stimuli_per_pfinkel // num_batches
|
||||
|
||||
if logger is not None:
|
||||
logger.info(f"{num_batches} batches")
|
||||
logger.info(f"{num_img_per_pfinkel} stimuli per pfinkel.")
|
||||
|
||||
# initialize data and label tensors:
|
||||
num_stimuli: int = len(angle) * num_batches * num_img_per_pfinkel * 2
|
||||
data_tensor: torch.Tensor = torch.empty(
|
||||
(num_stimuli, 200, 200), dtype=torch.uint8, device=torch.device("cpu")
|
||||
)
|
||||
label_tensor: torch.Tensor = torch.empty(
|
||||
(num_stimuli), dtype=torch.int64, device=torch.device("cpu")
|
||||
)
|
||||
|
||||
if logger is not None:
|
||||
logger.info(f"data tensor shape: {data_tensor.shape}")
|
||||
logger.info(f"label tensor shape: {label_tensor.shape}")
|
||||
|
||||
# append data
|
||||
idx: int = 0
|
||||
for i in range(len(angle)):
|
||||
for j in range(num_batches):
|
||||
# load contour
|
||||
temp_filename: str = (
|
||||
f"{filename}_{angle[i]:03}_b{batch[j]:03}_n10000_RENDERED.npz"
|
||||
)
|
||||
contour_filename: str = os.path.join(filepaths, temp_filename)
|
||||
c_data = np.load(contour_filename)
|
||||
data_tensor[idx : idx + num_img_per_pfinkel, ...] = torch.tensor(
|
||||
c_data["gaborfield"][:num_img_per_pfinkel, ...],
|
||||
dtype=torch.uint8,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
label_tensor[idx : idx + num_img_per_pfinkel] = int(1)
|
||||
idx += num_img_per_pfinkel
|
||||
|
||||
# next append distractor stimuli
|
||||
for i in range(len(angle)):
|
||||
for j in range(num_batches):
|
||||
# load distractor
|
||||
temp_filename = (
|
||||
f"{filename}_{angle[i]:03}_dist_b{batch[j]:03}_n10000_RENDERED.npz"
|
||||
)
|
||||
distractor_filename: str = os.path.join(filepaths, temp_filename)
|
||||
nc_data = np.load(distractor_filename)
|
||||
data_tensor[idx : idx + num_img_per_pfinkel, ...] = torch.tensor(
|
||||
nc_data["gaborfield"][:num_img_per_pfinkel, ...],
|
||||
dtype=torch.uint8,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
label_tensor[idx : idx + num_img_per_pfinkel] = int(0)
|
||||
idx += num_img_per_pfinkel
|
||||
|
||||
return torch.utils.data.TensorDataset(label_tensor, data_tensor.unsqueeze(1))
|
103
thesis code/shallow net/functions/analyse_network.py
Normal file
103
thesis code/shallow net/functions/analyse_network.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
import torch
|
||||
|
||||
|
||||
def unfold(
|
||||
layer: torch.nn.Conv2d | torch.nn.MaxPool2d | torch.nn.AvgPool2d, size: int
|
||||
) -> torch.Tensor:
|
||||
if isinstance(layer.kernel_size, tuple):
|
||||
assert layer.kernel_size[0] == layer.kernel_size[1]
|
||||
kernel_size: int = int(layer.kernel_size[0])
|
||||
else:
|
||||
kernel_size = int(layer.kernel_size)
|
||||
|
||||
if isinstance(layer.dilation, tuple):
|
||||
assert layer.dilation[0] == layer.dilation[1]
|
||||
dilation: int = int(layer.dilation[0])
|
||||
else:
|
||||
dilation = int(layer.dilation) # type: ignore
|
||||
|
||||
if isinstance(layer.padding, tuple):
|
||||
assert layer.padding[0] == layer.padding[1]
|
||||
padding: int = int(layer.padding[0])
|
||||
else:
|
||||
padding = int(layer.padding)
|
||||
|
||||
if isinstance(layer.stride, tuple):
|
||||
assert layer.stride[0] == layer.stride[1]
|
||||
stride: int = int(layer.stride[0])
|
||||
else:
|
||||
stride = int(layer.stride)
|
||||
|
||||
out = (
|
||||
torch.nn.functional.unfold(
|
||||
torch.arange(0, size, dtype=torch.float32)
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(-1),
|
||||
kernel_size=(kernel_size, 1),
|
||||
dilation=(dilation, 1),
|
||||
padding=(padding, 0),
|
||||
stride=(stride, 1),
|
||||
)
|
||||
.squeeze(0)
|
||||
.type(torch.int64)
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def analyse_network(
|
||||
model: torch.nn.Sequential, input_shape: int
|
||||
) -> tuple[list, list, list]:
|
||||
combined_list: list = []
|
||||
coordinate_list: list = []
|
||||
layer_type_list: list = []
|
||||
pixel_used: list[int] = []
|
||||
|
||||
size: int = int(input_shape)
|
||||
|
||||
for layer_id in range(0, len(model)):
|
||||
if isinstance(
|
||||
model[layer_id], (torch.nn.Conv2d, torch.nn.MaxPool2d, torch.nn.AvgPool2d)
|
||||
):
|
||||
out = unfold(layer=model[layer_id], size=size)
|
||||
coordinate_list.append(out)
|
||||
layer_type_list.append(
|
||||
str(type(model[layer_id])).split(".")[-1].split("'")[0]
|
||||
)
|
||||
size = int(out.shape[-1])
|
||||
else:
|
||||
coordinate_list.append(None)
|
||||
layer_type_list.append(None)
|
||||
|
||||
assert coordinate_list[0] is not None
|
||||
combined_list.append(coordinate_list[0])
|
||||
|
||||
for i in range(1, len(coordinate_list)):
|
||||
if coordinate_list[i] is None:
|
||||
combined_list.append(combined_list[i - 1])
|
||||
else:
|
||||
for pos in range(0, coordinate_list[i].shape[-1]):
|
||||
idx_shape: int | None = None
|
||||
|
||||
idx = torch.unique(
|
||||
torch.flatten(combined_list[i - 1][:, coordinate_list[i][:, pos]])
|
||||
)
|
||||
if idx_shape is None:
|
||||
idx_shape = idx.shape[0]
|
||||
assert idx_shape == idx.shape[0]
|
||||
|
||||
assert idx_shape is not None
|
||||
|
||||
temp = torch.zeros((idx_shape, coordinate_list[i].shape[-1]))
|
||||
for pos in range(0, coordinate_list[i].shape[-1]):
|
||||
idx = torch.unique(
|
||||
torch.flatten(combined_list[i - 1][:, coordinate_list[i][:, pos]])
|
||||
)
|
||||
temp[:, pos] = idx
|
||||
combined_list.append(temp)
|
||||
|
||||
for i in range(0, len(combined_list)):
|
||||
pixel_used.append(int(torch.unique(torch.flatten(combined_list[i])).shape[0]))
|
||||
|
||||
return combined_list, layer_type_list, pixel_used
|
40
thesis code/shallow net/functions/create_logger.py
Normal file
40
thesis code/shallow net/functions/create_logger.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
import logging
|
||||
import datetime
|
||||
import os
|
||||
|
||||
|
||||
def create_logger(save_logging_messages: bool, display_logging_messages: bool, model_name: str | None):
|
||||
now = datetime.datetime.now()
|
||||
dt_string_filename = now.strftime("%Y_%m_%d_%H_%M_%S")
|
||||
|
||||
logger = logging.getLogger("MyLittleLogger")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
if save_logging_messages:
|
||||
if model_name:
|
||||
filename = os.path.join(
|
||||
"logs", f"log_{dt_string_filename}_{model_name}.txt"
|
||||
)
|
||||
else:
|
||||
filename = os.path.join("logs", f"log_{dt_string_filename}.txt")
|
||||
|
||||
time_format = "%b %-d %Y %H:%M:%S"
|
||||
logformat = "%(asctime)s %(message)s"
|
||||
file_formatter = logging.Formatter(fmt=logformat, datefmt=time_format)
|
||||
os.makedirs("logs", exist_ok=True)
|
||||
file_handler = logging.FileHandler(filename)
|
||||
file_handler.setLevel(logging.INFO)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
if display_logging_messages:
|
||||
time_format = "%H:%M:%S"
|
||||
logformat = "%(asctime)s %(message)s"
|
||||
stream_formatter = logging.Formatter(fmt=logformat, datefmt=time_format)
|
||||
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setLevel(logging.INFO)
|
||||
stream_handler.setFormatter(stream_formatter)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
return logger
|
114
thesis code/shallow net/functions/make_cnn.py
Normal file
114
thesis code/shallow net/functions/make_cnn.py
Normal file
|
@ -0,0 +1,114 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def make_cnn(
|
||||
conv_out_channels_list: list[int],
|
||||
conv_kernel_size: list[int],
|
||||
conv_stride_size: int,
|
||||
conv_activation_function: str,
|
||||
train_conv_0: bool,
|
||||
logger,
|
||||
conv_0_kernel_size: int,
|
||||
mp_1_kernel_size: int,
|
||||
mp_1_stride: int,
|
||||
pooling_type: str,
|
||||
conv_0_enable_softmax: bool,
|
||||
l_relu_negative_slope: float,
|
||||
) -> torch.nn.Sequential:
|
||||
assert len(conv_out_channels_list) >= 1
|
||||
assert len(conv_out_channels_list) == len(conv_kernel_size) + 1
|
||||
|
||||
cnn = torch.nn.Sequential()
|
||||
|
||||
# Fixed structure
|
||||
cnn.append(
|
||||
torch.nn.Conv2d(
|
||||
in_channels=1,
|
||||
out_channels=conv_out_channels_list[0] if train_conv_0 else 32,
|
||||
kernel_size=conv_0_kernel_size,
|
||||
stride=1,
|
||||
bias=train_conv_0,
|
||||
)
|
||||
)
|
||||
|
||||
if conv_0_enable_softmax:
|
||||
cnn.append(torch.nn.Softmax(dim=1))
|
||||
|
||||
setting_understood: bool = False
|
||||
if conv_activation_function.upper() == str("relu").upper():
|
||||
cnn.append(torch.nn.ReLU())
|
||||
setting_understood = True
|
||||
elif conv_activation_function.upper() == str("leaky relu").upper():
|
||||
cnn.append(torch.nn.LeakyReLU(negative_slope=l_relu_negative_slope))
|
||||
setting_understood = True
|
||||
elif conv_activation_function.upper() == str("tanh").upper():
|
||||
cnn.append(torch.nn.Tanh())
|
||||
setting_understood = True
|
||||
elif conv_activation_function.upper() == str("none").upper():
|
||||
setting_understood = True
|
||||
assert setting_understood
|
||||
|
||||
setting_understood = False
|
||||
if pooling_type.upper() == str("max").upper():
|
||||
cnn.append(torch.nn.MaxPool2d(kernel_size=mp_1_kernel_size, stride=mp_1_stride))
|
||||
setting_understood = True
|
||||
elif pooling_type.upper() == str("average").upper():
|
||||
cnn.append(torch.nn.AvgPool2d(kernel_size=mp_1_kernel_size, stride=mp_1_stride))
|
||||
setting_understood = True
|
||||
elif pooling_type.upper() == str("none").upper():
|
||||
setting_understood = True
|
||||
assert setting_understood
|
||||
|
||||
# Changing structure
|
||||
for i in range(1, len(conv_out_channels_list)):
|
||||
if i == 1 and not train_conv_0:
|
||||
in_channels = 32
|
||||
else:
|
||||
in_channels = conv_out_channels_list[i - 1]
|
||||
cnn.append(
|
||||
torch.nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=conv_out_channels_list[i],
|
||||
kernel_size=conv_kernel_size[i - 1],
|
||||
stride=conv_stride_size,
|
||||
bias=True,
|
||||
)
|
||||
)
|
||||
setting_understood = False
|
||||
if conv_activation_function.upper() == str("relu").upper():
|
||||
cnn.append(torch.nn.ReLU())
|
||||
setting_understood = True
|
||||
elif conv_activation_function.upper() == str("leaky relu").upper():
|
||||
cnn.append(torch.nn.LeakyReLU(negative_slope=l_relu_negative_slope))
|
||||
setting_understood = True
|
||||
elif conv_activation_function.upper() == str("tanh").upper():
|
||||
cnn.append(torch.nn.Tanh())
|
||||
setting_understood = True
|
||||
elif conv_activation_function.upper() == str("none").upper():
|
||||
setting_understood = True
|
||||
|
||||
assert setting_understood
|
||||
|
||||
# Fixed structure
|
||||
# define fully connected layer:
|
||||
cnn.append(torch.nn.Flatten(start_dim=1))
|
||||
cnn.append(torch.nn.LazyLinear(2, bias=True))
|
||||
|
||||
# if conv1 not trained:
|
||||
filename_load_weight_0: str | None = None
|
||||
if train_conv_0 is False and cnn[0]._parameters["weight"].shape[0] == 32:
|
||||
filename_load_weight_0 = "weights_radius10_norm.npy"
|
||||
if train_conv_0 is False and cnn[0]._parameters["weight"].shape[0] == 16:
|
||||
filename_load_weight_0 = "8orient_2phase_weights.npy"
|
||||
|
||||
if filename_load_weight_0 is not None:
|
||||
logger.info(f"Replace weights in CNN 0 with {filename_load_weight_0}")
|
||||
cnn[0]._parameters["weight"] = torch.tensor(
|
||||
np.load(filename_load_weight_0),
|
||||
dtype=cnn[0]._parameters["weight"].dtype,
|
||||
requires_grad=False,
|
||||
device=cnn[0]._parameters["weight"].device,
|
||||
)
|
||||
|
||||
return cnn
|
84
thesis code/shallow net/functions/plot_intermediate.py
Normal file
84
thesis code/shallow net/functions/plot_intermediate.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib as mpl
|
||||
import os
|
||||
import re
|
||||
|
||||
mpl.rcParams["text.usetex"] = True
|
||||
mpl.rcParams["font.family"] = "serif"
|
||||
|
||||
|
||||
def plot_intermediate(
|
||||
train_accuracy: list[float],
|
||||
test_accuracy: list[float],
|
||||
train_losses: list[float],
|
||||
test_losses: list[float],
|
||||
save_name: str,
|
||||
reduction_factor: int = 1,
|
||||
) -> None:
|
||||
assert len(train_accuracy) == len(test_accuracy)
|
||||
assert len(train_accuracy) == len(train_losses)
|
||||
assert len(train_accuracy) == len(test_losses)
|
||||
|
||||
# legend:
|
||||
pattern = (
|
||||
r"(outChannels\[\d+(?:, \d+)*\]_kernelSize\[\d+(?:, \d+)*\]_)([^_]+)(?=_stride)"
|
||||
)
|
||||
matches = re.findall(pattern, save_name)
|
||||
legend_label = "".join(["".join(match) for match in matches])
|
||||
|
||||
max_epochs: int = len(train_accuracy)
|
||||
# set stepsize
|
||||
x = np.arange(1, max_epochs + 1)
|
||||
|
||||
stepsize = max_epochs // reduction_factor
|
||||
|
||||
# accuracies
|
||||
plt.figure(figsize=[12, 7])
|
||||
plt.subplot(2, 1, 1)
|
||||
|
||||
plt.plot(x, np.array(train_accuracy), label="Train: " + str(legend_label))
|
||||
plt.plot(x, np.array(test_accuracy), label="Test: " + str(legend_label))
|
||||
plt.title("Training and Testing Accuracy", fontsize=18)
|
||||
plt.xlabel("Epoch", fontsize=18)
|
||||
plt.ylabel("Accuracy (\\%)", fontsize=18)
|
||||
plt.legend(fontsize=14)
|
||||
plt.xticks(
|
||||
np.concatenate((np.array([1]), np.arange(stepsize, max_epochs + 1, stepsize))),
|
||||
np.concatenate((np.array([1]), np.arange(stepsize, max_epochs + 1, stepsize))),
|
||||
)
|
||||
|
||||
# Increase tick label font size
|
||||
plt.xticks(fontsize=16)
|
||||
plt.yticks(fontsize=16)
|
||||
plt.grid(True)
|
||||
|
||||
# losses
|
||||
plt.subplot(2, 1, 2)
|
||||
plt.plot(x, np.array(train_losses), label="Train: " + str(legend_label))
|
||||
plt.plot(x, np.array(test_losses), label="Test: " + str(legend_label))
|
||||
plt.title("Training and Testing Losses", fontsize=18)
|
||||
plt.xlabel("Epoch", fontsize=18)
|
||||
plt.ylabel("Loss", fontsize=18)
|
||||
plt.legend(fontsize=14)
|
||||
plt.xticks(
|
||||
np.concatenate((np.array([1]), np.arange(stepsize, max_epochs + 1, stepsize))),
|
||||
np.concatenate((np.array([1]), np.arange(stepsize, max_epochs + 1, stepsize))),
|
||||
)
|
||||
|
||||
# Increase tick label font size
|
||||
plt.xticks(fontsize=16)
|
||||
plt.yticks(fontsize=16)
|
||||
plt.grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
os.makedirs("performance_plots", exist_ok=True)
|
||||
plt.savefig(
|
||||
os.path.join(
|
||||
"performance_plots",
|
||||
f"performance_{save_name}.pdf",
|
||||
),
|
||||
dpi=300,
|
||||
bbox_inches="tight",
|
||||
)
|
||||
plt.show()
|
12
thesis code/shallow net/functions/set_seed.py
Normal file
12
thesis code/shallow net/functions/set_seed.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def set_seed(seed: int, logger) -> None:
|
||||
# set seed for all used modules
|
||||
if logger:
|
||||
logger.info(f"set seed to {seed}")
|
||||
torch.manual_seed(seed=seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed=seed)
|
||||
np.random.seed(seed=seed)
|
58
thesis code/shallow net/functions/test.py
Normal file
58
thesis code/shallow net/functions/test.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
import torch
|
||||
import logging
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test(
|
||||
model: torch.nn.modules.container.Sequential,
|
||||
loader: torch.utils.data.dataloader.DataLoader,
|
||||
device: torch.device,
|
||||
tb,
|
||||
epoch: int,
|
||||
logger: logging.Logger,
|
||||
test_accuracy: list[float],
|
||||
test_losses: list[float],
|
||||
scale_data: float,
|
||||
) -> float:
|
||||
test_loss: float = 0.0
|
||||
correct: int = 0
|
||||
pattern_count: float = 0.0
|
||||
|
||||
model.eval()
|
||||
|
||||
for data in loader:
|
||||
label = data[0].to(device)
|
||||
image = data[1].type(dtype=torch.float32).to(device)
|
||||
if scale_data > 0:
|
||||
image /= scale_data
|
||||
|
||||
output = model(image)
|
||||
|
||||
# loss and optimization
|
||||
loss = torch.nn.functional.cross_entropy(output, label, reduction="sum")
|
||||
pattern_count += float(label.shape[0])
|
||||
test_loss += loss.item()
|
||||
prediction = output.argmax(dim=1)
|
||||
correct += prediction.eq(label).sum().item()
|
||||
|
||||
logger.info(
|
||||
(
|
||||
"Test set:"
|
||||
f" Average loss: {test_loss / pattern_count:.3e},"
|
||||
f" Accuracy: {correct}/{pattern_count},"
|
||||
f"({100.0 * correct / pattern_count:.2f}%)"
|
||||
)
|
||||
)
|
||||
logger.info("")
|
||||
|
||||
acc = 100.0 * correct / pattern_count
|
||||
test_losses.append(test_loss / pattern_count)
|
||||
test_accuracy.append(acc)
|
||||
|
||||
# add to tb:
|
||||
tb.add_scalar("Test Loss", (test_loss / pattern_count), epoch)
|
||||
tb.add_scalar("Test Performance", 100.0 * correct / pattern_count, epoch)
|
||||
tb.add_scalar("Test Number Correct", correct, epoch)
|
||||
tb.flush()
|
||||
|
||||
return acc
|
80
thesis code/shallow net/functions/train.py
Normal file
80
thesis code/shallow net/functions/train.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
import torch
|
||||
import logging
|
||||
|
||||
|
||||
def train(
|
||||
model: torch.nn.modules.container.Sequential,
|
||||
loader: torch.utils.data.dataloader.DataLoader,
|
||||
optimizer: torch.optim.Adam | torch.optim.SGD,
|
||||
epoch: int,
|
||||
device: torch.device,
|
||||
tb,
|
||||
test_acc,
|
||||
logger: logging.Logger,
|
||||
train_accuracy: list[float],
|
||||
train_losses: list[float],
|
||||
train_loss: list[float],
|
||||
scale_data: float,
|
||||
) -> float:
|
||||
num_train_pattern: int = 0
|
||||
running_loss: float = 0.0
|
||||
correct: int = 0
|
||||
pattern_count: float = 0.0
|
||||
|
||||
model.train()
|
||||
for data in loader:
|
||||
label = data[0].to(device)
|
||||
image = data[1].type(dtype=torch.float32).to(device)
|
||||
if scale_data > 0:
|
||||
image /= scale_data
|
||||
|
||||
optimizer.zero_grad()
|
||||
output = model(image)
|
||||
loss = torch.nn.functional.cross_entropy(output, label, reduction="sum")
|
||||
loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
|
||||
# for loss and accuracy plotting:
|
||||
num_train_pattern += int(label.shape[0])
|
||||
pattern_count += float(label.shape[0])
|
||||
running_loss += float(loss)
|
||||
train_loss.append(float(loss))
|
||||
prediction = output.argmax(dim=1)
|
||||
correct += prediction.eq(label).sum().item()
|
||||
|
||||
total_number_of_pattern: int = int(len(loader)) * int(label.shape[0])
|
||||
|
||||
# infos:
|
||||
logger.info(
|
||||
(
|
||||
"Train Epoch:"
|
||||
f" {epoch}"
|
||||
f" [{int(pattern_count)}/{total_number_of_pattern}"
|
||||
f" ({100.0 * pattern_count / total_number_of_pattern:.2f}%)],"
|
||||
f" Loss: {float(running_loss) / float(num_train_pattern):.4e},"
|
||||
f" Acc: {(100.0 * correct / num_train_pattern):.2f}"
|
||||
f" Test Acc: {test_acc:.2f}%,"
|
||||
f" LR: {optimizer.param_groups[0]['lr']:.2e}"
|
||||
)
|
||||
)
|
||||
|
||||
acc = 100.0 * correct / num_train_pattern
|
||||
train_accuracy.append(acc)
|
||||
|
||||
epoch_loss = running_loss / pattern_count
|
||||
train_losses.append(epoch_loss)
|
||||
|
||||
# add to tb:
|
||||
tb.add_scalar("Train Loss", loss.item(), epoch)
|
||||
tb.add_scalar("Train Performance", torch.tensor(acc), epoch)
|
||||
tb.add_scalar("Train Number Correct", torch.tensor(correct), epoch)
|
||||
|
||||
# for parameters:
|
||||
for name, param in model.named_parameters():
|
||||
if "weight" in name or "bias" in name:
|
||||
tb.add_histogram(f"{name}", param.data.clone(), epoch)
|
||||
|
||||
tb.flush()
|
||||
|
||||
return epoch_loss
|
Loading…
Reference in a new issue