Add files via upload
This commit is contained in:
parent
26673d3f31
commit
9bdadfef7c
1 changed files with 503 additions and 0 deletions
503
network/AutoNNMFLayer.py
Normal file
503
network/AutoNNMFLayer.py
Normal file
|
@ -0,0 +1,503 @@
|
|||
import torch
|
||||
|
||||
from network.calculate_output_size import calculate_output_size
|
||||
|
||||
|
||||
class AutoNNMFLayer(torch.nn.Module):
|
||||
_epsilon_0: float
|
||||
_weights: torch.nn.parameter.Parameter
|
||||
_weights_exists: bool = False
|
||||
_kernel_size: list[int]
|
||||
_stride: list[int]
|
||||
_dilation: list[int]
|
||||
_padding: list[int]
|
||||
_output_size: torch.Tensor
|
||||
_inbetween_size: torch.Tensor
|
||||
_number_of_neurons: int
|
||||
_number_of_input_neurons: int
|
||||
_h_initial: torch.Tensor | None = None
|
||||
_w_trainable: bool
|
||||
_weight_noise_range: list[float]
|
||||
_input_size: list[int]
|
||||
_output_layer: bool = False
|
||||
_number_of_iterations: int
|
||||
_local_learning: bool = False
|
||||
|
||||
device: torch.device
|
||||
default_dtype: torch.dtype
|
||||
|
||||
_number_of_grad_weight_contributions: float = 0.0
|
||||
|
||||
last_input_store: bool = False
|
||||
last_input_data: torch.Tensor | None = None
|
||||
|
||||
_layer_id: int = -1
|
||||
|
||||
_skip_gradient_calculation: bool
|
||||
|
||||
_keep_last_grad_scale: bool
|
||||
_disable_scale_grade: bool
|
||||
_last_grad_scale: torch.nn.parameter.Parameter
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
number_of_input_neurons: int,
|
||||
number_of_neurons: int,
|
||||
input_size: list[int],
|
||||
forward_kernel_size: list[int],
|
||||
number_of_iterations: int,
|
||||
epsilon_0: float = 1.0,
|
||||
weight_noise_range: list[float] = [0.0, 1.0],
|
||||
strides: list[int] = [1, 1],
|
||||
dilation: list[int] = [0, 0],
|
||||
padding: list[int] = [0, 0],
|
||||
w_trainable: bool = False,
|
||||
device: torch.device | None = None,
|
||||
default_dtype: torch.dtype | None = None,
|
||||
layer_id: int = -1,
|
||||
local_learning: bool = False,
|
||||
output_layer: bool = False,
|
||||
skip_gradient_calculation: bool = False,
|
||||
keep_last_grad_scale: bool = False,
|
||||
disable_scale_grade: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
assert device is not None
|
||||
assert default_dtype is not None
|
||||
self.device = device
|
||||
self.default_dtype = default_dtype
|
||||
|
||||
self._w_trainable = bool(w_trainable)
|
||||
self._stride = strides
|
||||
self._dilation = dilation
|
||||
self._padding = padding
|
||||
self._kernel_size = forward_kernel_size
|
||||
self._number_of_input_neurons = int(number_of_input_neurons)
|
||||
self._number_of_neurons = int(number_of_neurons)
|
||||
self._epsilon_0 = float(epsilon_0)
|
||||
self._number_of_iterations = int(number_of_iterations)
|
||||
self._weight_noise_range = weight_noise_range
|
||||
self._layer_id = layer_id
|
||||
self._local_learning = local_learning
|
||||
self._output_layer = output_layer
|
||||
self._skip_gradient_calculation = skip_gradient_calculation
|
||||
|
||||
self._keep_last_grad_scale = bool(keep_last_grad_scale)
|
||||
self._disable_scale_grade = bool(disable_scale_grade)
|
||||
self._last_grad_scale = torch.nn.parameter.Parameter(
|
||||
torch.tensor(-1.0, dtype=self.default_dtype),
|
||||
requires_grad=True,
|
||||
)
|
||||
|
||||
assert len(input_size) == 2
|
||||
self._input_size = input_size
|
||||
self._output_size = torch.tensor(input_size)
|
||||
|
||||
self._inbetween_size = calculate_output_size(
|
||||
value=input_size,
|
||||
kernel_size=self._kernel_size,
|
||||
stride=self._stride,
|
||||
dilation=self._dilation,
|
||||
padding=self._padding,
|
||||
)
|
||||
|
||||
self.set_h_init_to_uniform()
|
||||
|
||||
# ###############################################################
|
||||
# Initialize the weights
|
||||
# ###############################################################
|
||||
|
||||
assert len(self._weight_noise_range) == 2
|
||||
weights = torch.empty(
|
||||
(
|
||||
int(self._kernel_size[0])
|
||||
* int(self._kernel_size[1])
|
||||
* int(self._number_of_input_neurons),
|
||||
int(self._number_of_neurons),
|
||||
),
|
||||
dtype=self.default_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
torch.nn.init.uniform_(
|
||||
weights,
|
||||
a=float(self._weight_noise_range[0]),
|
||||
b=float(self._weight_noise_range[1]),
|
||||
)
|
||||
self.weights = weights
|
||||
|
||||
self.functional_nnmf_sbs_bp = FunctionalAutoNNMF.apply
|
||||
|
||||
@property
|
||||
def weights(self) -> torch.Tensor | None:
|
||||
if self._weights_exists is False:
|
||||
return None
|
||||
else:
|
||||
return self._weights
|
||||
|
||||
@weights.setter
|
||||
def weights(self, value: torch.Tensor):
|
||||
assert value is not None
|
||||
assert torch.is_tensor(value) is True
|
||||
assert value.dim() == 2
|
||||
temp: torch.Tensor = (
|
||||
value.detach()
|
||||
.clone(memory_format=torch.contiguous_format)
|
||||
.type(dtype=self.default_dtype)
|
||||
.to(device=self.device)
|
||||
)
|
||||
temp /= temp.sum(dim=0, keepdim=True, dtype=self.default_dtype)
|
||||
if self._weights_exists is False:
|
||||
self._weights = torch.nn.parameter.Parameter(temp, requires_grad=True)
|
||||
self._weights_exists = True
|
||||
else:
|
||||
self._weights.data = temp
|
||||
|
||||
@property
|
||||
def h_initial(self) -> torch.Tensor | None:
|
||||
return self._h_initial
|
||||
|
||||
@h_initial.setter
|
||||
def h_initial(self, value: torch.Tensor):
|
||||
assert value is not None
|
||||
assert torch.is_tensor(value) is True
|
||||
assert value.dim() == 1
|
||||
assert value.dtype == self.default_dtype
|
||||
self._h_initial = (
|
||||
value.detach()
|
||||
.clone(memory_format=torch.contiguous_format)
|
||||
.type(dtype=self.default_dtype)
|
||||
.to(device=self.device)
|
||||
.requires_grad_(False)
|
||||
)
|
||||
|
||||
def update_pre_care(self):
|
||||
if self._weights.grad is not None:
|
||||
assert self._number_of_grad_weight_contributions > 0
|
||||
self._weights.grad /= self._number_of_grad_weight_contributions
|
||||
self._number_of_grad_weight_contributions = 0.0
|
||||
|
||||
def update_after_care(self, threshold_weight: float):
|
||||
if self._w_trainable is True:
|
||||
self.norm_weights()
|
||||
self.threshold_weights(threshold_weight)
|
||||
self.norm_weights()
|
||||
|
||||
def after_batch(self, new_state: bool = False):
|
||||
if self._keep_last_grad_scale is True:
|
||||
self._last_grad_scale.data = self._last_grad_scale.grad # type: ignore
|
||||
self._keep_last_grad_scale = new_state
|
||||
|
||||
self._last_grad_scale.grad = torch.zeros_like(self._last_grad_scale.grad) # type: ignore
|
||||
|
||||
def set_h_init_to_uniform(self) -> None:
|
||||
assert self._number_of_neurons > 2
|
||||
|
||||
self.h_initial: torch.Tensor = torch.full(
|
||||
(self._number_of_neurons,),
|
||||
(1.0 / float(self._number_of_neurons)),
|
||||
dtype=self.default_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def norm_weights(self) -> None:
|
||||
assert self._weights_exists is True
|
||||
temp: torch.Tensor = (
|
||||
self._weights.data.detach()
|
||||
.clone(memory_format=torch.contiguous_format)
|
||||
.type(dtype=self.default_dtype)
|
||||
.to(device=self.device)
|
||||
)
|
||||
temp /= temp.sum(dim=0, keepdim=True, dtype=self.default_dtype)
|
||||
self._weights.data = temp
|
||||
|
||||
def threshold_weights(self, threshold: float) -> None:
|
||||
assert self._weights_exists is True
|
||||
assert threshold >= 0
|
||||
|
||||
torch.clamp(
|
||||
self._weights.data,
|
||||
min=float(threshold),
|
||||
max=None,
|
||||
out=self._weights.data,
|
||||
)
|
||||
|
||||
####################################################################
|
||||
# Forward #
|
||||
####################################################################
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Are we happy with the input?
|
||||
assert input is not None
|
||||
assert torch.is_tensor(input) is True
|
||||
assert input.dim() == 4
|
||||
assert input.dtype == self.default_dtype
|
||||
assert input.shape[1] == self._number_of_input_neurons
|
||||
assert input.shape[2] == self._input_size[0]
|
||||
assert input.shape[3] == self._input_size[1]
|
||||
|
||||
# Are we happy with the rest of the network?
|
||||
assert self._epsilon_0 is not None
|
||||
assert self._h_initial is not None
|
||||
assert self._weights_exists is True
|
||||
assert self._weights is not None
|
||||
|
||||
# Convolution of the input...
|
||||
# Well, this is a convoltion layer
|
||||
# there needs to be convolution somewhere
|
||||
input_convolved = torch.nn.functional.fold(
|
||||
torch.nn.functional.unfold(
|
||||
input.requires_grad_(True),
|
||||
kernel_size=(int(self._kernel_size[0]), int(self._kernel_size[1])),
|
||||
dilation=(int(self._dilation[0]), int(self._dilation[1])),
|
||||
padding=(int(self._padding[0]), int(self._padding[1])),
|
||||
stride=(int(self._stride[0]), int(self._stride[1])),
|
||||
),
|
||||
output_size=tuple(self._inbetween_size.tolist()),
|
||||
kernel_size=(1, 1),
|
||||
dilation=(1, 1),
|
||||
padding=(0, 0),
|
||||
stride=(1, 1),
|
||||
)
|
||||
|
||||
# We might need the convolved input for other layers
|
||||
# let us keep it for the future
|
||||
if self.last_input_store is True:
|
||||
self.last_input_data = input_convolved.detach().clone()
|
||||
self.last_input_data /= self.last_input_data.sum(dim=1, keepdim=True)
|
||||
else:
|
||||
self.last_input_data = None
|
||||
|
||||
input_convolved = input_convolved / (
|
||||
input_convolved.sum(dim=1, keepdim=True) + 1e-20
|
||||
)
|
||||
|
||||
parameter_list = torch.tensor(
|
||||
[
|
||||
int(self._inbetween_size[0]), # 0
|
||||
int(self._inbetween_size[1]), # 1
|
||||
int(self._number_of_iterations), # 2
|
||||
int(self._w_trainable), # 3
|
||||
int(self._skip_gradient_calculation), # 4
|
||||
int(self._keep_last_grad_scale), # 5
|
||||
int(self._disable_scale_grade), # 6
|
||||
int(self._local_learning), # 7
|
||||
int(self._output_layer), # 8
|
||||
],
|
||||
dtype=torch.int64,
|
||||
)
|
||||
|
||||
epsilon_0 = torch.tensor(self._epsilon_0)
|
||||
|
||||
h = self.functional_nnmf_sbs_bp(
|
||||
input_convolved,
|
||||
epsilon_0,
|
||||
self._weights,
|
||||
self._h_initial,
|
||||
parameter_list,
|
||||
self._last_grad_scale,
|
||||
)
|
||||
|
||||
self._number_of_grad_weight_contributions += (
|
||||
h.shape[0] * h.shape[-2] * h.shape[-1]
|
||||
)
|
||||
|
||||
# The decoding weight will not be optimized!!!
|
||||
# (self._weights.detach().clone())
|
||||
output_convolved = (
|
||||
h.unsqueeze(1) * (self._weights.detach().clone()).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
).sum(dim=2)
|
||||
|
||||
output = torch.nn.functional.fold(
|
||||
torch.nn.functional.unfold(
|
||||
output_convolved,
|
||||
kernel_size=(1, 1),
|
||||
dilation=1,
|
||||
padding=0,
|
||||
stride=1,
|
||||
),
|
||||
output_size=(int(input.shape[-2]), int(input.shape[-1])),
|
||||
kernel_size=(int(self._kernel_size[0]), int(self._kernel_size[1])),
|
||||
dilation=(int(self._dilation[0]), int(self._dilation[1])),
|
||||
padding=(int(self._padding[0]), int(self._padding[1])),
|
||||
stride=(int(self._stride[0]), int(self._stride[1])),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class FunctionalAutoNNMF(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward( # type: ignore
|
||||
ctx,
|
||||
input: torch.Tensor,
|
||||
epsilon_0: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
h_initial: torch.Tensor,
|
||||
parameter_list: torch.Tensor,
|
||||
grad_output_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
output_size_0: int = int(parameter_list[0])
|
||||
output_size_1: int = int(parameter_list[1])
|
||||
number_of_iterations: int = int(parameter_list[2])
|
||||
|
||||
_epsilon_0 = epsilon_0.item()
|
||||
|
||||
h = torch.tile(
|
||||
h_initial.unsqueeze(0).unsqueeze(-1).unsqueeze(-1),
|
||||
dims=[
|
||||
int(input.shape[0]),
|
||||
1,
|
||||
int(output_size_0),
|
||||
int(output_size_1),
|
||||
],
|
||||
)
|
||||
|
||||
for _ in range(0, number_of_iterations):
|
||||
h_w = h.unsqueeze(1) * weights.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
h_w = h_w / (h_w.sum(dim=2, keepdim=True) + 1e-20)
|
||||
h_w = (h_w * input.unsqueeze(2)).sum(dim=1)
|
||||
if _epsilon_0 > 0:
|
||||
h = h + _epsilon_0 * h_w
|
||||
else:
|
||||
h = h_w
|
||||
h = h / (h.sum(dim=1, keepdim=True) + 1e-20)
|
||||
|
||||
ctx.save_for_backward(
|
||||
input,
|
||||
weights,
|
||||
h,
|
||||
parameter_list,
|
||||
grad_output_scale,
|
||||
)
|
||||
|
||||
return h
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
# ##############################################
|
||||
# Get the variables back
|
||||
# ##############################################
|
||||
(
|
||||
input,
|
||||
weights,
|
||||
output,
|
||||
parameter_list,
|
||||
last_grad_scale,
|
||||
) = ctx.saved_tensors
|
||||
|
||||
# ##############################################
|
||||
# Default output
|
||||
# ##############################################
|
||||
grad_input = None
|
||||
grad_epsilon_0 = None
|
||||
grad_weights = None
|
||||
grad_h_initial = None
|
||||
grad_parameter_list = None
|
||||
|
||||
# ##############################################
|
||||
# Parameters
|
||||
# ##############################################
|
||||
parameter_w_trainable: bool = bool(parameter_list[3])
|
||||
parameter_skip_gradient_calculation: bool = bool(parameter_list[4])
|
||||
parameter_keep_last_grad_scale: bool = bool(parameter_list[5])
|
||||
parameter_disable_scale_grade: bool = bool(parameter_list[6])
|
||||
parameter_local_learning: bool = bool(parameter_list[7])
|
||||
parameter_output_layer: bool = bool(parameter_list[8])
|
||||
|
||||
# ##############################################
|
||||
# Dealing with overall scale of the gradient
|
||||
# ##############################################
|
||||
if parameter_disable_scale_grade is False:
|
||||
if parameter_keep_last_grad_scale is True:
|
||||
last_grad_scale = torch.tensor(
|
||||
[torch.abs(grad_output).max(), last_grad_scale]
|
||||
).max()
|
||||
grad_output /= last_grad_scale
|
||||
grad_output_scale = last_grad_scale.clone()
|
||||
|
||||
input /= input.sum(dim=1, keepdim=True, dtype=weights.dtype) + 1e-20
|
||||
|
||||
# #################################################
|
||||
# User doesn't want us to calculate the gradients
|
||||
# #################################################
|
||||
|
||||
if parameter_skip_gradient_calculation is True:
|
||||
return (
|
||||
grad_input,
|
||||
grad_epsilon_0,
|
||||
grad_weights,
|
||||
grad_h_initial,
|
||||
grad_parameter_list,
|
||||
grad_output_scale,
|
||||
)
|
||||
|
||||
# #################################################
|
||||
# Calculate backprop error (grad_input)
|
||||
# #################################################
|
||||
|
||||
backprop_r: torch.Tensor = weights.unsqueeze(0).unsqueeze(-1).unsqueeze(
|
||||
-1
|
||||
) * output.unsqueeze(1)
|
||||
|
||||
backprop_bigr: torch.Tensor = backprop_r.sum(dim=2)
|
||||
|
||||
backprop_z: torch.Tensor = backprop_r * (
|
||||
1.0 / (backprop_bigr + 1e-20)
|
||||
).unsqueeze(2)
|
||||
grad_input: torch.Tensor = (backprop_z * grad_output.unsqueeze(1)).sum(2)
|
||||
del backprop_z
|
||||
|
||||
# #################################################
|
||||
# Calculate weight gradient (grad_weights)
|
||||
# #################################################
|
||||
|
||||
if parameter_w_trainable is False:
|
||||
# #################################################
|
||||
# We don't train this weight
|
||||
# #################################################
|
||||
grad_weights = None
|
||||
|
||||
elif (parameter_output_layer is False) and (parameter_local_learning is True):
|
||||
# #################################################
|
||||
# Local learning
|
||||
# #################################################
|
||||
grad_weights = (
|
||||
(-2 * (input - backprop_bigr).unsqueeze(2) * output.unsqueeze(1))
|
||||
.sum(0)
|
||||
.sum(-1)
|
||||
.sum(-1)
|
||||
)
|
||||
|
||||
else:
|
||||
# #################################################
|
||||
# Backprop
|
||||
# #################################################
|
||||
backprop_f: torch.Tensor = output.unsqueeze(1) * (
|
||||
input / (backprop_bigr**2 + 1e-20)
|
||||
).unsqueeze(2)
|
||||
|
||||
result_omega: torch.Tensor = backprop_bigr.unsqueeze(
|
||||
2
|
||||
) * grad_output.unsqueeze(1)
|
||||
result_omega -= (backprop_r * grad_output.unsqueeze(1)).sum(2).unsqueeze(2)
|
||||
result_omega *= backprop_f
|
||||
del backprop_f
|
||||
grad_weights = result_omega.sum(0).sum(-1).sum(-1)
|
||||
del result_omega
|
||||
|
||||
del backprop_bigr
|
||||
del backprop_r
|
||||
|
||||
return (
|
||||
grad_input,
|
||||
grad_epsilon_0,
|
||||
grad_weights,
|
||||
grad_h_initial,
|
||||
grad_parameter_list,
|
||||
grad_output_scale,
|
||||
)
|
Loading…
Reference in a new issue