Dateien nach „tools“ hochladen
This commit is contained in:
commit
8b432157d8
5 changed files with 1390 additions and 0 deletions
237
tools/NNMF2d.py
Normal file
237
tools/NNMF2d.py
Normal file
|
@ -0,0 +1,237 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class NNMF2d(torch.nn.Module):
|
||||||
|
|
||||||
|
in_channels: int
|
||||||
|
out_channels: int
|
||||||
|
weight: torch.Tensor
|
||||||
|
iterations: int
|
||||||
|
epsilon: float | None
|
||||||
|
init_min: float
|
||||||
|
init_max: float
|
||||||
|
local_learning: bool
|
||||||
|
local_learning_kl: bool
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
iterations: int = 20,
|
||||||
|
epsilon: float | None = None,
|
||||||
|
init_min: float = 0.0,
|
||||||
|
init_max: float = 1.0,
|
||||||
|
local_learning: bool = False,
|
||||||
|
local_learning_kl: bool = False,
|
||||||
|
) -> None:
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.init_min = init_min
|
||||||
|
self.init_max = init_max
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
self.iterations = iterations
|
||||||
|
self.local_learning = local_learning
|
||||||
|
self.local_learning_kl = local_learning_kl
|
||||||
|
|
||||||
|
self.weight = torch.nn.parameter.Parameter(
|
||||||
|
torch.empty((out_channels, in_channels), **factory_kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
self.functional_nnmf2d = FunctionalNNMF2d.apply
|
||||||
|
|
||||||
|
self.epsilon = epsilon
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
s: str = f"{self.in_channels}, {self.out_channels}"
|
||||||
|
|
||||||
|
if self.epsilon is not None:
|
||||||
|
s += f", epsilon={self.epsilon}"
|
||||||
|
s += f", local_learning={self.local_learning}"
|
||||||
|
|
||||||
|
if self.local_learning:
|
||||||
|
s += f", local_learning_kl={self.local_learning_kl}"
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
def reset_parameters(self) -> None:
|
||||||
|
torch.nn.init.uniform_(self.weight, a=self.init_min, b=self.init_max)
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
positive_weights = torch.abs(self.weight)
|
||||||
|
positive_weights = positive_weights / (
|
||||||
|
positive_weights.sum(dim=1, keepdim=True) + 10e-20
|
||||||
|
)
|
||||||
|
|
||||||
|
h_dyn = self.functional_nnmf2d(
|
||||||
|
input,
|
||||||
|
positive_weights,
|
||||||
|
self.out_channels,
|
||||||
|
self.iterations,
|
||||||
|
self.epsilon,
|
||||||
|
self.local_learning,
|
||||||
|
self.local_learning_kl,
|
||||||
|
)
|
||||||
|
|
||||||
|
return h_dyn
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionalNNMF2d(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward( # type: ignore
|
||||||
|
ctx,
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
out_channels: int,
|
||||||
|
iterations: int,
|
||||||
|
epsilon: float | None,
|
||||||
|
local_learning: bool,
|
||||||
|
local_learning_kl: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# Prepare h
|
||||||
|
h = torch.full(
|
||||||
|
(input.shape[0], out_channels, input.shape[-2], input.shape[-1]),
|
||||||
|
1.0 / float(out_channels),
|
||||||
|
device=input.device,
|
||||||
|
dtype=input.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
h = h.movedim(1, -1)
|
||||||
|
input = input.movedim(1, -1)
|
||||||
|
for _ in range(0, iterations):
|
||||||
|
reconstruction = torch.nn.functional.linear(h, weight.T)
|
||||||
|
reconstruction += 1e-20
|
||||||
|
if epsilon is None:
|
||||||
|
h *= torch.nn.functional.linear((input / reconstruction), weight)
|
||||||
|
else:
|
||||||
|
h *= 1 + epsilon * torch.nn.functional.linear(
|
||||||
|
(input / reconstruction), weight
|
||||||
|
)
|
||||||
|
h /= h.sum(-1, keepdim=True) + 10e-20
|
||||||
|
h = h.movedim(-1, 1)
|
||||||
|
input = input.movedim(-1, 1)
|
||||||
|
|
||||||
|
# ###########################################################
|
||||||
|
# Save the necessary data for the backward pass
|
||||||
|
# ###########################################################
|
||||||
|
ctx.save_for_backward(input, weight, h)
|
||||||
|
ctx.local_learning = local_learning
|
||||||
|
ctx.local_learning_kl = local_learning_kl
|
||||||
|
|
||||||
|
assert torch.isfinite(h).all()
|
||||||
|
return h
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.autograd.function.once_differentiable
|
||||||
|
def backward(ctx, grad_output: torch.Tensor) -> tuple[ # type: ignore
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor | None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
]:
|
||||||
|
|
||||||
|
# ##############################################
|
||||||
|
# Default values
|
||||||
|
# ##############################################
|
||||||
|
grad_weight: torch.Tensor | None = None
|
||||||
|
|
||||||
|
# ##############################################
|
||||||
|
# Get the variables back
|
||||||
|
# ##############################################
|
||||||
|
(input, weight, h) = ctx.saved_tensors
|
||||||
|
|
||||||
|
# The back prop gradient
|
||||||
|
h = h.movedim(1, -1)
|
||||||
|
grad_output = grad_output.movedim(1, -1)
|
||||||
|
input = input.movedim(1, -1)
|
||||||
|
big_r = torch.nn.functional.linear(h, weight.T)
|
||||||
|
big_r_div = 1.0 / (big_r + 1e-20)
|
||||||
|
|
||||||
|
factor_x_div_r = input * big_r_div
|
||||||
|
|
||||||
|
grad_input: torch.Tensor = (
|
||||||
|
torch.nn.functional.linear(h * grad_output, weight.T) * big_r_div
|
||||||
|
)
|
||||||
|
|
||||||
|
del big_r_div
|
||||||
|
|
||||||
|
# The weight gradient
|
||||||
|
if ctx.local_learning is False:
|
||||||
|
del big_r
|
||||||
|
|
||||||
|
grad_weight = -torch.nn.functional.linear(
|
||||||
|
h.reshape(
|
||||||
|
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
||||||
|
h.shape[3],
|
||||||
|
).T,
|
||||||
|
(factor_x_div_r * grad_input)
|
||||||
|
.reshape(
|
||||||
|
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
||||||
|
grad_input.shape[3],
|
||||||
|
)
|
||||||
|
.T,
|
||||||
|
)
|
||||||
|
|
||||||
|
grad_weight += torch.nn.functional.linear(
|
||||||
|
(h * grad_output)
|
||||||
|
.reshape(
|
||||||
|
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
||||||
|
h.shape[3],
|
||||||
|
)
|
||||||
|
.T,
|
||||||
|
factor_x_div_r.reshape(
|
||||||
|
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
||||||
|
grad_input.shape[3],
|
||||||
|
).T,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if ctx.local_learning_kl:
|
||||||
|
grad_weight = -torch.nn.functional.linear(
|
||||||
|
h.reshape(
|
||||||
|
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
||||||
|
h.shape[3],
|
||||||
|
).T,
|
||||||
|
factor_x_div_r.reshape(
|
||||||
|
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
||||||
|
grad_input.shape[3],
|
||||||
|
).T,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
grad_weight = -torch.nn.functional.linear(
|
||||||
|
h.reshape(
|
||||||
|
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
||||||
|
h.shape[3],
|
||||||
|
).T,
|
||||||
|
(2 * (input - big_r))
|
||||||
|
.reshape(
|
||||||
|
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
||||||
|
grad_input.shape[3],
|
||||||
|
)
|
||||||
|
.T,
|
||||||
|
)
|
||||||
|
grad_input = grad_input.movedim(-1, 1)
|
||||||
|
assert torch.isfinite(grad_input).all()
|
||||||
|
assert torch.isfinite(grad_weight).all()
|
||||||
|
|
||||||
|
return (
|
||||||
|
grad_input,
|
||||||
|
grad_weight,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
234
tools/append_block.py
Normal file
234
tools/append_block.py
Normal file
|
@ -0,0 +1,234 @@
|
||||||
|
import torch
|
||||||
|
from tools.L1NormLayer import L1NormLayer
|
||||||
|
from tools.NNMF2d import NNMF2d
|
||||||
|
from tools.append_parameter import append_parameter
|
||||||
|
|
||||||
|
|
||||||
|
def append_block(
|
||||||
|
network: torch.nn.Sequential,
|
||||||
|
number_of_neurons_a: int,
|
||||||
|
number_of_neurons_b: int,
|
||||||
|
test_image: torch.Tensor,
|
||||||
|
parameter_neuron_a: list[torch.nn.parameter.Parameter],
|
||||||
|
parameter_neuron_b: list[torch.nn.parameter.Parameter],
|
||||||
|
parameter_batchnorm2d: list[torch.nn.parameter.Parameter],
|
||||||
|
device: torch.device,
|
||||||
|
dilation: tuple[int, int] | int = 1,
|
||||||
|
padding: tuple[int, int] | int = 0,
|
||||||
|
stride: tuple[int, int] | int = 1,
|
||||||
|
kernel_size: tuple[int, int] = (5, 5),
|
||||||
|
epsilon: float | None = None,
|
||||||
|
iterations: int = 20,
|
||||||
|
local_learning: bool = False,
|
||||||
|
local_learning_kl: bool = False,
|
||||||
|
momentum: float = 0.1,
|
||||||
|
track_running_stats: bool = False,
|
||||||
|
type_of_neuron_a: int = 0,
|
||||||
|
type_of_neuron_b: int = 0,
|
||||||
|
batch_norm_neuron_a: bool = True,
|
||||||
|
batch_norm_neuron_b: bool = True,
|
||||||
|
bias_norm_neuron_a: bool = False,
|
||||||
|
bias_norm_neuron_b: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
assert (type_of_neuron_a > 0) or (type_of_neuron_b > 0)
|
||||||
|
|
||||||
|
if number_of_neurons_b <= 0:
|
||||||
|
number_of_neurons_b = number_of_neurons_a
|
||||||
|
|
||||||
|
if number_of_neurons_a <= 0:
|
||||||
|
number_of_neurons_a = number_of_neurons_b
|
||||||
|
|
||||||
|
assert (type_of_neuron_a == 1) or (type_of_neuron_a == 2)
|
||||||
|
assert (type_of_neuron_b == 0) or (type_of_neuron_b == 1) or (type_of_neuron_b == 2)
|
||||||
|
|
||||||
|
kernel_size_internal: list[int] = [kernel_size[-2], kernel_size[-1]]
|
||||||
|
|
||||||
|
if kernel_size[0] < 1:
|
||||||
|
kernel_size_internal[0] = test_image.shape[-2]
|
||||||
|
|
||||||
|
if kernel_size[1] < 1:
|
||||||
|
kernel_size_internal[1] = test_image.shape[-1]
|
||||||
|
|
||||||
|
network.append(torch.nn.ReLU())
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
# I need the output size
|
||||||
|
mock_output = (
|
||||||
|
torch.nn.functional.conv2d(
|
||||||
|
torch.zeros(
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
test_image.shape[2],
|
||||||
|
test_image.shape[3],
|
||||||
|
),
|
||||||
|
torch.zeros((1, 1, kernel_size_internal[0], kernel_size_internal[1])),
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
)
|
||||||
|
.squeeze(0)
|
||||||
|
.squeeze(0)
|
||||||
|
)
|
||||||
|
network.append(
|
||||||
|
torch.nn.Unfold(
|
||||||
|
kernel_size=(kernel_size_internal[-2], kernel_size_internal[-1]),
|
||||||
|
dilation=dilation,
|
||||||
|
padding=padding,
|
||||||
|
stride=stride,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
network.append(
|
||||||
|
torch.nn.Fold(
|
||||||
|
output_size=mock_output.shape,
|
||||||
|
kernel_size=(1, 1),
|
||||||
|
dilation=1,
|
||||||
|
padding=0,
|
||||||
|
stride=1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
network.append(L1NormLayer())
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
if type_of_neuron_a == 1:
|
||||||
|
network.append(
|
||||||
|
NNMF2d(
|
||||||
|
in_channels=test_image.shape[1],
|
||||||
|
out_channels=number_of_neurons_a,
|
||||||
|
epsilon=epsilon,
|
||||||
|
iterations=iterations,
|
||||||
|
local_learning=local_learning,
|
||||||
|
local_learning_kl=local_learning_kl,
|
||||||
|
).to(device)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
append_parameter(module=network[-1], parameter_list=parameter_neuron_a)
|
||||||
|
|
||||||
|
elif type_of_neuron_a == 2:
|
||||||
|
network.append(
|
||||||
|
torch.nn.Conv2d(
|
||||||
|
in_channels=test_image.shape[1],
|
||||||
|
out_channels=number_of_neurons_a,
|
||||||
|
kernel_size=(1, 1),
|
||||||
|
bias=bias_norm_neuron_a,
|
||||||
|
).to(device)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
append_parameter(module=network[-1], parameter_list=parameter_neuron_a)
|
||||||
|
else:
|
||||||
|
assert (type_of_neuron_a == 1) or (type_of_neuron_a == 2)
|
||||||
|
|
||||||
|
if batch_norm_neuron_a:
|
||||||
|
if (test_image.shape[-1] > 1) or (test_image.shape[-2] > 1):
|
||||||
|
network.append(
|
||||||
|
torch.nn.BatchNorm2d(
|
||||||
|
num_features=test_image.shape[1],
|
||||||
|
momentum=momentum,
|
||||||
|
track_running_stats=track_running_stats,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
append_parameter(module=network[-1], parameter_list=parameter_batchnorm2d)
|
||||||
|
|
||||||
|
if type_of_neuron_b == 0:
|
||||||
|
pass
|
||||||
|
elif type_of_neuron_b == 1:
|
||||||
|
|
||||||
|
network.append(torch.nn.ReLU())
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
network.append(L1NormLayer())
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
network.append(
|
||||||
|
NNMF2d(
|
||||||
|
in_channels=test_image.shape[1],
|
||||||
|
out_channels=number_of_neurons_b,
|
||||||
|
epsilon=epsilon,
|
||||||
|
iterations=iterations,
|
||||||
|
local_learning=local_learning,
|
||||||
|
local_learning_kl=local_learning_kl,
|
||||||
|
).to(device)
|
||||||
|
)
|
||||||
|
# Init the cnn top layers 1x1 conv2d layers
|
||||||
|
for name, param in network[-1].named_parameters():
|
||||||
|
with torch.no_grad():
|
||||||
|
print(param.shape)
|
||||||
|
if name == "weight":
|
||||||
|
if number_of_neurons_a >= param.shape[0]:
|
||||||
|
param.data[: param.shape[0], : param.shape[0]] = torch.eye(
|
||||||
|
param.shape[0], dtype=param.dtype, device=param.device
|
||||||
|
)
|
||||||
|
param.data[param.shape[0] :, :] = 0
|
||||||
|
param.data[:, param.shape[0] :] = 0
|
||||||
|
param.data += 1.0 / 10000.0
|
||||||
|
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
append_parameter(module=network[-1], parameter_list=parameter_neuron_b)
|
||||||
|
|
||||||
|
elif type_of_neuron_b == 2:
|
||||||
|
|
||||||
|
network.append(torch.nn.ReLU())
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
network.append(L1NormLayer())
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
network.append(
|
||||||
|
torch.nn.Conv2d(
|
||||||
|
in_channels=test_image.shape[1],
|
||||||
|
out_channels=number_of_neurons_b,
|
||||||
|
kernel_size=(1, 1),
|
||||||
|
stride=(1, 1),
|
||||||
|
padding=(0, 0),
|
||||||
|
bias=bias_norm_neuron_b,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Init the cnn top layers 1x1 conv2d layers
|
||||||
|
for name, param in network[-1].named_parameters():
|
||||||
|
with torch.no_grad():
|
||||||
|
if name == "bias":
|
||||||
|
param.data *= 0
|
||||||
|
param.data += (torch.rand_like(param) - 0.5) / 10000.0
|
||||||
|
if name == "weight":
|
||||||
|
if number_of_neurons_b >= param.shape[0]:
|
||||||
|
assert param.shape[-2] == 1
|
||||||
|
assert param.shape[-1] == 1
|
||||||
|
param.data[: param.shape[0], : param.shape[0], 0, 0] = (
|
||||||
|
torch.eye(
|
||||||
|
param.shape[0], dtype=param.dtype, device=param.device
|
||||||
|
)
|
||||||
|
)
|
||||||
|
param.data[param.shape[0] :, :, 0, 0] = 0
|
||||||
|
param.data[:, param.shape[0] :, 0, 0] = 0
|
||||||
|
param.data += (torch.rand_like(param) - 0.5) / 10000.0
|
||||||
|
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
append_parameter(module=network[-1], parameter_list=parameter_neuron_b)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
(type_of_neuron_b == 0)
|
||||||
|
or (type_of_neuron_b == 1)
|
||||||
|
or (type_of_neuron_b == 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
if (test_image.shape[-1] > 1) or (test_image.shape[-2] > 1):
|
||||||
|
if (batch_norm_neuron_b) and (type_of_neuron_b > 0):
|
||||||
|
network.append(
|
||||||
|
torch.nn.BatchNorm2d(
|
||||||
|
num_features=test_image.shape[1],
|
||||||
|
device=device,
|
||||||
|
momentum=momentum,
|
||||||
|
track_running_stats=track_running_stats,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
append_parameter(module=network[-1], parameter_list=parameter_batchnorm2d)
|
||||||
|
|
||||||
|
return test_image
|
163
tools/get_the_data.py
Normal file
163
tools/get_the_data.py
Normal file
|
@ -0,0 +1,163 @@
|
||||||
|
import torch
|
||||||
|
import torchvision # type: ignore
|
||||||
|
from tools.data_loader import data_loader
|
||||||
|
|
||||||
|
from torchvision.transforms import v2 # type: ignore
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def get_the_data(
|
||||||
|
dataset: str,
|
||||||
|
batch_size_train: int,
|
||||||
|
batch_size_test: int,
|
||||||
|
torch_device: torch.device,
|
||||||
|
input_dim_x: int,
|
||||||
|
input_dim_y: int,
|
||||||
|
flip_p: float = 0.5,
|
||||||
|
jitter_brightness: float = 0.5,
|
||||||
|
jitter_contrast: float = 0.1,
|
||||||
|
jitter_saturation: float = 0.1,
|
||||||
|
jitter_hue: float = 0.15,
|
||||||
|
da_auto_mode: bool = False,
|
||||||
|
disable_da: bool = False,
|
||||||
|
) -> tuple[
|
||||||
|
torch.utils.data.dataloader.DataLoader,
|
||||||
|
torch.utils.data.dataloader.DataLoader,
|
||||||
|
torchvision.transforms.Compose,
|
||||||
|
torchvision.transforms.Compose,
|
||||||
|
]:
|
||||||
|
if dataset == "MNIST":
|
||||||
|
tv_dataset_train = torchvision.datasets.MNIST(
|
||||||
|
root="data", train=True, download=True
|
||||||
|
)
|
||||||
|
tv_dataset_test = torchvision.datasets.MNIST(
|
||||||
|
root="data", train=False, download=True
|
||||||
|
)
|
||||||
|
elif dataset == "FashionMNIST":
|
||||||
|
tv_dataset_train = torchvision.datasets.FashionMNIST(
|
||||||
|
root="data", train=True, download=True
|
||||||
|
)
|
||||||
|
tv_dataset_test = torchvision.datasets.FashionMNIST(
|
||||||
|
root="data", train=False, download=True
|
||||||
|
)
|
||||||
|
elif dataset == "CIFAR10":
|
||||||
|
tv_dataset_train = torchvision.datasets.CIFAR10(
|
||||||
|
root="data", train=True, download=True
|
||||||
|
)
|
||||||
|
tv_dataset_test = torchvision.datasets.CIFAR10(
|
||||||
|
root="data", train=False, download=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("This dataset is not implemented.")
|
||||||
|
|
||||||
|
def seed_worker(worker_id):
|
||||||
|
worker_seed = torch.initial_seed() % 2**32
|
||||||
|
np.random.seed(worker_seed)
|
||||||
|
torch.random.seed(worker_seed)
|
||||||
|
|
||||||
|
g = torch.Generator()
|
||||||
|
g.manual_seed(0)
|
||||||
|
|
||||||
|
if dataset == "MNIST" or dataset == "FashionMNIST":
|
||||||
|
|
||||||
|
train_dataloader = data_loader(
|
||||||
|
torch_device=torch_device,
|
||||||
|
batch_size=batch_size_train,
|
||||||
|
pattern=tv_dataset_train.data,
|
||||||
|
labels=tv_dataset_train.targets,
|
||||||
|
shuffle=True,
|
||||||
|
worker_init_fn=seed_worker,
|
||||||
|
generator=g,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_dataloader = data_loader(
|
||||||
|
torch_device=torch_device,
|
||||||
|
batch_size=batch_size_test,
|
||||||
|
pattern=tv_dataset_test.data,
|
||||||
|
labels=tv_dataset_test.targets,
|
||||||
|
shuffle=False,
|
||||||
|
worker_init_fn=seed_worker,
|
||||||
|
generator=g,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Data augmentation filter
|
||||||
|
test_processing_chain = torchvision.transforms.Compose(
|
||||||
|
transforms=[torchvision.transforms.CenterCrop((input_dim_x, input_dim_y))],
|
||||||
|
)
|
||||||
|
if disable_da:
|
||||||
|
train_processing_chain = torchvision.transforms.Compose(
|
||||||
|
transforms=[
|
||||||
|
torchvision.transforms.CenterCrop((input_dim_x, input_dim_y))
|
||||||
|
],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train_processing_chain = torchvision.transforms.Compose(
|
||||||
|
transforms=[
|
||||||
|
torchvision.transforms.RandomCrop((input_dim_x, input_dim_y))
|
||||||
|
],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
|
||||||
|
train_dataloader = data_loader(
|
||||||
|
torch_device=torch_device,
|
||||||
|
batch_size=batch_size_train,
|
||||||
|
pattern=torch.tensor(tv_dataset_train.data).movedim(-1, 1),
|
||||||
|
labels=torch.tensor(tv_dataset_train.targets),
|
||||||
|
shuffle=True,
|
||||||
|
worker_init_fn=seed_worker,
|
||||||
|
generator=g,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_dataloader = data_loader(
|
||||||
|
torch_device=torch_device,
|
||||||
|
batch_size=batch_size_test,
|
||||||
|
pattern=torch.tensor(tv_dataset_test.data).movedim(-1, 1),
|
||||||
|
labels=torch.tensor(tv_dataset_test.targets),
|
||||||
|
shuffle=False,
|
||||||
|
worker_init_fn=seed_worker,
|
||||||
|
generator=g,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Data augmentation filter
|
||||||
|
test_processing_chain = torchvision.transforms.Compose(
|
||||||
|
transforms=[torchvision.transforms.CenterCrop((input_dim_x, input_dim_y))],
|
||||||
|
)
|
||||||
|
|
||||||
|
if disable_da:
|
||||||
|
train_processing_chain = torchvision.transforms.Compose(
|
||||||
|
transforms=[
|
||||||
|
torchvision.transforms.CenterCrop((input_dim_x, input_dim_y))
|
||||||
|
],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if da_auto_mode:
|
||||||
|
train_processing_chain = torchvision.transforms.Compose(
|
||||||
|
transforms=[
|
||||||
|
v2.AutoAugment(
|
||||||
|
policy=torchvision.transforms.AutoAugmentPolicy(
|
||||||
|
v2.AutoAugmentPolicy.CIFAR10
|
||||||
|
)
|
||||||
|
),
|
||||||
|
torchvision.transforms.CenterCrop((input_dim_x, input_dim_y)),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train_processing_chain = torchvision.transforms.Compose(
|
||||||
|
transforms=[
|
||||||
|
torchvision.transforms.RandomCrop((input_dim_x, input_dim_y)),
|
||||||
|
torchvision.transforms.RandomHorizontalFlip(p=flip_p),
|
||||||
|
torchvision.transforms.ColorJitter(
|
||||||
|
brightness=jitter_brightness,
|
||||||
|
contrast=jitter_contrast,
|
||||||
|
saturation=jitter_saturation,
|
||||||
|
hue=jitter_hue,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
train_dataloader,
|
||||||
|
test_dataloader,
|
||||||
|
train_processing_chain,
|
||||||
|
test_processing_chain,
|
||||||
|
)
|
525
tools/make_network.py
Normal file
525
tools/make_network.py
Normal file
|
@ -0,0 +1,525 @@
|
||||||
|
import torch
|
||||||
|
from tools.append_block import append_block
|
||||||
|
from tools.L1NormLayer import L1NormLayer
|
||||||
|
from tools.NNMF2d import NNMF2d
|
||||||
|
from tools.append_parameter import append_parameter
|
||||||
|
|
||||||
|
import json
|
||||||
|
from jsmin import jsmin
|
||||||
|
|
||||||
|
|
||||||
|
def make_network(
|
||||||
|
input_dim_x: int,
|
||||||
|
input_dim_y: int,
|
||||||
|
input_number_of_channel: int,
|
||||||
|
device: torch.device,
|
||||||
|
config_network_filename: str = "config_network.json",
|
||||||
|
) -> tuple[
|
||||||
|
torch.nn.Sequential,
|
||||||
|
list[list[torch.nn.parameter.Parameter]],
|
||||||
|
list[str],
|
||||||
|
]:
|
||||||
|
|
||||||
|
with open(config_network_filename, "r") as file:
|
||||||
|
minified = jsmin(file.read())
|
||||||
|
config_network = json.loads(minified)
|
||||||
|
|
||||||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||||||
|
list(config_network["number_of_neurons_b"])
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||||||
|
list(config_network["kernel_size_conv"])
|
||||||
|
)
|
||||||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||||||
|
list(config_network["stride_conv"])
|
||||||
|
)
|
||||||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||||||
|
list(config_network["padding_conv"])
|
||||||
|
)
|
||||||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||||||
|
list(config_network["dilation_conv"])
|
||||||
|
)
|
||||||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||||||
|
list(config_network["kernel_size_pool"])
|
||||||
|
)
|
||||||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||||||
|
list(config_network["stride_pool"])
|
||||||
|
)
|
||||||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||||||
|
list(config_network["padding_pool"])
|
||||||
|
)
|
||||||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||||||
|
list(config_network["dilation_pool"])
|
||||||
|
)
|
||||||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||||||
|
list(config_network["type_of_pooling"])
|
||||||
|
)
|
||||||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||||||
|
list(config_network["local_learning_pooling"])
|
||||||
|
)
|
||||||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||||||
|
list(config_network["local_learning_use_kl_pooling"])
|
||||||
|
)
|
||||||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||||||
|
list(config_network["type_of_neuron_a"])
|
||||||
|
)
|
||||||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||||||
|
list(config_network["type_of_neuron_b"])
|
||||||
|
)
|
||||||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||||||
|
list(config_network["batch_norm_neuron_a"])
|
||||||
|
)
|
||||||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||||||
|
list(config_network["batch_norm_neuron_b"])
|
||||||
|
)
|
||||||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||||||
|
list(config_network["bias_norm_neuron_a"])
|
||||||
|
)
|
||||||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||||||
|
list(config_network["bias_norm_neuron_b"])
|
||||||
|
)
|
||||||
|
|
||||||
|
parameter_neuron_b: list[torch.nn.parameter.Parameter] = []
|
||||||
|
parameter_neuron_a: list[torch.nn.parameter.Parameter] = []
|
||||||
|
parameter_batchnorm2d: list[torch.nn.parameter.Parameter] = []
|
||||||
|
parameter_neuron_pool: list[torch.nn.parameter.Parameter] = []
|
||||||
|
|
||||||
|
test_image = torch.ones(
|
||||||
|
(1, input_number_of_channel, input_dim_x, input_dim_y), device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
network = torch.nn.Sequential()
|
||||||
|
network = network.to(device)
|
||||||
|
|
||||||
|
epsilon: float | None = None
|
||||||
|
|
||||||
|
if isinstance(config_network["epsilon"], float):
|
||||||
|
epsilon = float(config_network["epsilon"])
|
||||||
|
|
||||||
|
for block_id in range(0, len(list(config_network["number_of_neurons_a"]))):
|
||||||
|
|
||||||
|
test_image = append_block(
|
||||||
|
network=network,
|
||||||
|
number_of_neurons_a=int(
|
||||||
|
list(config_network["number_of_neurons_a"])[block_id]
|
||||||
|
),
|
||||||
|
number_of_neurons_b=int(
|
||||||
|
list(config_network["number_of_neurons_b"])[block_id]
|
||||||
|
),
|
||||||
|
test_image=test_image,
|
||||||
|
dilation=list(list(config_network["dilation_conv"])[block_id]),
|
||||||
|
padding=list(list(config_network["padding_conv"])[block_id]),
|
||||||
|
stride=list(list(config_network["stride_conv"])[block_id]),
|
||||||
|
kernel_size=list(list(config_network["kernel_size_conv"])[block_id]),
|
||||||
|
epsilon=epsilon,
|
||||||
|
iterations=int(config_network["iterations"]),
|
||||||
|
device=device,
|
||||||
|
parameter_neuron_a=parameter_neuron_a,
|
||||||
|
parameter_neuron_b=parameter_neuron_b,
|
||||||
|
parameter_batchnorm2d=parameter_batchnorm2d,
|
||||||
|
type_of_neuron_a=int(list(config_network["type_of_neuron_a"])[block_id]),
|
||||||
|
type_of_neuron_b=int(list(config_network["type_of_neuron_b"])[block_id]),
|
||||||
|
batch_norm_neuron_a=bool(
|
||||||
|
list(config_network["batch_norm_neuron_a"])[block_id]
|
||||||
|
),
|
||||||
|
batch_norm_neuron_b=bool(
|
||||||
|
list(config_network["batch_norm_neuron_b"])[block_id]
|
||||||
|
),
|
||||||
|
bias_norm_neuron_a=bool(
|
||||||
|
list(config_network["bias_norm_neuron_a"])[block_id]
|
||||||
|
),
|
||||||
|
bias_norm_neuron_b=bool(
|
||||||
|
list(config_network["bias_norm_neuron_b"])[block_id]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if (int(list(list(config_network["kernel_size_pool"])[block_id])[0]) > 0) and (
|
||||||
|
(int(list(list(config_network["kernel_size_pool"])[block_id])[1]) > 0)
|
||||||
|
):
|
||||||
|
if int(list(config_network["type_of_pooling"])[block_id]) == 0:
|
||||||
|
pass
|
||||||
|
|
||||||
|
elif int(list(config_network["type_of_pooling"])[block_id]) == 1:
|
||||||
|
network.append(
|
||||||
|
torch.nn.AvgPool2d(
|
||||||
|
kernel_size=(
|
||||||
|
(
|
||||||
|
int(
|
||||||
|
list(
|
||||||
|
list(config_network["kernel_size_pool"])[
|
||||||
|
block_id
|
||||||
|
]
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
(
|
||||||
|
int(
|
||||||
|
list(
|
||||||
|
list(config_network["kernel_size_pool"])[
|
||||||
|
block_id
|
||||||
|
]
|
||||||
|
)[1]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
stride=(
|
||||||
|
(
|
||||||
|
int(
|
||||||
|
list(list(config_network["stride_pool"])[block_id])[
|
||||||
|
0
|
||||||
|
]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
(
|
||||||
|
int(
|
||||||
|
list(list(config_network["stride_pool"])[block_id])[
|
||||||
|
1
|
||||||
|
]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
padding=(
|
||||||
|
(
|
||||||
|
int(
|
||||||
|
list(
|
||||||
|
list(config_network["padding_pool"])[block_id]
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
(
|
||||||
|
int(
|
||||||
|
list(
|
||||||
|
list(config_network["padding_pool"])[block_id]
|
||||||
|
)[1]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
elif int(list(config_network["type_of_pooling"])[block_id]) == 2:
|
||||||
|
network.append(
|
||||||
|
torch.nn.MaxPool2d(
|
||||||
|
kernel_size=(
|
||||||
|
(
|
||||||
|
int(
|
||||||
|
list(
|
||||||
|
list(config_network["kernel_size_pool"])[
|
||||||
|
block_id
|
||||||
|
]
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
(
|
||||||
|
int(
|
||||||
|
list(
|
||||||
|
list(config_network["kernel_size_pool"])[
|
||||||
|
block_id
|
||||||
|
]
|
||||||
|
)[1]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
stride=(
|
||||||
|
(
|
||||||
|
int(
|
||||||
|
list(list(config_network["stride_pool"])[block_id])[
|
||||||
|
0
|
||||||
|
]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
(
|
||||||
|
int(
|
||||||
|
list(list(config_network["stride_pool"])[block_id])[
|
||||||
|
1
|
||||||
|
]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
padding=(
|
||||||
|
(
|
||||||
|
int(
|
||||||
|
list(
|
||||||
|
list(config_network["padding_pool"])[block_id]
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
(
|
||||||
|
int(
|
||||||
|
list(
|
||||||
|
list(config_network["padding_pool"])[block_id]
|
||||||
|
)[1]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
elif (int(list(config_network["type_of_pooling"])[block_id]) == 3) or (
|
||||||
|
int(list(config_network["type_of_pooling"])[block_id]) == 4
|
||||||
|
):
|
||||||
|
|
||||||
|
network.append(torch.nn.ReLU())
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
mock_output = (
|
||||||
|
torch.nn.functional.conv2d(
|
||||||
|
torch.zeros(
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
test_image.shape[2],
|
||||||
|
test_image.shape[3],
|
||||||
|
),
|
||||||
|
torch.zeros(
|
||||||
|
(
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
int(
|
||||||
|
list(
|
||||||
|
list(config_network["kernel_size_pool"])[
|
||||||
|
block_id
|
||||||
|
]
|
||||||
|
)[0]
|
||||||
|
),
|
||||||
|
int(
|
||||||
|
list(
|
||||||
|
list(config_network["kernel_size_pool"])[
|
||||||
|
block_id
|
||||||
|
]
|
||||||
|
)[1]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
stride=(
|
||||||
|
(
|
||||||
|
int(
|
||||||
|
list(list(config_network["stride_pool"])[block_id])[
|
||||||
|
0
|
||||||
|
]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
(
|
||||||
|
int(
|
||||||
|
list(list(config_network["stride_pool"])[block_id])[
|
||||||
|
1
|
||||||
|
]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
padding=(
|
||||||
|
(
|
||||||
|
int(
|
||||||
|
list(
|
||||||
|
list(config_network["padding_pool"])[block_id]
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
(
|
||||||
|
int(
|
||||||
|
list(
|
||||||
|
list(config_network["padding_pool"])[block_id]
|
||||||
|
)[1]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
dilation=(
|
||||||
|
(
|
||||||
|
int(
|
||||||
|
list(
|
||||||
|
list(config_network["dilation_pool"])[block_id]
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
(
|
||||||
|
int(
|
||||||
|
list(
|
||||||
|
list(config_network["dilation_pool"])[block_id]
|
||||||
|
)[1]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.squeeze(0)
|
||||||
|
.squeeze(0)
|
||||||
|
)
|
||||||
|
|
||||||
|
network.append(
|
||||||
|
torch.nn.Unfold(
|
||||||
|
kernel_size=(
|
||||||
|
int(
|
||||||
|
list(
|
||||||
|
list(config_network["kernel_size_pool"])[block_id]
|
||||||
|
)[0]
|
||||||
|
),
|
||||||
|
int(
|
||||||
|
list(
|
||||||
|
list(config_network["kernel_size_pool"])[block_id]
|
||||||
|
)[1]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
stride=(
|
||||||
|
(
|
||||||
|
int(
|
||||||
|
list(list(config_network["stride_pool"])[block_id])[
|
||||||
|
0
|
||||||
|
]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
(
|
||||||
|
int(
|
||||||
|
list(list(config_network["stride_pool"])[block_id])[
|
||||||
|
1
|
||||||
|
]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
padding=(
|
||||||
|
(
|
||||||
|
int(
|
||||||
|
list(
|
||||||
|
list(config_network["padding_pool"])[block_id]
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
(
|
||||||
|
int(
|
||||||
|
list(
|
||||||
|
list(config_network["padding_pool"])[block_id]
|
||||||
|
)[1]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
dilation=(
|
||||||
|
(
|
||||||
|
int(
|
||||||
|
list(
|
||||||
|
list(config_network["dilation_pool"])[block_id]
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
(
|
||||||
|
int(
|
||||||
|
list(
|
||||||
|
list(config_network["dilation_pool"])[block_id]
|
||||||
|
)[1]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
network.append(
|
||||||
|
torch.nn.Fold(
|
||||||
|
output_size=mock_output.shape,
|
||||||
|
kernel_size=(1, 1),
|
||||||
|
dilation=1,
|
||||||
|
padding=0,
|
||||||
|
stride=1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
network.append(L1NormLayer())
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
if int(list(config_network["type_of_pooling"])[block_id]) == 3:
|
||||||
|
network.append(
|
||||||
|
torch.nn.Conv2d(
|
||||||
|
in_channels=test_image.shape[1],
|
||||||
|
out_channels=test_image.shape[1]
|
||||||
|
// (
|
||||||
|
int(
|
||||||
|
list(
|
||||||
|
list(config_network["kernel_size_pool"])[
|
||||||
|
block_id
|
||||||
|
]
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
|
* int(
|
||||||
|
list(
|
||||||
|
list(config_network["kernel_size_pool"])[
|
||||||
|
block_id
|
||||||
|
]
|
||||||
|
)[1]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
kernel_size=(1, 1),
|
||||||
|
bias=False,
|
||||||
|
).to(device)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
network.append(
|
||||||
|
NNMF2d(
|
||||||
|
in_channels=test_image.shape[1],
|
||||||
|
out_channels=test_image.shape[1]
|
||||||
|
// (
|
||||||
|
int(
|
||||||
|
list(
|
||||||
|
list(config_network["kernel_size_pool"])[
|
||||||
|
block_id
|
||||||
|
]
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
|
* int(
|
||||||
|
list(
|
||||||
|
list(config_network["kernel_size_pool"])[
|
||||||
|
block_id
|
||||||
|
]
|
||||||
|
)[1]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
epsilon=epsilon,
|
||||||
|
local_learning=bool(
|
||||||
|
list(config_network["local_learning_pooling"])[block_id]
|
||||||
|
),
|
||||||
|
local_learning_kl=bool(
|
||||||
|
list(config_network["local_learning_use_kl_pooling"])[
|
||||||
|
block_id
|
||||||
|
]
|
||||||
|
),
|
||||||
|
).to(device)
|
||||||
|
)
|
||||||
|
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
append_parameter(
|
||||||
|
module=network[-1], parameter_list=parameter_neuron_pool
|
||||||
|
)
|
||||||
|
|
||||||
|
network.append(
|
||||||
|
torch.nn.BatchNorm2d(
|
||||||
|
num_features=test_image.shape[1],
|
||||||
|
device=device,
|
||||||
|
momentum=0.1,
|
||||||
|
track_running_stats=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
append_parameter(
|
||||||
|
module=network[-1], parameter_list=parameter_batchnorm2d
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert int(list(config_network["type_of_pooling"])[block_id]) > 4
|
||||||
|
network.append(torch.nn.Softmax(dim=1))
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
network.append(torch.nn.Flatten())
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
parameters: list[list[torch.nn.parameter.Parameter]] = [
|
||||||
|
parameter_neuron_a,
|
||||||
|
parameter_neuron_b,
|
||||||
|
parameter_batchnorm2d,
|
||||||
|
parameter_neuron_pool,
|
||||||
|
]
|
||||||
|
|
||||||
|
name_list: list[str] = ["neuron a", "neuron b", "batchnorm2d", "neuron pool"]
|
||||||
|
|
||||||
|
return (
|
||||||
|
network,
|
||||||
|
parameters,
|
||||||
|
name_list,
|
||||||
|
)
|
231
tools/run_network_train.py
Normal file
231
tools/run_network_train.py
Normal file
|
@ -0,0 +1,231 @@
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import json
|
||||||
|
from jsmin import jsmin
|
||||||
|
import os
|
||||||
|
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tools.make_network import make_network
|
||||||
|
from tools.get_the_data import get_the_data
|
||||||
|
from tools.loss_function import loss_function
|
||||||
|
from tools.make_optimize import make_optimize
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
rand_seed: int = 21,
|
||||||
|
only_print_network: bool = False,
|
||||||
|
config_network_filename: str = "config_network.json",
|
||||||
|
config_data_filename: str = "config_data.json",
|
||||||
|
config_lr_parameter_filename: str = "config_lr_parameter.json",
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
os.makedirs("Models", exist_ok=True)
|
||||||
|
|
||||||
|
device: torch.device = (
|
||||||
|
torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
)
|
||||||
|
torch.set_default_dtype(torch.float32)
|
||||||
|
|
||||||
|
# Some parameters
|
||||||
|
with open(config_data_filename, "r") as file:
|
||||||
|
minified = jsmin(file.read())
|
||||||
|
config_data = json.loads(minified)
|
||||||
|
|
||||||
|
with open(config_lr_parameter_filename, "r") as file:
|
||||||
|
minified = jsmin(file.read())
|
||||||
|
config_lr_parameter = json.loads(minified)
|
||||||
|
|
||||||
|
torch.manual_seed(rand_seed)
|
||||||
|
torch.cuda.manual_seed(rand_seed)
|
||||||
|
np.random.seed(rand_seed)
|
||||||
|
|
||||||
|
if (
|
||||||
|
str(config_data["dataset"]) == "MNIST"
|
||||||
|
or str(config_data["dataset"]) == "FashionMNIST"
|
||||||
|
):
|
||||||
|
input_number_of_channel: int = 1
|
||||||
|
input_dim_x: int = 24
|
||||||
|
input_dim_y: int = 24
|
||||||
|
else:
|
||||||
|
input_number_of_channel = 3
|
||||||
|
input_dim_x = 28
|
||||||
|
input_dim_y = 28
|
||||||
|
|
||||||
|
train_dataloader, test_dataloader, train_processing_chain, test_processing_chain = (
|
||||||
|
get_the_data(
|
||||||
|
str(config_data["dataset"]),
|
||||||
|
int(config_data["batch_size_train"]),
|
||||||
|
int(config_data["batch_size_test"]),
|
||||||
|
device,
|
||||||
|
input_dim_x,
|
||||||
|
input_dim_y,
|
||||||
|
flip_p=float(config_data["flip_p"]),
|
||||||
|
jitter_brightness=float(config_data["jitter_brightness"]),
|
||||||
|
jitter_contrast=float(config_data["jitter_contrast"]),
|
||||||
|
jitter_saturation=float(config_data["jitter_saturation"]),
|
||||||
|
jitter_hue=float(config_data["jitter_hue"]),
|
||||||
|
da_auto_mode=bool(config_data["da_auto_mode"]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
network,
|
||||||
|
parameters,
|
||||||
|
name_list,
|
||||||
|
) = make_network(
|
||||||
|
input_dim_x=input_dim_x,
|
||||||
|
input_dim_y=input_dim_y,
|
||||||
|
input_number_of_channel=input_number_of_channel,
|
||||||
|
device=device,
|
||||||
|
config_network_filename=config_network_filename,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(network)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("Information about used parameters:")
|
||||||
|
number_of_parameter: int = 0
|
||||||
|
for i, parameter_list in enumerate(parameters):
|
||||||
|
count_parameter: int = 0
|
||||||
|
for parameter_element in parameter_list:
|
||||||
|
count_parameter += parameter_element.numel()
|
||||||
|
print(f"{name_list[i]}: {count_parameter}")
|
||||||
|
number_of_parameter += count_parameter
|
||||||
|
print(f"total number of parameter: {number_of_parameter}")
|
||||||
|
|
||||||
|
if only_print_network:
|
||||||
|
exit()
|
||||||
|
|
||||||
|
(
|
||||||
|
optimizers,
|
||||||
|
lr_schedulers,
|
||||||
|
) = make_optimize(
|
||||||
|
parameters=parameters,
|
||||||
|
lr_initial=[
|
||||||
|
float(config_lr_parameter["lr_initial_neuron_a"]),
|
||||||
|
float(config_lr_parameter["lr_initial_neuron_b"]),
|
||||||
|
float(config_lr_parameter["lr_initial_norm"]),
|
||||||
|
float(config_lr_parameter["lr_initial_batchnorm2d"]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
my_string: str = f"seed_{rand_seed}"
|
||||||
|
default_path: str = f"{my_string}"
|
||||||
|
log_dir: str = f"log_{default_path}"
|
||||||
|
|
||||||
|
tb = SummaryWriter(log_dir=log_dir)
|
||||||
|
|
||||||
|
for epoch_id in range(0, int(config_lr_parameter["number_of_epoch"])):
|
||||||
|
print()
|
||||||
|
print(f"Epoch: {epoch_id}")
|
||||||
|
t_start: float = time.perf_counter()
|
||||||
|
|
||||||
|
train_loss: float = 0.0
|
||||||
|
train_correct: int = 0
|
||||||
|
train_number: int = 0
|
||||||
|
test_correct: int = 0
|
||||||
|
test_number: int = 0
|
||||||
|
|
||||||
|
# Switch the network into training mode
|
||||||
|
network.train()
|
||||||
|
|
||||||
|
# This runs in total for one epoch split up into mini-batches
|
||||||
|
for image, target in train_dataloader:
|
||||||
|
|
||||||
|
# Clean the gradient
|
||||||
|
for i in range(0, len(optimizers)):
|
||||||
|
if optimizers[i] is not None:
|
||||||
|
optimizers[i].zero_grad() # type: ignore
|
||||||
|
|
||||||
|
output = network(train_processing_chain(image))
|
||||||
|
|
||||||
|
loss = loss_function(
|
||||||
|
h=output,
|
||||||
|
labels=target,
|
||||||
|
number_of_output_neurons=output.shape[1],
|
||||||
|
loss_mode=int(config_lr_parameter["loss_mode"]),
|
||||||
|
loss_coeffs_mse=float(config_lr_parameter["loss_coeffs_mse"]),
|
||||||
|
loss_coeffs_kldiv=float(config_lr_parameter["loss_coeffs_kldiv"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert loss is not None
|
||||||
|
train_loss += loss.item()
|
||||||
|
train_correct += (output.argmax(dim=1) == target).sum().cpu().numpy()
|
||||||
|
train_number += target.shape[0]
|
||||||
|
|
||||||
|
# Calculate backprop
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# Update the parameter
|
||||||
|
# Clean the gradient
|
||||||
|
for i in range(0, len(optimizers)):
|
||||||
|
if optimizers[i] is not None:
|
||||||
|
optimizers[i].step() # type: ignore
|
||||||
|
|
||||||
|
perfomance_train_correct: float = 100.0 * train_correct / train_number
|
||||||
|
# Update the learning rate
|
||||||
|
for i in range(0, len(lr_schedulers)):
|
||||||
|
if lr_schedulers[i] is not None:
|
||||||
|
lr_schedulers[i].step(train_loss) # type: ignore
|
||||||
|
|
||||||
|
my_string = "Actual lr: "
|
||||||
|
for i in range(0, len(lr_schedulers)):
|
||||||
|
if lr_schedulers[i] is not None:
|
||||||
|
my_string += f" {lr_schedulers[i].get_last_lr()[0]:.4e} " # type: ignore
|
||||||
|
else:
|
||||||
|
my_string += " --- "
|
||||||
|
|
||||||
|
print(my_string)
|
||||||
|
t_training: float = time.perf_counter()
|
||||||
|
|
||||||
|
# Switch the network into evalution mode
|
||||||
|
network.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
|
||||||
|
for image, target in test_dataloader:
|
||||||
|
output = network(test_processing_chain(image))
|
||||||
|
|
||||||
|
test_correct += (output.argmax(dim=1) == target).sum().cpu().numpy()
|
||||||
|
test_number += target.shape[0]
|
||||||
|
|
||||||
|
t_testing = time.perf_counter()
|
||||||
|
|
||||||
|
perfomance_test_correct: float = 100.0 * test_correct / test_number
|
||||||
|
|
||||||
|
tb.add_scalar("Train Loss", train_loss / float(train_number), epoch_id)
|
||||||
|
tb.add_scalar("Train Number Correct", train_correct, epoch_id)
|
||||||
|
tb.add_scalar("Test Number Correct", test_correct, epoch_id)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Training: Loss={train_loss / float(train_number):.5f} Correct={perfomance_train_correct:.2f}%"
|
||||||
|
)
|
||||||
|
print(f"Testing: Correct={perfomance_test_correct:.2f}%")
|
||||||
|
print(
|
||||||
|
f"Time: Training={(t_training - t_start):.1f}sec, Testing={(t_testing - t_training):.1f}sec"
|
||||||
|
)
|
||||||
|
|
||||||
|
tb.flush()
|
||||||
|
|
||||||
|
lr_check: list[float] = []
|
||||||
|
for i in range(0, len(lr_schedulers)):
|
||||||
|
if lr_schedulers[i] is not None:
|
||||||
|
lr_check.append(lr_schedulers[i].get_last_lr()[0]) # type: ignore
|
||||||
|
|
||||||
|
lr_check_max = float(torch.tensor(lr_check).max())
|
||||||
|
|
||||||
|
if lr_check_max < float(config_lr_parameter["lr_limit"]):
|
||||||
|
torch.save(network, f"Models/Model_{default_path}.pt")
|
||||||
|
tb.close()
|
||||||
|
print("Done (lr_limit)")
|
||||||
|
return
|
||||||
|
|
||||||
|
torch.save(network, f"Models/Model_{default_path}.pt")
|
||||||
|
print()
|
||||||
|
|
||||||
|
tb.close()
|
||||||
|
print("Done (loop end)")
|
||||||
|
|
||||||
|
return
|
Loading…
Reference in a new issue