Add files via upload
This commit is contained in:
parent
e268200501
commit
b6c5c4d210
12 changed files with 1468 additions and 0 deletions
41
Functional2Layer.py
Normal file
41
Functional2Layer.py
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
import torch
|
||||||
|
from typing import Callable, Any
|
||||||
|
|
||||||
|
|
||||||
|
class Functional2Layer(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, func: Callable[..., torch.Tensor], *args: Any, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.func = func
|
||||||
|
self.args = args
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.func(input, *self.args, **self.kwargs)
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
func_name = (
|
||||||
|
self.func.__name__ if hasattr(self.func, "__name__") else str(self.func)
|
||||||
|
)
|
||||||
|
args_repr = ", ".join(map(repr, self.args))
|
||||||
|
kwargs_repr = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items())
|
||||||
|
return f"func={func_name}, args=({args_repr}), kwargs={{{kwargs_repr}}}"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
print("Permute Example")
|
||||||
|
test_layer_permute = Functional2Layer(func=torch.permute, dims=(0, 2, 3, 1))
|
||||||
|
input = torch.zeros((10, 11, 12, 13))
|
||||||
|
output = test_layer_permute(input)
|
||||||
|
print(input.shape)
|
||||||
|
print(output.shape)
|
||||||
|
print(test_layer_permute)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("Clamp Example")
|
||||||
|
test_layer_clamp = Functional2Layer(func=torch.clamp, min=5, max=100)
|
||||||
|
output = test_layer_permute(input)
|
||||||
|
print(output[0, 0, 0, 0])
|
||||||
|
print(test_layer_clamp)
|
277
NNMF2dGrouped.py
Normal file
277
NNMF2dGrouped.py
Normal file
|
@ -0,0 +1,277 @@
|
||||||
|
import torch
|
||||||
|
from non_linear_weigth_function import non_linear_weigth_function
|
||||||
|
|
||||||
|
|
||||||
|
class NNMF2dGrouped(torch.nn.Module):
|
||||||
|
|
||||||
|
in_channels: int
|
||||||
|
out_channels: int
|
||||||
|
weight: torch.Tensor
|
||||||
|
iterations: int
|
||||||
|
epsilon: float | None
|
||||||
|
init_min: float
|
||||||
|
init_max: float
|
||||||
|
beta: torch.Tensor | None
|
||||||
|
positive_function_type: int
|
||||||
|
local_learning: bool
|
||||||
|
local_learning_kl: bool
|
||||||
|
groups: int
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
groups: int = 1,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
iterations: int = 20,
|
||||||
|
epsilon: float | None = None,
|
||||||
|
init_min: float = 0.0,
|
||||||
|
init_max: float = 1.0,
|
||||||
|
beta: float | None = None,
|
||||||
|
positive_function_type: int = 0,
|
||||||
|
local_learning: bool = False,
|
||||||
|
local_learning_kl: bool = False,
|
||||||
|
) -> None:
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.positive_function_type = positive_function_type
|
||||||
|
self.init_min = init_min
|
||||||
|
self.init_max = init_max
|
||||||
|
|
||||||
|
self.groups = groups
|
||||||
|
assert (
|
||||||
|
in_channels % self.groups == 0
|
||||||
|
), f"Can't divide without rest {in_channels} / {self.groups}"
|
||||||
|
self.in_channels = in_channels // self.groups
|
||||||
|
assert (
|
||||||
|
out_channels % self.groups == 0
|
||||||
|
), f"Can't divide without rest {out_channels} / {self.groups}"
|
||||||
|
self.out_channels = out_channels // self.groups
|
||||||
|
|
||||||
|
self.iterations = iterations
|
||||||
|
self.local_learning = local_learning
|
||||||
|
self.local_learning_kl = local_learning_kl
|
||||||
|
|
||||||
|
self.weight = torch.nn.parameter.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
(self.groups, self.out_channels, self.in_channels), **factory_kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if beta is not None:
|
||||||
|
self.beta = torch.nn.parameter.Parameter(torch.empty((1), **factory_kwargs))
|
||||||
|
self.beta.data[0] = beta
|
||||||
|
else:
|
||||||
|
self.beta = None
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
self.functional_nnmf2d_grouped = FunctionalNNMF2dGrouped.apply
|
||||||
|
|
||||||
|
self.epsilon = epsilon
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
s: str = f"{self.in_channels}, {self.out_channels}"
|
||||||
|
|
||||||
|
if self.epsilon is not None:
|
||||||
|
s += f", epsilon={self.epsilon}"
|
||||||
|
s += f", pfunctype={self.positive_function_type}"
|
||||||
|
s += f", local_learning={self.local_learning}"
|
||||||
|
s += f", groups={self.groups}"
|
||||||
|
|
||||||
|
if self.local_learning:
|
||||||
|
s += f", local_learning_kl={self.local_learning_kl}"
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
def reset_parameters(self) -> None:
|
||||||
|
torch.nn.init.uniform_(self.weight, a=self.init_min, b=self.init_max)
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
positive_weights = non_linear_weigth_function(
|
||||||
|
self.weight, self.beta, self.positive_function_type
|
||||||
|
)
|
||||||
|
positive_weights = positive_weights / (
|
||||||
|
positive_weights.sum(dim=-1, keepdim=True) + 10e-20
|
||||||
|
)
|
||||||
|
assert self.groups * self.in_channels == input.shape[1]
|
||||||
|
|
||||||
|
input = input.reshape(
|
||||||
|
(
|
||||||
|
input.shape[0],
|
||||||
|
self.groups,
|
||||||
|
self.in_channels,
|
||||||
|
input.shape[-2],
|
||||||
|
input.shape[-1],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
input = input / (input.sum(dim=2, keepdim=True) + 10e-20)
|
||||||
|
|
||||||
|
h_dyn = self.functional_nnmf2d_grouped(
|
||||||
|
input,
|
||||||
|
positive_weights,
|
||||||
|
self.out_channels,
|
||||||
|
self.iterations,
|
||||||
|
self.epsilon,
|
||||||
|
self.local_learning,
|
||||||
|
self.local_learning_kl,
|
||||||
|
)
|
||||||
|
|
||||||
|
h_dyn = h_dyn.reshape(
|
||||||
|
(
|
||||||
|
h_dyn.shape[0],
|
||||||
|
h_dyn.shape[1] * h_dyn.shape[2],
|
||||||
|
h_dyn.shape[3],
|
||||||
|
h_dyn.shape[4],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
h_dyn = h_dyn / (h_dyn.sum(dim=1, keepdim=True) + 10e-20)
|
||||||
|
|
||||||
|
return h_dyn
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def grouped_linear_einsum_h_weights(h, weights):
|
||||||
|
return torch.einsum("bgoxy,goi->bgixy", h, weights)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def grouped_linear_einsum_reconstruction_weights(reconstruction, weights):
|
||||||
|
return torch.einsum("bgixy,goi->bgoxy", reconstruction, weights)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def grouped_linear_einsum_h_input(h, reconstruction):
|
||||||
|
return torch.einsum("bgoxy,bgixy->goi", h, reconstruction)
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionalNNMF2dGrouped(torch.autograd.Function):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward( # type: ignore
|
||||||
|
ctx,
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
out_channels: int,
|
||||||
|
iterations: int,
|
||||||
|
epsilon: float | None,
|
||||||
|
local_learning: bool,
|
||||||
|
local_learning_kl: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# Prepare h
|
||||||
|
h = torch.full(
|
||||||
|
(
|
||||||
|
input.shape[0],
|
||||||
|
input.shape[1],
|
||||||
|
out_channels,
|
||||||
|
input.shape[-2],
|
||||||
|
input.shape[-1],
|
||||||
|
),
|
||||||
|
1.0 / float(out_channels),
|
||||||
|
device=input.device,
|
||||||
|
dtype=input.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(0, iterations):
|
||||||
|
|
||||||
|
reconstruction = grouped_linear_einsum_h_weights(h, weight)
|
||||||
|
reconstruction += 1e-20
|
||||||
|
|
||||||
|
if epsilon is None:
|
||||||
|
h *= grouped_linear_einsum_reconstruction_weights(
|
||||||
|
(input / reconstruction), weight
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
h *= 1 + epsilon * grouped_linear_einsum_reconstruction_weights(
|
||||||
|
(input / reconstruction), weight
|
||||||
|
)
|
||||||
|
h /= h.sum(2, keepdim=True) + 10e-20
|
||||||
|
|
||||||
|
# ###########################################################
|
||||||
|
# Save the necessary data for the backward pass
|
||||||
|
# ###########################################################
|
||||||
|
ctx.save_for_backward(input, weight, h)
|
||||||
|
ctx.local_learning = local_learning
|
||||||
|
ctx.local_learning_kl = local_learning_kl
|
||||||
|
|
||||||
|
assert torch.isfinite(h).all()
|
||||||
|
return h
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.autograd.function.once_differentiable
|
||||||
|
def backward(ctx, grad_output: torch.Tensor) -> tuple[ # type: ignore
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor | None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
]:
|
||||||
|
|
||||||
|
# ##############################################
|
||||||
|
# Default values
|
||||||
|
# ##############################################
|
||||||
|
grad_weight: torch.Tensor | None = None
|
||||||
|
|
||||||
|
# ##############################################
|
||||||
|
# Get the variables back
|
||||||
|
# ##############################################
|
||||||
|
(input, weight, h) = ctx.saved_tensors
|
||||||
|
|
||||||
|
# The back prop gradient
|
||||||
|
big_r = grouped_linear_einsum_h_weights(h, weight)
|
||||||
|
|
||||||
|
big_r_div = 1.0 / (big_r + 1e-20)
|
||||||
|
|
||||||
|
factor_x_div_r = input * big_r_div
|
||||||
|
|
||||||
|
grad_input: torch.Tensor = (
|
||||||
|
grouped_linear_einsum_h_weights(h * grad_output, weight) * big_r_div
|
||||||
|
)
|
||||||
|
|
||||||
|
del big_r_div
|
||||||
|
|
||||||
|
# The weight gradient
|
||||||
|
if ctx.local_learning is False:
|
||||||
|
del big_r
|
||||||
|
|
||||||
|
grad_weight = -grouped_linear_einsum_h_input(
|
||||||
|
h, (factor_x_div_r * grad_input)
|
||||||
|
)
|
||||||
|
|
||||||
|
grad_weight += grouped_linear_einsum_h_input(
|
||||||
|
(h * grad_output),
|
||||||
|
factor_x_div_r,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if ctx.local_learning_kl:
|
||||||
|
|
||||||
|
grad_weight = -grouped_linear_einsum_h_input(
|
||||||
|
h,
|
||||||
|
factor_x_div_r,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
grad_weight = -grouped_linear_einsum_h_input(
|
||||||
|
h,
|
||||||
|
(2 * (input - big_r)),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.isfinite(grad_input).all()
|
||||||
|
assert torch.isfinite(grad_weight).all()
|
||||||
|
|
||||||
|
return (
|
||||||
|
grad_input,
|
||||||
|
grad_weight,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
22
PositionalEncoding.py
Normal file
22
PositionalEncoding.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class PositionalEncoding(torch.nn.Module):
|
||||||
|
|
||||||
|
init_std: float
|
||||||
|
pos_embedding: torch.nn.Parameter
|
||||||
|
|
||||||
|
def __init__(self, dim: list[int], init_std: float = 0.2):
|
||||||
|
super().__init__()
|
||||||
|
self.init_std = init_std
|
||||||
|
assert len(dim) == 3
|
||||||
|
self.pos_embedding: torch.nn.Parameter = torch.nn.Parameter(
|
||||||
|
torch.randn(1, *dim)
|
||||||
|
)
|
||||||
|
self.init_parameters()
|
||||||
|
|
||||||
|
def init_parameters(self):
|
||||||
|
torch.nn.init.trunc_normal_(self.pos_embedding, std=self.init_std)
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor):
|
||||||
|
return input + self.pos_embedding
|
169
SequentialSplit.py
Normal file
169
SequentialSplit.py
Normal file
|
@ -0,0 +1,169 @@
|
||||||
|
import torch
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
|
||||||
|
class SequentialSplit(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
A PyTorch module that splits the processing path of a input tensor
|
||||||
|
and processes it through multiple torch.nn.Sequential segments,
|
||||||
|
then combines the outputs using a specified methods.
|
||||||
|
|
||||||
|
This module allows for creating split paths within a `torch.nn.Sequential`
|
||||||
|
model, making it possible to implement architectures with skip connections
|
||||||
|
or parallel paths without abandoning the sequential model structure.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
segments (torch.nn.Sequential[torch.nn.Sequential]): A list of sequential modules to
|
||||||
|
process the input tensor.
|
||||||
|
combine_func (Callable | None): A function to combine the outputs
|
||||||
|
from the segments.
|
||||||
|
dim (int | None): The dimension along which to concatenate
|
||||||
|
the outputs if `combine_func` is `torch.cat`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segments (torch.nn.Sequential[torch.nn.Sequential]): A torch.nn.Sequential
|
||||||
|
with a list of sequential modules to process the input tensor.
|
||||||
|
combine (str, optional): The method to combine the outputs.
|
||||||
|
"cat" for concatenation (default), "sum" for a summation,
|
||||||
|
or "func" to use a custom combine function.
|
||||||
|
dim (int | None, optional): The dimension along which to
|
||||||
|
concatenate the outputs if `combine` is "cat".
|
||||||
|
Defaults to 1.
|
||||||
|
combine_func (Callable | None, optional): A custom function
|
||||||
|
to combine the outputs if `combine` is "func".
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
A simple example for the `SequentialSplit` module with two sub-torch.nn.Sequential:
|
||||||
|
|
||||||
|
----- segment_a -----
|
||||||
|
main_Sequential ----| |---- main_Sequential
|
||||||
|
----- segment_b -----
|
||||||
|
|
||||||
|
segments = [segment_a, segment_b]
|
||||||
|
y_split = SequentialSplit(segments)
|
||||||
|
result = y_split(input_tensor)
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
forward(input: torch.Tensor) -> torch.Tensor:
|
||||||
|
Processes the input tensor through the segments and
|
||||||
|
combines the results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
segments: torch.nn.Sequential
|
||||||
|
combine_func: Callable
|
||||||
|
dim: int | None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
segments: torch.nn.Sequential,
|
||||||
|
combine: str = "cat", # "cat", "sum", "func",
|
||||||
|
dim: int | None = 1,
|
||||||
|
combine_func: Callable | None = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.segments = segments
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
self.combine = combine
|
||||||
|
|
||||||
|
if combine.upper() == "CAT":
|
||||||
|
self.combine_func = torch.cat
|
||||||
|
elif combine.upper() == "SUM":
|
||||||
|
self.combine_func = self.sum
|
||||||
|
self.dim = None
|
||||||
|
else:
|
||||||
|
assert combine_func is not None
|
||||||
|
self.combine_func = combine_func
|
||||||
|
|
||||||
|
def sum(self, input: list[torch.Tensor]) -> torch.Tensor | None:
|
||||||
|
|
||||||
|
if len(input) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(input) == 1:
|
||||||
|
return input[0]
|
||||||
|
|
||||||
|
output: torch.Tensor = input[0]
|
||||||
|
|
||||||
|
for i in range(1, len(input)):
|
||||||
|
output = output + input[i]
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
results: list[torch.Tensor] = []
|
||||||
|
for segment in self.segments:
|
||||||
|
results.append(segment(input))
|
||||||
|
|
||||||
|
if self.dim is None:
|
||||||
|
return self.combine_func(results)
|
||||||
|
else:
|
||||||
|
return self.combine_func(results, dim=self.dim)
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return self.combine
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
print("Example CAT")
|
||||||
|
strain_a = torch.nn.Sequential(torch.nn.Identity())
|
||||||
|
strain_b = torch.nn.Sequential(torch.nn.Identity())
|
||||||
|
strain_c = torch.nn.Sequential(torch.nn.Identity())
|
||||||
|
test_cat = SequentialSplit(
|
||||||
|
torch.nn.Sequential(strain_a, strain_b, strain_c), combine="cat", dim=2
|
||||||
|
)
|
||||||
|
print(test_cat)
|
||||||
|
input = torch.ones((10, 11, 12, 13))
|
||||||
|
output = test_cat(input)
|
||||||
|
print(input.shape)
|
||||||
|
print(output.shape)
|
||||||
|
print(input[0, 0, 0, 0])
|
||||||
|
print(output[0, 0, 0, 0])
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("Example SUM")
|
||||||
|
strain_a = torch.nn.Sequential(torch.nn.Identity())
|
||||||
|
strain_b = torch.nn.Sequential(torch.nn.Identity())
|
||||||
|
strain_c = torch.nn.Sequential(torch.nn.Identity())
|
||||||
|
test_sum = SequentialSplit(
|
||||||
|
torch.nn.Sequential(strain_a, strain_b, strain_c), combine="sum", dim=2
|
||||||
|
)
|
||||||
|
print(test_sum)
|
||||||
|
input = torch.ones((10, 11, 12, 13))
|
||||||
|
output = test_sum(input)
|
||||||
|
print(input.shape)
|
||||||
|
print(output.shape)
|
||||||
|
print(input[0, 0, 0, 0])
|
||||||
|
print(output[0, 0, 0, 0])
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("Example Labeling")
|
||||||
|
strain_a = torch.nn.Sequential()
|
||||||
|
strain_a.add_module("Label for first strain", torch.nn.Identity())
|
||||||
|
strain_b = torch.nn.Sequential()
|
||||||
|
strain_b.add_module("Label for second strain", torch.nn.Identity())
|
||||||
|
strain_c = torch.nn.Sequential()
|
||||||
|
strain_c.add_module("Label for third strain", torch.nn.Identity())
|
||||||
|
test_label = SequentialSplit(torch.nn.Sequential(strain_a, strain_b, strain_c))
|
||||||
|
print(test_label)
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("Example Get Parameter")
|
||||||
|
input = torch.ones((10, 11, 12, 13))
|
||||||
|
strain_a = torch.nn.Sequential()
|
||||||
|
strain_a.add_module("Identity", torch.nn.Identity())
|
||||||
|
strain_b = torch.nn.Sequential()
|
||||||
|
strain_b.add_module(
|
||||||
|
"Conv2d",
|
||||||
|
torch.nn.Conv2d(
|
||||||
|
in_channels=input.shape[1],
|
||||||
|
out_channels=input.shape[1],
|
||||||
|
kernel_size=(1, 1),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
test_parameter = SequentialSplit(torch.nn.Sequential(strain_a, strain_b))
|
||||||
|
print(test_parameter)
|
||||||
|
for name, param in test_parameter.named_parameters():
|
||||||
|
print(f"Parameter name: {name}, Shape: {param.shape}")
|
29
convert_log_to_numpy.py
Normal file
29
convert_log_to_numpy.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
|
||||||
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
|
|
||||||
|
from tensorboard.backend.event_processing import event_accumulator # type: ignore
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def get_data(path: str = "log_cnn"):
|
||||||
|
acc = event_accumulator.EventAccumulator(path)
|
||||||
|
acc.Reload()
|
||||||
|
|
||||||
|
which_scalar = "Test Number Correct"
|
||||||
|
te = acc.Scalars(which_scalar)
|
||||||
|
|
||||||
|
np_temp = np.zeros((len(te), 2))
|
||||||
|
|
||||||
|
for id in range(0, len(te)):
|
||||||
|
np_temp[id, 0] = te[id].step
|
||||||
|
np_temp[id, 1] = te[id].value
|
||||||
|
print(np_temp[:, 1]/100)
|
||||||
|
return np_temp
|
||||||
|
|
||||||
|
|
||||||
|
for path in glob.glob("log_*"):
|
||||||
|
print(path)
|
||||||
|
data = get_data(path)
|
||||||
|
np.save("data_" + path + ".npy", data)
|
31
data_loader.py
Normal file
31
data_loader.py
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def data_loader(
|
||||||
|
pattern: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
worker_init_fn,
|
||||||
|
generator,
|
||||||
|
batch_size: int = 128,
|
||||||
|
shuffle: bool = True,
|
||||||
|
torch_device: torch.device = torch.device("cpu"),
|
||||||
|
) -> torch.utils.data.dataloader.DataLoader:
|
||||||
|
|
||||||
|
assert pattern.ndim >= 3
|
||||||
|
|
||||||
|
pattern_storage: torch.Tensor = pattern.to(torch_device).type(torch.float32)
|
||||||
|
if pattern_storage.ndim == 3:
|
||||||
|
pattern_storage = pattern_storage.unsqueeze(1)
|
||||||
|
pattern_storage /= pattern_storage.max()
|
||||||
|
|
||||||
|
label_storage: torch.Tensor = labels.to(torch_device).type(torch.int64)
|
||||||
|
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
torch.utils.data.TensorDataset(pattern_storage, label_storage),
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=shuffle,
|
||||||
|
worker_init_fn=worker_init_fn,
|
||||||
|
generator=generator,
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataloader
|
147
get_the_data.py
Normal file
147
get_the_data.py
Normal file
|
@ -0,0 +1,147 @@
|
||||||
|
import torch
|
||||||
|
import torchvision # type: ignore
|
||||||
|
from data_loader import data_loader
|
||||||
|
|
||||||
|
from torchvision.transforms import v2 # type: ignore
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def get_the_data(
|
||||||
|
dataset: str,
|
||||||
|
batch_size_train: int,
|
||||||
|
batch_size_test: int,
|
||||||
|
torch_device: torch.device,
|
||||||
|
input_dim_x: int,
|
||||||
|
input_dim_y: int,
|
||||||
|
flip_p: float = 0.5,
|
||||||
|
jitter_brightness: float = 0.5,
|
||||||
|
jitter_contrast: float = 0.1,
|
||||||
|
jitter_saturation: float = 0.1,
|
||||||
|
jitter_hue: float = 0.15,
|
||||||
|
da_auto_mode: bool = False,
|
||||||
|
) -> tuple[
|
||||||
|
torch.utils.data.dataloader.DataLoader,
|
||||||
|
torch.utils.data.dataloader.DataLoader,
|
||||||
|
torchvision.transforms.Compose,
|
||||||
|
torchvision.transforms.Compose,
|
||||||
|
]:
|
||||||
|
if dataset == "MNIST":
|
||||||
|
tv_dataset_train = torchvision.datasets.MNIST(
|
||||||
|
root="data", train=True, download=True
|
||||||
|
)
|
||||||
|
tv_dataset_test = torchvision.datasets.MNIST(
|
||||||
|
root="data", train=False, download=True
|
||||||
|
)
|
||||||
|
elif dataset == "FashionMNIST":
|
||||||
|
tv_dataset_train = torchvision.datasets.FashionMNIST(
|
||||||
|
root="data", train=True, download=True
|
||||||
|
)
|
||||||
|
tv_dataset_test = torchvision.datasets.FashionMNIST(
|
||||||
|
root="data", train=False, download=True
|
||||||
|
)
|
||||||
|
elif dataset == "CIFAR10":
|
||||||
|
tv_dataset_train = torchvision.datasets.CIFAR10(
|
||||||
|
root="data", train=True, download=True
|
||||||
|
)
|
||||||
|
tv_dataset_test = torchvision.datasets.CIFAR10(
|
||||||
|
root="data", train=False, download=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("This dataset is not implemented.")
|
||||||
|
|
||||||
|
def seed_worker(worker_id):
|
||||||
|
worker_seed = torch.initial_seed() % 2**32
|
||||||
|
np.random.seed(worker_seed)
|
||||||
|
torch.random.seed(worker_seed)
|
||||||
|
|
||||||
|
g = torch.Generator()
|
||||||
|
g.manual_seed(0)
|
||||||
|
|
||||||
|
if dataset == "MNIST" or dataset == "FashionMNIST":
|
||||||
|
|
||||||
|
train_dataloader = data_loader(
|
||||||
|
torch_device=torch_device,
|
||||||
|
batch_size=batch_size_train,
|
||||||
|
pattern=tv_dataset_train.data,
|
||||||
|
labels=tv_dataset_train.targets,
|
||||||
|
shuffle=True,
|
||||||
|
worker_init_fn=seed_worker,
|
||||||
|
generator=g,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_dataloader = data_loader(
|
||||||
|
torch_device=torch_device,
|
||||||
|
batch_size=batch_size_test,
|
||||||
|
pattern=tv_dataset_test.data,
|
||||||
|
labels=tv_dataset_test.targets,
|
||||||
|
shuffle=False,
|
||||||
|
worker_init_fn=seed_worker,
|
||||||
|
generator=g,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Data augmentation filter
|
||||||
|
test_processing_chain = torchvision.transforms.Compose(
|
||||||
|
transforms=[torchvision.transforms.CenterCrop((input_dim_x, input_dim_y))],
|
||||||
|
)
|
||||||
|
|
||||||
|
train_processing_chain = torchvision.transforms.Compose(
|
||||||
|
transforms=[torchvision.transforms.RandomCrop((input_dim_x, input_dim_y))],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
|
||||||
|
train_dataloader = data_loader(
|
||||||
|
torch_device=torch_device,
|
||||||
|
batch_size=batch_size_train,
|
||||||
|
pattern=torch.tensor(tv_dataset_train.data).movedim(-1, 1),
|
||||||
|
labels=torch.tensor(tv_dataset_train.targets),
|
||||||
|
shuffle=True,
|
||||||
|
worker_init_fn=seed_worker,
|
||||||
|
generator=g,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_dataloader = data_loader(
|
||||||
|
torch_device=torch_device,
|
||||||
|
batch_size=batch_size_test,
|
||||||
|
pattern=torch.tensor(tv_dataset_test.data).movedim(-1, 1),
|
||||||
|
labels=torch.tensor(tv_dataset_test.targets),
|
||||||
|
shuffle=False,
|
||||||
|
worker_init_fn=seed_worker,
|
||||||
|
generator=g,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Data augmentation filter
|
||||||
|
test_processing_chain = torchvision.transforms.Compose(
|
||||||
|
transforms=[torchvision.transforms.CenterCrop((input_dim_x, input_dim_y))],
|
||||||
|
)
|
||||||
|
|
||||||
|
if da_auto_mode:
|
||||||
|
train_processing_chain = torchvision.transforms.Compose(
|
||||||
|
transforms=[
|
||||||
|
v2.AutoAugment(
|
||||||
|
policy=torchvision.transforms.AutoAugmentPolicy(
|
||||||
|
v2.AutoAugmentPolicy.CIFAR10
|
||||||
|
)
|
||||||
|
),
|
||||||
|
torchvision.transforms.CenterCrop((input_dim_x, input_dim_y)),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train_processing_chain = torchvision.transforms.Compose(
|
||||||
|
transforms=[
|
||||||
|
torchvision.transforms.RandomCrop((input_dim_x, input_dim_y)),
|
||||||
|
torchvision.transforms.RandomHorizontalFlip(p=flip_p),
|
||||||
|
torchvision.transforms.ColorJitter(
|
||||||
|
brightness=jitter_brightness,
|
||||||
|
contrast=jitter_contrast,
|
||||||
|
saturation=jitter_saturation,
|
||||||
|
hue=jitter_hue,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
train_dataloader,
|
||||||
|
test_dataloader,
|
||||||
|
train_processing_chain,
|
||||||
|
test_processing_chain,
|
||||||
|
)
|
64
loss_function.py
Normal file
64
loss_function.py
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# loss_mode == 0: "normal" SbS loss function mixture
|
||||||
|
# loss_mode == 1: cross_entropy
|
||||||
|
def loss_function(
|
||||||
|
h: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
loss_mode: int = 0,
|
||||||
|
number_of_output_neurons: int = 10,
|
||||||
|
loss_coeffs_mse: float = 0.0,
|
||||||
|
loss_coeffs_kldiv: float = 0.0,
|
||||||
|
) -> torch.Tensor | None:
|
||||||
|
|
||||||
|
assert loss_mode >= 0
|
||||||
|
assert loss_mode <= 1
|
||||||
|
|
||||||
|
assert h.ndim == 2
|
||||||
|
|
||||||
|
if loss_mode == 0:
|
||||||
|
|
||||||
|
# Convert label into one hot
|
||||||
|
target_one_hot: torch.Tensor = torch.zeros(
|
||||||
|
(
|
||||||
|
labels.shape[0],
|
||||||
|
number_of_output_neurons,
|
||||||
|
),
|
||||||
|
device=h.device,
|
||||||
|
dtype=h.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
target_one_hot.scatter_(
|
||||||
|
1,
|
||||||
|
labels.to(h.device).unsqueeze(1),
|
||||||
|
torch.ones(
|
||||||
|
(labels.shape[0], 1),
|
||||||
|
device=h.device,
|
||||||
|
dtype=h.dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
my_loss: torch.Tensor = ((h - target_one_hot) ** 2).sum(dim=0).mean(
|
||||||
|
dim=0
|
||||||
|
) * loss_coeffs_mse
|
||||||
|
|
||||||
|
my_loss = (
|
||||||
|
my_loss
|
||||||
|
+ (
|
||||||
|
(target_one_hot * torch.log((target_one_hot + 1e-20) / (h + 1e-20)))
|
||||||
|
.sum(dim=0)
|
||||||
|
.mean(dim=0)
|
||||||
|
)
|
||||||
|
* loss_coeffs_kldiv
|
||||||
|
)
|
||||||
|
|
||||||
|
my_loss = my_loss / (abs(loss_coeffs_kldiv) + abs(loss_coeffs_mse))
|
||||||
|
|
||||||
|
return my_loss
|
||||||
|
|
||||||
|
elif loss_mode == 1:
|
||||||
|
my_loss = torch.nn.functional.cross_entropy(h, labels.to(h.device))
|
||||||
|
return my_loss
|
||||||
|
else:
|
||||||
|
return None
|
367
make_network.py
Normal file
367
make_network.py
Normal file
|
@ -0,0 +1,367 @@
|
||||||
|
import torch
|
||||||
|
from PositionalEncoding import PositionalEncoding
|
||||||
|
from SequentialSplit import SequentialSplit
|
||||||
|
from NNMF2dGrouped import NNMF2dGrouped
|
||||||
|
from Functional2Layer import Functional2Layer
|
||||||
|
|
||||||
|
|
||||||
|
def add_block(
|
||||||
|
network: torch.nn.Sequential,
|
||||||
|
embed_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
example_image: torch.Tensor,
|
||||||
|
mlp_ratio: int = 4,
|
||||||
|
block_id: int = 0,
|
||||||
|
iterations: int = 20,
|
||||||
|
padding: int = 1,
|
||||||
|
kernel_size: tuple[int, int] = (3, 3),
|
||||||
|
) -> torch.Tensor | None:
|
||||||
|
|
||||||
|
# ###########
|
||||||
|
# Attention #
|
||||||
|
# ###########
|
||||||
|
|
||||||
|
example_image_a: torch.Tensor = example_image.clone()
|
||||||
|
example_image_b: torch.Tensor = example_image.clone()
|
||||||
|
|
||||||
|
attention_a_sequential = torch.nn.Sequential()
|
||||||
|
|
||||||
|
attention_a_sequential.add_module(
|
||||||
|
"Attention Layer Norm 1 [Pre-Permute]",
|
||||||
|
Functional2Layer(func=torch.permute, dims=(0, 2, 3, 1)),
|
||||||
|
)
|
||||||
|
example_image_a = attention_a_sequential[-1](example_image_a)
|
||||||
|
|
||||||
|
attention_a_sequential.add_module(
|
||||||
|
"Attention Layer Norm 1",
|
||||||
|
torch.nn.LayerNorm(
|
||||||
|
normalized_shape=example_image_a.shape[-1],
|
||||||
|
eps=1e-06,
|
||||||
|
bias=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
example_image_a = attention_a_sequential[-1](example_image_a)
|
||||||
|
|
||||||
|
attention_a_sequential.add_module(
|
||||||
|
"Attention Layer Norm 1 [Post-Permute]",
|
||||||
|
Functional2Layer(func=torch.permute, dims=(0, 3, 1, 2)),
|
||||||
|
)
|
||||||
|
example_image_a = attention_a_sequential[-1](example_image_a)
|
||||||
|
|
||||||
|
attention_a_sequential.add_module(
|
||||||
|
"Attention Clamp Layer", Functional2Layer(func=torch.clamp, min=1e-6)
|
||||||
|
)
|
||||||
|
example_image_a = attention_a_sequential[-1](example_image_a)
|
||||||
|
|
||||||
|
backup_image_dim = example_image_a.shape[1]
|
||||||
|
|
||||||
|
attention_a_sequential.add_module(
|
||||||
|
"Attention Zero Padding Layer", torch.nn.ZeroPad2d(padding=padding)
|
||||||
|
)
|
||||||
|
example_image_a = attention_a_sequential[-1](example_image_a)
|
||||||
|
|
||||||
|
# I need the output size
|
||||||
|
mock_output_shape = (
|
||||||
|
torch.nn.functional.conv2d(
|
||||||
|
torch.zeros(
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
example_image_a.shape[2],
|
||||||
|
example_image_a.shape[3],
|
||||||
|
),
|
||||||
|
torch.zeros((1, 1, kernel_size[0], kernel_size[1])),
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
dilation=1,
|
||||||
|
)
|
||||||
|
.squeeze(0)
|
||||||
|
.squeeze(0)
|
||||||
|
).shape
|
||||||
|
|
||||||
|
attention_a_sequential.add_module(
|
||||||
|
"Attention Windowing [Part 1]",
|
||||||
|
torch.nn.Unfold(
|
||||||
|
kernel_size=(kernel_size[-2], kernel_size[-1]),
|
||||||
|
dilation=1,
|
||||||
|
padding=0,
|
||||||
|
stride=1,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
example_image_a = attention_a_sequential[-1](example_image_a)
|
||||||
|
|
||||||
|
attention_a_sequential.add_module(
|
||||||
|
"Attention Windowing [Part 2]",
|
||||||
|
torch.nn.Fold(
|
||||||
|
output_size=mock_output_shape,
|
||||||
|
kernel_size=(1, 1),
|
||||||
|
dilation=1,
|
||||||
|
padding=0,
|
||||||
|
stride=1,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
example_image_a = attention_a_sequential[-1](example_image_a)
|
||||||
|
|
||||||
|
attention_a_sequential.add_module("Attention NNMFConv2d", torch.nn.ReLU())
|
||||||
|
example_image_a = attention_a_sequential[-1](example_image_a)
|
||||||
|
|
||||||
|
attention_a_sequential.add_module(
|
||||||
|
"Attention NNMFConv2d",
|
||||||
|
NNMF2dGrouped(
|
||||||
|
in_channels=example_image_a.shape[1],
|
||||||
|
out_channels=embed_dim,
|
||||||
|
groups=num_heads,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
iterations=iterations,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
example_image_a = attention_a_sequential[-1](example_image_a)
|
||||||
|
|
||||||
|
attention_a_sequential.add_module(
|
||||||
|
"Attention Layer Norm 2 [Pre-Permute]",
|
||||||
|
Functional2Layer(func=torch.permute, dims=(0, 2, 3, 1)),
|
||||||
|
)
|
||||||
|
example_image_a = attention_a_sequential[-1](example_image_a)
|
||||||
|
|
||||||
|
attention_a_sequential.add_module(
|
||||||
|
"Attention Layer Norm 2",
|
||||||
|
torch.nn.LayerNorm(
|
||||||
|
normalized_shape=example_image_a.shape[-1],
|
||||||
|
eps=1e-06,
|
||||||
|
bias=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
example_image_a = attention_a_sequential[-1](example_image_a)
|
||||||
|
|
||||||
|
attention_a_sequential.add_module(
|
||||||
|
"Attention Layer Norm 2 [Post-Permute]",
|
||||||
|
Functional2Layer(func=torch.permute, dims=(0, 3, 1, 2)),
|
||||||
|
)
|
||||||
|
example_image_a = attention_a_sequential[-1](example_image_a)
|
||||||
|
|
||||||
|
attention_a_sequential.add_module(
|
||||||
|
"Attention Conv2d Layer ",
|
||||||
|
torch.nn.Conv2d(
|
||||||
|
in_channels=example_image_a.shape[1],
|
||||||
|
out_channels=backup_image_dim,
|
||||||
|
kernel_size=1,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
example_image_a = attention_a_sequential[-1](example_image_a)
|
||||||
|
|
||||||
|
attention_b_sequential = torch.nn.Sequential()
|
||||||
|
attention_b_sequential.add_module(
|
||||||
|
"Attention Identity for the skip", torch.nn.Identity()
|
||||||
|
)
|
||||||
|
example_image_b = attention_b_sequential[-1](example_image_b)
|
||||||
|
|
||||||
|
assert example_image_b.shape == example_image_a.shape
|
||||||
|
|
||||||
|
network.add_module(
|
||||||
|
f"Block Number {block_id} [Attention]",
|
||||||
|
SequentialSplit(
|
||||||
|
torch.nn.Sequential(
|
||||||
|
attention_a_sequential,
|
||||||
|
attention_b_sequential,
|
||||||
|
),
|
||||||
|
combine="SUM",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
example_image = network[-1](example_image)
|
||||||
|
|
||||||
|
# ######
|
||||||
|
# MLP #
|
||||||
|
# #####
|
||||||
|
|
||||||
|
example_image_a = example_image.clone()
|
||||||
|
example_image_b = example_image.clone()
|
||||||
|
|
||||||
|
mlp_a_sequential = torch.nn.Sequential()
|
||||||
|
|
||||||
|
mlp_a_sequential.add_module(
|
||||||
|
"MLP [Pre-Permute]", Functional2Layer(func=torch.permute, dims=(0, 2, 3, 1))
|
||||||
|
)
|
||||||
|
example_image_a = mlp_a_sequential[-1](example_image_a)
|
||||||
|
|
||||||
|
mlp_a_sequential.add_module(
|
||||||
|
"MLP Layer Norm",
|
||||||
|
torch.nn.LayerNorm(
|
||||||
|
normalized_shape=example_image_a.shape[-1],
|
||||||
|
eps=1e-06,
|
||||||
|
bias=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
example_image_a = mlp_a_sequential[-1](example_image_a)
|
||||||
|
|
||||||
|
mlp_a_sequential.add_module(
|
||||||
|
"MLP Linear Layer A",
|
||||||
|
torch.nn.Linear(
|
||||||
|
example_image_a.shape[-1],
|
||||||
|
int(example_image_a.shape[-1] * mlp_ratio),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
example_image_a = mlp_a_sequential[-1](example_image_a)
|
||||||
|
|
||||||
|
mlp_a_sequential.add_module("MLP GELU", torch.nn.GELU())
|
||||||
|
example_image_a = mlp_a_sequential[-1](example_image_a)
|
||||||
|
|
||||||
|
mlp_a_sequential.add_module(
|
||||||
|
"MLP Linear Layer B",
|
||||||
|
torch.nn.Linear(
|
||||||
|
example_image_a.shape[-1],
|
||||||
|
int(example_image_a.shape[-1] // mlp_ratio),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
example_image_a = mlp_a_sequential[-1](example_image_a)
|
||||||
|
|
||||||
|
mlp_a_sequential.add_module(
|
||||||
|
"MLP [Post-Permute]", Functional2Layer(func=torch.permute, dims=(0, 3, 1, 2))
|
||||||
|
)
|
||||||
|
example_image_a = mlp_a_sequential[-1](example_image_a)
|
||||||
|
|
||||||
|
mlp_b_sequential = torch.nn.Sequential()
|
||||||
|
mlp_b_sequential.add_module("MLP Identity for the skip", torch.nn.Identity())
|
||||||
|
|
||||||
|
example_image_b = attention_b_sequential[-1](example_image_b)
|
||||||
|
|
||||||
|
assert example_image_b.shape == example_image_a.shape
|
||||||
|
|
||||||
|
network.add_module(
|
||||||
|
f"Block Number {block_id} [MLP]",
|
||||||
|
SequentialSplit(
|
||||||
|
torch.nn.Sequential(
|
||||||
|
mlp_a_sequential,
|
||||||
|
mlp_b_sequential,
|
||||||
|
),
|
||||||
|
combine="SUM",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
example_image = network[-1](example_image)
|
||||||
|
|
||||||
|
return example_image
|
||||||
|
|
||||||
|
|
||||||
|
def make_network(
|
||||||
|
in_channels: int = 3,
|
||||||
|
dims: list[int] = [72, 72, 72],
|
||||||
|
embed_dims: list[int] = [192, 192, 192],
|
||||||
|
n_classes: int = 10,
|
||||||
|
heads: int = 12,
|
||||||
|
example_image_shape: list[int] = [1, 3, 28, 28],
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
device: torch.device | None = None,
|
||||||
|
iterations: int = 20,
|
||||||
|
) -> torch.nn.Sequential:
|
||||||
|
|
||||||
|
assert device is not None
|
||||||
|
|
||||||
|
network = torch.nn.Sequential()
|
||||||
|
|
||||||
|
example_image: torch.Tensor = torch.zeros(
|
||||||
|
example_image_shape, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
network.add_module(
|
||||||
|
"Encode Conv2d",
|
||||||
|
torch.nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
dims[0],
|
||||||
|
kernel_size=4,
|
||||||
|
stride=4,
|
||||||
|
padding=0,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
example_image = network[-1](example_image)
|
||||||
|
|
||||||
|
network.add_module(
|
||||||
|
"Encode Offset",
|
||||||
|
PositionalEncoding(
|
||||||
|
[example_image.shape[-3], example_image.shape[-2], example_image.shape[-1]]
|
||||||
|
).to(device=device),
|
||||||
|
)
|
||||||
|
example_image = network[-1](example_image)
|
||||||
|
|
||||||
|
network.add_module(
|
||||||
|
"Encode Layer Norm [Pre-Permute]",
|
||||||
|
Functional2Layer(func=torch.permute, dims=(0, 2, 3, 1)),
|
||||||
|
)
|
||||||
|
example_image = network[-1](example_image)
|
||||||
|
|
||||||
|
network.add_module(
|
||||||
|
"Encode Layer Norm",
|
||||||
|
torch.nn.LayerNorm(
|
||||||
|
normalized_shape=example_image.shape[-1],
|
||||||
|
eps=1e-06,
|
||||||
|
bias=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
example_image = network[-1](example_image)
|
||||||
|
|
||||||
|
network.add_module(
|
||||||
|
"Encode Layer Norm [Post-Permute]",
|
||||||
|
Functional2Layer(func=torch.permute, dims=(0, 3, 1, 2)),
|
||||||
|
)
|
||||||
|
example_image = network[-1](example_image)
|
||||||
|
|
||||||
|
for i in range(len(dims)):
|
||||||
|
example_image = add_block(
|
||||||
|
network=network,
|
||||||
|
embed_dim=embed_dims[i],
|
||||||
|
num_heads=heads,
|
||||||
|
mlp_ratio=2,
|
||||||
|
block_id=i,
|
||||||
|
example_image=example_image,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
iterations=iterations,
|
||||||
|
)
|
||||||
|
|
||||||
|
network.add_module(
|
||||||
|
"Spatial Mean Layer", Functional2Layer(func=torch.mean, dim=(-1, -2))
|
||||||
|
)
|
||||||
|
example_image = network[-1](example_image)
|
||||||
|
|
||||||
|
network.add_module(
|
||||||
|
"Final Linear Layer",
|
||||||
|
torch.nn.Linear(example_image.shape[-1], n_classes, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
example_image = network[-1](example_image)
|
||||||
|
|
||||||
|
network.add_module("Final Softmax Layer", torch.nn.Softmax(dim=-1))
|
||||||
|
example_image = network[-1](example_image)
|
||||||
|
|
||||||
|
assert example_image.ndim == 2
|
||||||
|
assert example_image.shape[0] == example_image_shape[0]
|
||||||
|
assert example_image.shape[1] == n_classes
|
||||||
|
|
||||||
|
return network
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
network = make_network(device=torch.device("cuda:0"))
|
||||||
|
print(network)
|
||||||
|
|
||||||
|
number_of_parameter: int = 0
|
||||||
|
for name, param in network.named_parameters():
|
||||||
|
print(f"Parameter name: {name}, Shape: {param.shape}")
|
||||||
|
number_of_parameter += param.numel()
|
||||||
|
|
||||||
|
print("Number of total parameters:", number_of_parameter)
|
32
make_optimize.py
Normal file
32
make_optimize.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def make_optimize(
|
||||||
|
parameters: list[list[torch.nn.parameter.Parameter]],
|
||||||
|
lr_initial: list[float],
|
||||||
|
eps=1e-10,
|
||||||
|
) -> tuple[
|
||||||
|
list[torch.optim.Adam | None],
|
||||||
|
list[torch.optim.lr_scheduler.ReduceLROnPlateau | None],
|
||||||
|
]:
|
||||||
|
list_optimizer: list[torch.optim.Adam | None] = []
|
||||||
|
list_lr_scheduler: list[torch.optim.lr_scheduler.ReduceLROnPlateau | None] = []
|
||||||
|
|
||||||
|
assert len(parameters) == len(lr_initial)
|
||||||
|
|
||||||
|
for i in range(0, len(parameters)):
|
||||||
|
if len(parameters[i]) > 0:
|
||||||
|
list_optimizer.append(torch.optim.Adam(parameters[i], lr=lr_initial[i]))
|
||||||
|
else:
|
||||||
|
list_optimizer.append(None)
|
||||||
|
|
||||||
|
for i in range(0, len(list_optimizer)):
|
||||||
|
if list_optimizer[i] is not None:
|
||||||
|
pass
|
||||||
|
list_lr_scheduler.append(
|
||||||
|
torch.optim.lr_scheduler.ReduceLROnPlateau(list_optimizer[i], eps=eps) # type: ignore
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
list_lr_scheduler.append(None)
|
||||||
|
|
||||||
|
return (list_optimizer, list_lr_scheduler)
|
26
non_linear_weigth_function.py
Normal file
26
non_linear_weigth_function.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def non_linear_weigth_function(
|
||||||
|
weight: torch.Tensor, beta: torch.Tensor | None, positive_function_type: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
if positive_function_type == 0:
|
||||||
|
positive_weights = torch.abs(weight)
|
||||||
|
|
||||||
|
elif positive_function_type == 1:
|
||||||
|
assert beta is not None
|
||||||
|
positive_weights = weight
|
||||||
|
max_value = torch.abs(positive_weights).max()
|
||||||
|
if max_value > 80:
|
||||||
|
positive_weights = 80.0 * positive_weights / max_value
|
||||||
|
positive_weights = torch.exp((torch.tanh(beta) + 1.0) * 0.5 * positive_weights)
|
||||||
|
|
||||||
|
elif positive_function_type == 2:
|
||||||
|
assert beta is not None
|
||||||
|
positive_weights = (torch.tanh(beta * weight) + 1.0) * 0.5
|
||||||
|
|
||||||
|
else:
|
||||||
|
positive_weights = weight
|
||||||
|
|
||||||
|
return positive_weights
|
263
run_network.py
Normal file
263
run_network.py
Normal file
|
@ -0,0 +1,263 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
|
|
||||||
|
import argh
|
||||||
|
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
rand_seed: int = 21
|
||||||
|
torch.manual_seed(rand_seed)
|
||||||
|
torch.cuda.manual_seed(rand_seed)
|
||||||
|
np.random.seed(rand_seed)
|
||||||
|
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from make_network import make_network
|
||||||
|
from get_the_data import get_the_data
|
||||||
|
from loss_function import loss_function
|
||||||
|
from make_optimize import make_optimize
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
lr_initial_nnmf: float = 0.01,
|
||||||
|
lr_initial_cnn: float = 0.01,
|
||||||
|
iterations: int = 25,
|
||||||
|
heads: int = 12,
|
||||||
|
dataset: str = "CIFAR10", # "CIFAR10", "FashionMNIST", "MNIST"
|
||||||
|
only_print_network: bool = False,
|
||||||
|
da_auto_mode: bool = False,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
lr_limit: float = 1e-9
|
||||||
|
|
||||||
|
torch_device: torch.device = (
|
||||||
|
torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
)
|
||||||
|
torch.set_default_dtype(torch.float32)
|
||||||
|
|
||||||
|
# Some parameters
|
||||||
|
batch_size_train: int = 500
|
||||||
|
batch_size_test: int = 500
|
||||||
|
number_of_epoch: int = 5000
|
||||||
|
|
||||||
|
prefix = ""
|
||||||
|
|
||||||
|
loss_mode: int = 0
|
||||||
|
loss_coeffs_mse: float = 0.5
|
||||||
|
loss_coeffs_kldiv: float = 1.0
|
||||||
|
print(
|
||||||
|
"loss_mode: ",
|
||||||
|
loss_mode,
|
||||||
|
"loss_coeffs_mse: ",
|
||||||
|
loss_coeffs_mse,
|
||||||
|
"loss_coeffs_kldiv: ",
|
||||||
|
loss_coeffs_kldiv,
|
||||||
|
)
|
||||||
|
|
||||||
|
if dataset == "MNIST" or dataset == "FashionMNIST":
|
||||||
|
input_number_of_channel: int = 1
|
||||||
|
input_dim_x: int = 24
|
||||||
|
input_dim_y: int = 24
|
||||||
|
else:
|
||||||
|
input_number_of_channel = 3
|
||||||
|
input_dim_x = 28
|
||||||
|
input_dim_y = 28
|
||||||
|
|
||||||
|
train_dataloader, test_dataloader, train_processing_chain, test_processing_chain = (
|
||||||
|
get_the_data(
|
||||||
|
dataset,
|
||||||
|
batch_size_train,
|
||||||
|
batch_size_test,
|
||||||
|
torch_device,
|
||||||
|
input_dim_x,
|
||||||
|
input_dim_y,
|
||||||
|
flip_p=0.5,
|
||||||
|
jitter_brightness=0.5,
|
||||||
|
jitter_contrast=0.1,
|
||||||
|
jitter_saturation=0.1,
|
||||||
|
jitter_hue=0.15,
|
||||||
|
da_auto_mode=da_auto_mode,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
network = make_network(
|
||||||
|
in_channels=input_number_of_channel,
|
||||||
|
dims=[72, 72, 72],
|
||||||
|
embed_dims=[192, 192, 192],
|
||||||
|
n_classes=10,
|
||||||
|
heads=heads,
|
||||||
|
example_image_shape=[1, input_number_of_channel, input_dim_x, input_dim_y],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=torch_device,
|
||||||
|
iterations=iterations,
|
||||||
|
)
|
||||||
|
print(network)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("Information about used parameters:")
|
||||||
|
|
||||||
|
parameter_list: list[list] = []
|
||||||
|
parameter_list.append([])
|
||||||
|
parameter_list.append([])
|
||||||
|
|
||||||
|
number_of_parameter: int = 0
|
||||||
|
for name, param in network.named_parameters():
|
||||||
|
|
||||||
|
if name.find("NNMF") == -1:
|
||||||
|
parameter_list[0].append(param)
|
||||||
|
else:
|
||||||
|
parameter_list[1].append(param)
|
||||||
|
print("!!! NNMF !!! ", end=" ")
|
||||||
|
|
||||||
|
print(f"Parameter name: {name}, Shape: {param.shape}")
|
||||||
|
number_of_parameter += param.numel()
|
||||||
|
print()
|
||||||
|
print("Number of total parameters:", number_of_parameter)
|
||||||
|
print("Number of parameter sets in CNN:", len(parameter_list[0]))
|
||||||
|
print("Number of parameter sets in NNMF:", len(parameter_list[1]))
|
||||||
|
|
||||||
|
if only_print_network:
|
||||||
|
exit()
|
||||||
|
|
||||||
|
(
|
||||||
|
optimizers,
|
||||||
|
lr_schedulers,
|
||||||
|
) = make_optimize(
|
||||||
|
parameters=parameter_list,
|
||||||
|
lr_initial=[
|
||||||
|
lr_initial_cnn,
|
||||||
|
lr_initial_nnmf,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
my_string: str = "_lr_"
|
||||||
|
for i in range(0, len(lr_schedulers)):
|
||||||
|
if lr_schedulers[i] is not None:
|
||||||
|
my_string += f"{lr_schedulers[i].get_last_lr()[0]:.4e}_" # type: ignore
|
||||||
|
else:
|
||||||
|
my_string += "-_"
|
||||||
|
|
||||||
|
default_path: str = f"{prefix}_iter{iterations}{my_string}"
|
||||||
|
log_dir: str = f"log_{default_path}"
|
||||||
|
|
||||||
|
tb = SummaryWriter(log_dir=log_dir)
|
||||||
|
|
||||||
|
for epoch_id in range(0, number_of_epoch):
|
||||||
|
print()
|
||||||
|
print(f"Epoch: {epoch_id}")
|
||||||
|
t_start: float = time.perf_counter()
|
||||||
|
|
||||||
|
train_loss: float = 0.0
|
||||||
|
train_correct: int = 0
|
||||||
|
train_number: int = 0
|
||||||
|
test_correct: int = 0
|
||||||
|
test_number: int = 0
|
||||||
|
|
||||||
|
# Switch the network into training mode
|
||||||
|
network.train()
|
||||||
|
|
||||||
|
# This runs in total for one epoch split up into mini-batches
|
||||||
|
for image, target in train_dataloader:
|
||||||
|
|
||||||
|
# Clean the gradient
|
||||||
|
for i in range(0, len(optimizers)):
|
||||||
|
if optimizers[i] is not None:
|
||||||
|
optimizers[i].zero_grad() # type: ignore
|
||||||
|
|
||||||
|
output = network(train_processing_chain(image))
|
||||||
|
|
||||||
|
loss = loss_function(
|
||||||
|
h=output,
|
||||||
|
labels=target,
|
||||||
|
number_of_output_neurons=output.shape[1],
|
||||||
|
loss_mode=loss_mode,
|
||||||
|
loss_coeffs_mse=loss_coeffs_mse,
|
||||||
|
loss_coeffs_kldiv=loss_coeffs_kldiv,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert loss is not None
|
||||||
|
train_loss += loss.item()
|
||||||
|
train_correct += (output.argmax(dim=1) == target).sum().cpu().numpy()
|
||||||
|
train_number += target.shape[0]
|
||||||
|
|
||||||
|
# Calculate backprop
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# Update the parameter
|
||||||
|
# Clean the gradient
|
||||||
|
for i in range(0, len(optimizers)):
|
||||||
|
if optimizers[i] is not None:
|
||||||
|
optimizers[i].step() # type: ignore
|
||||||
|
|
||||||
|
perfomance_train_correct: float = 100.0 * train_correct / train_number
|
||||||
|
# Update the learning rate
|
||||||
|
for i in range(0, len(lr_schedulers)):
|
||||||
|
if lr_schedulers[i] is not None:
|
||||||
|
lr_schedulers[i].step(train_loss) # type: ignore
|
||||||
|
|
||||||
|
my_string = "Actual lr: "
|
||||||
|
for i in range(0, len(lr_schedulers)):
|
||||||
|
if lr_schedulers[i] is not None:
|
||||||
|
my_string += f" {lr_schedulers[i].get_last_lr()[0]:.4e} " # type: ignore
|
||||||
|
else:
|
||||||
|
my_string += " --- "
|
||||||
|
|
||||||
|
print(my_string)
|
||||||
|
t_training: float = time.perf_counter()
|
||||||
|
|
||||||
|
# Switch the network into evalution mode
|
||||||
|
network.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
|
||||||
|
for image, target in test_dataloader:
|
||||||
|
output = network(test_processing_chain(image))
|
||||||
|
|
||||||
|
test_correct += (output.argmax(dim=1) == target).sum().cpu().numpy()
|
||||||
|
test_number += target.shape[0]
|
||||||
|
|
||||||
|
t_testing = time.perf_counter()
|
||||||
|
|
||||||
|
perfomance_test_correct: float = 100.0 * test_correct / test_number
|
||||||
|
|
||||||
|
tb.add_scalar("Train Loss", train_loss / float(train_number), epoch_id)
|
||||||
|
tb.add_scalar("Train Number Correct", train_correct, epoch_id)
|
||||||
|
tb.add_scalar("Test Number Correct", test_correct, epoch_id)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Training: Loss={train_loss / float(train_number):.5f} Correct={perfomance_train_correct:.2f}%"
|
||||||
|
)
|
||||||
|
print(f"Testing: Correct={perfomance_test_correct:.2f}%")
|
||||||
|
print(
|
||||||
|
f"Time: Training={(t_training - t_start):.1f}sec, Testing={(t_testing - t_training):.1f}sec"
|
||||||
|
)
|
||||||
|
|
||||||
|
tb.flush()
|
||||||
|
|
||||||
|
lr_check: list[float] = []
|
||||||
|
for i in range(0, len(lr_schedulers)):
|
||||||
|
if lr_schedulers[i] is not None:
|
||||||
|
lr_check.append(lr_schedulers[i].get_last_lr()[0]) # type: ignore
|
||||||
|
|
||||||
|
lr_check_max = float(torch.tensor(lr_check).max())
|
||||||
|
|
||||||
|
if lr_check_max < lr_limit:
|
||||||
|
torch.save(network, f"Model_{default_path}.pt")
|
||||||
|
tb.close()
|
||||||
|
print("Done (lr_limit)")
|
||||||
|
return
|
||||||
|
|
||||||
|
torch.save(network, f"Model_{default_path}.pt")
|
||||||
|
print()
|
||||||
|
|
||||||
|
tb.close()
|
||||||
|
print("Done (loop end)")
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
argh.dispatch_command(main)
|
Loading…
Reference in a new issue