Add files via upload
This commit is contained in:
parent
1a807b44df
commit
a85805e92c
16 changed files with 659 additions and 46 deletions
|
@ -151,11 +151,7 @@ class Adam(torch.optim.Optimizer):
|
||||||
if sbs_setting[i] is False:
|
if sbs_setting[i] is False:
|
||||||
param -= step_size * (exp_avg / denom)
|
param -= step_size * (exp_avg / denom)
|
||||||
else:
|
else:
|
||||||
# delta = torch.exp(-step_size * (exp_avg / denom))
|
delta = 0.5 * torch.tanh(-step_size * (exp_avg / denom)) + 1.0
|
||||||
delta = torch.tanh(-step_size * (exp_avg / denom))
|
|
||||||
delta += 1.0
|
|
||||||
delta *= 0.5
|
|
||||||
delta += 0.5
|
|
||||||
self._logging.info(
|
self._logging.info(
|
||||||
f"ADAM: Layer {i} -> dw_min:{float(delta.min()):.4e} dw_max:{float(delta.max()):.4e} lr:{lr:.4e}"
|
f"ADAM: Layer {i} -> dw_min:{float(delta.min()):.4e} dw_max:{float(delta.max()):.4e} lr:{lr:.4e}"
|
||||||
)
|
)
|
||||||
|
|
|
@ -15,6 +15,8 @@ class DatasetMaster(torch.utils.data.Dataset, ABC):
|
||||||
initial_size: list[int]
|
initial_size: list[int]
|
||||||
channel_size: int
|
channel_size: int
|
||||||
|
|
||||||
|
alpha: float
|
||||||
|
|
||||||
# Initialize
|
# Initialize
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
90
network/DatasetMix.py
Normal file
90
network/DatasetMix.py
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
import torch
|
||||||
|
from network.Dataset import DatasetMNIST, DatasetFashionMNIST, DatasetCIFAR
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetMNISTMix(DatasetMNIST):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
train: bool = False,
|
||||||
|
path_pattern: str = "./",
|
||||||
|
path_label: str = "./",
|
||||||
|
alpha: float = 1.0,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(train, path_pattern, path_label)
|
||||||
|
self.alpha = alpha
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> tuple[torch.Tensor, list[int]]: # type: ignore
|
||||||
|
|
||||||
|
assert self.alpha >= 0.0
|
||||||
|
assert self.alpha <= 1.0
|
||||||
|
|
||||||
|
image_a, target_a = super().__getitem__(index)
|
||||||
|
|
||||||
|
target_b: int = target_a
|
||||||
|
while target_b == target_a:
|
||||||
|
image_b, target_b = super().__getitem__(
|
||||||
|
int(math.floor(self.number_of_pattern * torch.rand((1)).item()))
|
||||||
|
)
|
||||||
|
|
||||||
|
image = self.alpha * image_a + (1.0 - self.alpha) * image_b
|
||||||
|
target = [target_a, target_b]
|
||||||
|
return image, target
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetFashionMNISTMix(DatasetFashionMNIST):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
train: bool = False,
|
||||||
|
path_pattern: str = "./",
|
||||||
|
path_label: str = "./",
|
||||||
|
alpha: float = 1.0,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(train, path_pattern, path_label)
|
||||||
|
self.alpha = alpha
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> tuple[torch.Tensor, list[int]]: # type: ignore
|
||||||
|
|
||||||
|
assert self.alpha >= 0.0
|
||||||
|
assert self.alpha <= 1.0
|
||||||
|
|
||||||
|
image_a, target_a = super().__getitem__(index)
|
||||||
|
|
||||||
|
target_b: int = target_a
|
||||||
|
while target_b == target_a:
|
||||||
|
image_b, target_b = super().__getitem__(
|
||||||
|
int(math.floor(self.number_of_pattern * torch.rand((1)).item()))
|
||||||
|
)
|
||||||
|
|
||||||
|
image = self.alpha * image_a + (1.0 - self.alpha) * image_b
|
||||||
|
target = [target_a, target_b]
|
||||||
|
return image, target
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetCIFARMix(DatasetCIFAR):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
train: bool = False,
|
||||||
|
path_pattern: str = "./",
|
||||||
|
path_label: str = "./",
|
||||||
|
alpha: float = 1.0,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(train, path_pattern, path_label)
|
||||||
|
self.alpha = alpha
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> tuple[torch.Tensor, list[int]]: # type: ignore
|
||||||
|
|
||||||
|
assert self.alpha >= 0.0
|
||||||
|
assert self.alpha <= 1.0
|
||||||
|
|
||||||
|
image_a, target_a = super().__getitem__(index)
|
||||||
|
|
||||||
|
target_b: int = target_a
|
||||||
|
while target_b == target_a:
|
||||||
|
image_b, target_b = super().__getitem__(
|
||||||
|
int(math.floor(self.number_of_pattern * torch.rand((1)).item()))
|
||||||
|
)
|
||||||
|
|
||||||
|
image = self.alpha * image_a + (1.0 - self.alpha) * image_b
|
||||||
|
target = [target_a, target_b]
|
||||||
|
return image, target
|
|
@ -443,7 +443,6 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
)
|
)
|
||||||
|
|
||||||
elif (parameter_output_layer is True) and (parameter_local_learning is True):
|
elif (parameter_output_layer is True) and (parameter_local_learning is True):
|
||||||
|
|
||||||
target_one_hot: torch.Tensor = torch.zeros(
|
target_one_hot: torch.Tensor = torch.zeros(
|
||||||
(
|
(
|
||||||
labels.shape[0],
|
labels.shape[0],
|
||||||
|
|
|
@ -51,7 +51,13 @@ class InputSpikeImage(torch.nn.Module):
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
if self.number_of_spikes < 1:
|
if self.number_of_spikes < 1:
|
||||||
return input
|
output = input
|
||||||
|
output = output.type(dtype=input.dtype)
|
||||||
|
if self._normalize is True:
|
||||||
|
output = output * output.shape[-1] * output.shape[-2] * output.shape[-3] / output.sum(dim=-1, keepdim=True).sum(
|
||||||
|
dim=-2, keepdim=True
|
||||||
|
).sum(dim=-3, keepdim=True)
|
||||||
|
return output
|
||||||
|
|
||||||
input_shape: list[int] = [
|
input_shape: list[int] = [
|
||||||
int(input.shape[0]),
|
int(input.shape[0]),
|
||||||
|
@ -95,9 +101,10 @@ class InputSpikeImage(torch.nn.Module):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._normalize is True:
|
|
||||||
output = output.type(dtype=input_work.dtype)
|
output = output.type(dtype=input_work.dtype)
|
||||||
output = output / output.sum(dim=-1, keepdim=True).sum(
|
|
||||||
|
if self._normalize is True:
|
||||||
|
output = output * output.shape[-1] * output.shape[-2] * output.shape[-3] / output.sum(dim=-1, keepdim=True).sum(
|
||||||
dim=-2, keepdim=True
|
dim=-2, keepdim=True
|
||||||
).sum(dim=-3, keepdim=True)
|
).sum(dim=-3, keepdim=True)
|
||||||
|
|
||||||
|
|
280
network/NNMFLayer.py
Normal file
280
network/NNMFLayer.py
Normal file
|
@ -0,0 +1,280 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from network.calculate_output_size import calculate_output_size
|
||||||
|
|
||||||
|
|
||||||
|
class NNMFLayer(torch.nn.Module):
|
||||||
|
|
||||||
|
_epsilon_0: float
|
||||||
|
_weights: torch.nn.parameter.Parameter
|
||||||
|
_weights_exists: bool = False
|
||||||
|
_kernel_size: list[int]
|
||||||
|
_stride: list[int]
|
||||||
|
_dilation: list[int]
|
||||||
|
_padding: list[int]
|
||||||
|
_output_size: torch.Tensor
|
||||||
|
_number_of_neurons: int
|
||||||
|
_number_of_input_neurons: int
|
||||||
|
_h_initial: torch.Tensor | None = None
|
||||||
|
_w_trainable: bool
|
||||||
|
_weight_noise_range: list[float]
|
||||||
|
_input_size: list[int]
|
||||||
|
_output_layer: bool = False
|
||||||
|
_number_of_iterations: int
|
||||||
|
_local_learning: bool = False
|
||||||
|
|
||||||
|
device: torch.device
|
||||||
|
default_dtype: torch.dtype
|
||||||
|
|
||||||
|
_number_of_grad_weight_contributions: float = 0.0
|
||||||
|
|
||||||
|
last_input_store: bool = False
|
||||||
|
last_input_data: torch.Tensor | None = None
|
||||||
|
|
||||||
|
_layer_id: int = -1
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
number_of_input_neurons: int,
|
||||||
|
number_of_neurons: int,
|
||||||
|
input_size: list[int],
|
||||||
|
forward_kernel_size: list[int],
|
||||||
|
number_of_iterations: int,
|
||||||
|
epsilon_0: float = 1.0,
|
||||||
|
weight_noise_range: list[float] = [0.0, 1.0],
|
||||||
|
strides: list[int] = [1, 1],
|
||||||
|
dilation: list[int] = [0, 0],
|
||||||
|
padding: list[int] = [0, 0],
|
||||||
|
w_trainable: bool = False,
|
||||||
|
device: torch.device | None = None,
|
||||||
|
default_dtype: torch.dtype | None = None,
|
||||||
|
layer_id: int = -1,
|
||||||
|
local_learning: bool = False,
|
||||||
|
output_layer: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert device is not None
|
||||||
|
assert default_dtype is not None
|
||||||
|
self.device = device
|
||||||
|
self.default_dtype = default_dtype
|
||||||
|
|
||||||
|
self._w_trainable = bool(w_trainable)
|
||||||
|
self._stride = strides
|
||||||
|
self._dilation = dilation
|
||||||
|
self._padding = padding
|
||||||
|
self._kernel_size = forward_kernel_size
|
||||||
|
self._number_of_input_neurons = int(number_of_input_neurons)
|
||||||
|
self._number_of_neurons = int(number_of_neurons)
|
||||||
|
self._epsilon_0 = float(epsilon_0)
|
||||||
|
self._number_of_iterations = int(number_of_iterations)
|
||||||
|
self._weight_noise_range = weight_noise_range
|
||||||
|
self._layer_id = layer_id
|
||||||
|
self._local_learning = local_learning
|
||||||
|
self._output_layer = output_layer
|
||||||
|
|
||||||
|
assert len(input_size) == 2
|
||||||
|
self._input_size = input_size
|
||||||
|
|
||||||
|
self._output_size = calculate_output_size(
|
||||||
|
value=input_size,
|
||||||
|
kernel_size=self._kernel_size,
|
||||||
|
stride=self._stride,
|
||||||
|
dilation=self._dilation,
|
||||||
|
padding=self._padding,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.set_h_init_to_uniform()
|
||||||
|
|
||||||
|
# ###############################################################
|
||||||
|
# Initialize the weights
|
||||||
|
# ###############################################################
|
||||||
|
|
||||||
|
assert len(self._weight_noise_range) == 2
|
||||||
|
weights = torch.empty(
|
||||||
|
(
|
||||||
|
int(self._kernel_size[0])
|
||||||
|
* int(self._kernel_size[1])
|
||||||
|
* int(self._number_of_input_neurons),
|
||||||
|
int(self._number_of_neurons),
|
||||||
|
),
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.nn.init.uniform_(
|
||||||
|
weights,
|
||||||
|
a=float(self._weight_noise_range[0]),
|
||||||
|
b=float(self._weight_noise_range[1]),
|
||||||
|
)
|
||||||
|
self.weights = weights
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weights(self) -> torch.Tensor | None:
|
||||||
|
if self._weights_exists is False:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return self._weights
|
||||||
|
|
||||||
|
@weights.setter
|
||||||
|
def weights(self, value: torch.Tensor):
|
||||||
|
assert value is not None
|
||||||
|
assert torch.is_tensor(value) is True
|
||||||
|
assert value.dim() == 2
|
||||||
|
temp: torch.Tensor = (
|
||||||
|
value.detach()
|
||||||
|
.clone(memory_format=torch.contiguous_format)
|
||||||
|
.type(dtype=self.default_dtype)
|
||||||
|
.to(device=self.device)
|
||||||
|
)
|
||||||
|
temp /= temp.sum(dim=0, keepdim=True, dtype=self.default_dtype)
|
||||||
|
if self._weights_exists is False:
|
||||||
|
self._weights = torch.nn.parameter.Parameter(temp, requires_grad=True)
|
||||||
|
self._weights_exists = True
|
||||||
|
else:
|
||||||
|
self._weights.data = temp
|
||||||
|
|
||||||
|
@property
|
||||||
|
def h_initial(self) -> torch.Tensor | None:
|
||||||
|
return self._h_initial
|
||||||
|
|
||||||
|
@h_initial.setter
|
||||||
|
def h_initial(self, value: torch.Tensor):
|
||||||
|
assert value is not None
|
||||||
|
assert torch.is_tensor(value) is True
|
||||||
|
assert value.dim() == 1
|
||||||
|
assert value.dtype == self.default_dtype
|
||||||
|
self._h_initial = (
|
||||||
|
value.detach()
|
||||||
|
.clone(memory_format=torch.contiguous_format)
|
||||||
|
.type(dtype=self.default_dtype)
|
||||||
|
.to(device=self.device)
|
||||||
|
.requires_grad_(False)
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_pre_care(self):
|
||||||
|
|
||||||
|
if self._weights.grad is not None:
|
||||||
|
assert self._number_of_grad_weight_contributions > 0
|
||||||
|
self._weights.grad /= self._number_of_grad_weight_contributions
|
||||||
|
self._number_of_grad_weight_contributions = 0.0
|
||||||
|
|
||||||
|
def update_after_care(self, threshold_weight: float):
|
||||||
|
|
||||||
|
if self._w_trainable is True:
|
||||||
|
self.norm_weights()
|
||||||
|
self.threshold_weights(threshold_weight)
|
||||||
|
self.norm_weights()
|
||||||
|
|
||||||
|
def set_h_init_to_uniform(self) -> None:
|
||||||
|
|
||||||
|
assert self._number_of_neurons > 2
|
||||||
|
|
||||||
|
self.h_initial: torch.Tensor = torch.full(
|
||||||
|
(self._number_of_neurons,),
|
||||||
|
(1.0 / float(self._number_of_neurons)),
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def norm_weights(self) -> None:
|
||||||
|
assert self._weights_exists is True
|
||||||
|
temp: torch.Tensor = (
|
||||||
|
self._weights.data.detach()
|
||||||
|
.clone(memory_format=torch.contiguous_format)
|
||||||
|
.type(dtype=self.default_dtype)
|
||||||
|
.to(device=self.device)
|
||||||
|
)
|
||||||
|
temp /= temp.sum(dim=0, keepdim=True, dtype=self.default_dtype)
|
||||||
|
self._weights.data = temp
|
||||||
|
|
||||||
|
def threshold_weights(self, threshold: float) -> None:
|
||||||
|
assert self._weights_exists is True
|
||||||
|
assert threshold >= 0
|
||||||
|
|
||||||
|
torch.clamp(
|
||||||
|
self._weights.data,
|
||||||
|
min=float(threshold),
|
||||||
|
max=None,
|
||||||
|
out=self._weights.data,
|
||||||
|
)
|
||||||
|
|
||||||
|
####################################################################
|
||||||
|
# Forward #
|
||||||
|
####################################################################
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# Are we happy with the input?
|
||||||
|
assert input is not None
|
||||||
|
assert torch.is_tensor(input) is True
|
||||||
|
assert input.dim() == 4
|
||||||
|
assert input.dtype == self.default_dtype
|
||||||
|
assert input.shape[1] == self._number_of_input_neurons
|
||||||
|
assert input.shape[2] == self._input_size[0]
|
||||||
|
assert input.shape[3] == self._input_size[1]
|
||||||
|
|
||||||
|
# Are we happy with the rest of the network?
|
||||||
|
assert self._epsilon_0 is not None
|
||||||
|
assert self._h_initial is not None
|
||||||
|
assert self._weights_exists is True
|
||||||
|
assert self._weights is not None
|
||||||
|
|
||||||
|
# Convolution of the input...
|
||||||
|
# Well, this is a convoltion layer
|
||||||
|
# there needs to be convolution somewhere
|
||||||
|
input_convolved = torch.nn.functional.fold(
|
||||||
|
torch.nn.functional.unfold(
|
||||||
|
input.requires_grad_(True),
|
||||||
|
kernel_size=(int(self._kernel_size[0]), int(self._kernel_size[1])),
|
||||||
|
dilation=(int(self._dilation[0]), int(self._dilation[1])),
|
||||||
|
padding=(int(self._padding[0]), int(self._padding[1])),
|
||||||
|
stride=(int(self._stride[0]), int(self._stride[1])),
|
||||||
|
),
|
||||||
|
output_size=tuple(self._output_size.tolist()),
|
||||||
|
kernel_size=(1, 1),
|
||||||
|
dilation=(1, 1),
|
||||||
|
padding=(0, 0),
|
||||||
|
stride=(1, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
# We might need the convolved input for other layers
|
||||||
|
# let us keep it for the future
|
||||||
|
if self.last_input_store is True:
|
||||||
|
self.last_input_data = input_convolved.detach().clone()
|
||||||
|
self.last_input_data /= self.last_input_data.sum(dim=1, keepdim=True)
|
||||||
|
else:
|
||||||
|
self.last_input_data = None
|
||||||
|
|
||||||
|
input_convolved = input_convolved / input_convolved.sum(dim=1, keepdim=True)
|
||||||
|
|
||||||
|
h = torch.tile(
|
||||||
|
self._h_initial.unsqueeze(0).unsqueeze(-1).unsqueeze(-1),
|
||||||
|
dims=[
|
||||||
|
int(input.shape[0]),
|
||||||
|
1,
|
||||||
|
int(self._output_size[0]),
|
||||||
|
int(self._output_size[1]),
|
||||||
|
],
|
||||||
|
).requires_grad_(True)
|
||||||
|
|
||||||
|
for _ in range(0, self._number_of_iterations):
|
||||||
|
h_w = h.unsqueeze(1) * self._weights.unsqueeze(0).unsqueeze(-1).unsqueeze(
|
||||||
|
-1
|
||||||
|
)
|
||||||
|
h_w = h_w / (h_w.sum(dim=2, keepdim=True) + 1e-20)
|
||||||
|
h_w = (h_w * input_convolved.unsqueeze(2)).sum(dim=1)
|
||||||
|
if self._epsilon_0 > 0:
|
||||||
|
h = h + self._epsilon_0 * h_w
|
||||||
|
else:
|
||||||
|
h = h_w
|
||||||
|
h = h / (h.sum(dim=1, keepdim=True) + 1e-20)
|
||||||
|
|
||||||
|
self._number_of_grad_weight_contributions += (
|
||||||
|
h.shape[0] * h.shape[-2] * h.shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
|
return h
|
|
@ -44,7 +44,7 @@ class LearningParameters:
|
||||||
overload_path: str = field(default="Previous")
|
overload_path: str = field(default="Previous")
|
||||||
|
|
||||||
weight_noise_range: list[float] = field(default_factory=list)
|
weight_noise_range: list[float] = field(default_factory=list)
|
||||||
eps_xy_intitial: float = field(default=0.1)
|
eps_xy_intitial: float = field(default=1.0)
|
||||||
|
|
||||||
disable_scale_grade: bool = field(default=False)
|
disable_scale_grade: bool = field(default=False)
|
||||||
kepp_last_grad_scale: bool = field(default=True)
|
kepp_last_grad_scale: bool = field(default=True)
|
||||||
|
@ -55,7 +55,6 @@ class LearningParameters:
|
||||||
|
|
||||||
w_trainable: list[bool] = field(default_factory=list)
|
w_trainable: list[bool] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Augmentation:
|
class Augmentation:
|
||||||
"""Parameters used for data augmentation."""
|
"""Parameters used for data augmentation."""
|
||||||
|
|
|
@ -87,6 +87,8 @@ class SbSLayer(torch.nn.Module):
|
||||||
spike_full_layer_input_distribution: bool = False,
|
spike_full_layer_input_distribution: bool = False,
|
||||||
force_forward_spike_on_cpu: bool = False,
|
force_forward_spike_on_cpu: bool = False,
|
||||||
force_forward_spike_output_on_cpu: bool = False,
|
force_forward_spike_output_on_cpu: bool = False,
|
||||||
|
local_learning: bool = False,
|
||||||
|
output_layer: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -117,6 +119,8 @@ class SbSLayer(torch.nn.Module):
|
||||||
self._epsilon_xy_use = epsilon_xy_use
|
self._epsilon_xy_use = epsilon_xy_use
|
||||||
self._force_forward_h_dynamic_on_cpu = force_forward_h_dynamic_on_cpu
|
self._force_forward_h_dynamic_on_cpu = force_forward_h_dynamic_on_cpu
|
||||||
self._spike_full_layer_input_distribution = spike_full_layer_input_distribution
|
self._spike_full_layer_input_distribution = spike_full_layer_input_distribution
|
||||||
|
self._local_learning = local_learning
|
||||||
|
self._output_layer = output_layer
|
||||||
|
|
||||||
assert len(input_size) == 2
|
assert len(input_size) == 2
|
||||||
self._input_size = input_size
|
self._input_size = input_size
|
||||||
|
|
|
@ -28,27 +28,28 @@ class SplitOnOffLayer(torch.nn.Module):
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
assert input.ndim == 4
|
assert input.ndim == 4
|
||||||
|
|
||||||
# self.training is switched by network.eval() and network.train()
|
# # self.training is switched by network.eval() and network.train()
|
||||||
if self.training is True:
|
# if self.training is True:
|
||||||
mean_temp = (
|
# mean_temp = (
|
||||||
input.mean(dim=0, keepdim=True)
|
# input.mean(dim=0, keepdim=True)
|
||||||
.mean(dim=1, keepdim=True)
|
# .mean(dim=1, keepdim=True)
|
||||||
.detach()
|
# .detach()
|
||||||
.clone()
|
# .clone()
|
||||||
)
|
# )
|
||||||
|
#
|
||||||
|
# if self.mean is None:
|
||||||
|
# self.mean = mean_temp
|
||||||
|
# else:
|
||||||
|
# self.mean = (1.0 - self.epsilon) * self.mean + self.epsilon * mean_temp
|
||||||
|
#
|
||||||
|
# assert self.mean is not None
|
||||||
|
|
||||||
if self.mean is None:
|
# temp = input - self.mean.detach().clone()
|
||||||
self.mean = mean_temp
|
temp = input - 0.5
|
||||||
else:
|
|
||||||
self.mean = (1.0 - self.epsilon) * self.mean + self.epsilon * mean_temp
|
|
||||||
|
|
||||||
assert self.mean is not None
|
|
||||||
|
|
||||||
temp = input - self.mean.detach().clone()
|
|
||||||
temp_a = torch.nn.functional.relu(temp)
|
temp_a = torch.nn.functional.relu(temp)
|
||||||
temp_b = torch.nn.functional.relu(-temp)
|
temp_b = torch.nn.functional.relu(-temp)
|
||||||
output = torch.cat((temp_a, temp_b), dim=1)
|
output = torch.cat((temp_a, temp_b), dim=1)
|
||||||
|
|
||||||
output /= output.sum(dim=1, keepdim=True) + 1e-20
|
#output /= output.sum(dim=1, keepdim=True) + 1e-20
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
|
@ -6,6 +6,11 @@ from network.Dataset import (
|
||||||
DatasetMNIST,
|
DatasetMNIST,
|
||||||
DatasetFashionMNIST,
|
DatasetFashionMNIST,
|
||||||
)
|
)
|
||||||
|
from network.DatasetMix import (
|
||||||
|
DatasetCIFARMix,
|
||||||
|
DatasetMNISTMix,
|
||||||
|
DatasetFashionMNISTMix,
|
||||||
|
)
|
||||||
from network.Parameter import Config
|
from network.Parameter import Config
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,6 +47,27 @@ def build_datasets(
|
||||||
the_dataset_test = DatasetFashionMNIST(
|
the_dataset_test = DatasetFashionMNIST(
|
||||||
train=False, path_pattern=cfg.data_path, path_label=cfg.data_path
|
train=False, path_pattern=cfg.data_path, path_label=cfg.data_path
|
||||||
)
|
)
|
||||||
|
elif cfg.data_mode == "MIX_CIFAR10":
|
||||||
|
the_dataset_train = DatasetCIFARMix(
|
||||||
|
train=True, path_pattern=cfg.data_path, path_label=cfg.data_path
|
||||||
|
)
|
||||||
|
the_dataset_test = DatasetCIFARMix(
|
||||||
|
train=False, path_pattern=cfg.data_path, path_label=cfg.data_path
|
||||||
|
)
|
||||||
|
elif cfg.data_mode == "MIX_MNIST":
|
||||||
|
the_dataset_train = DatasetMNISTMix(
|
||||||
|
train=True, path_pattern=cfg.data_path, path_label=cfg.data_path
|
||||||
|
)
|
||||||
|
the_dataset_test = DatasetMNISTMix(
|
||||||
|
train=False, path_pattern=cfg.data_path, path_label=cfg.data_path
|
||||||
|
)
|
||||||
|
elif cfg.data_mode == "MIX_MNIST_FASHION":
|
||||||
|
the_dataset_train = DatasetFashionMNISTMix(
|
||||||
|
train=True, path_pattern=cfg.data_path, path_label=cfg.data_path
|
||||||
|
)
|
||||||
|
the_dataset_test = DatasetFashionMNISTMix(
|
||||||
|
train=False, path_pattern=cfg.data_path, path_label=cfg.data_path
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception("data_mode unknown")
|
raise Exception("data_mode unknown")
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,8 @@ def build_lr_scheduler(
|
||||||
):
|
):
|
||||||
lr_scheduler_list.append(
|
lr_scheduler_list.append(
|
||||||
torch.optim.lr_scheduler.ReduceLROnPlateau(
|
torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
optimizer[id_optimizer],eps=1e-14,
|
optimizer[id_optimizer],
|
||||||
|
eps=1e-14,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -4,6 +4,7 @@ import torch
|
||||||
from network.calculate_output_size import calculate_output_size
|
from network.calculate_output_size import calculate_output_size
|
||||||
from network.Parameter import Config
|
from network.Parameter import Config
|
||||||
from network.SbSLayer import SbSLayer
|
from network.SbSLayer import SbSLayer
|
||||||
|
from network.NNMFLayer import NNMFLayer
|
||||||
from network.SplitOnOffLayer import SplitOnOffLayer
|
from network.SplitOnOffLayer import SplitOnOffLayer
|
||||||
from network.Conv2dApproximation import Conv2dApproximation
|
from network.Conv2dApproximation import Conv2dApproximation
|
||||||
from network.SbSReconstruction import SbSReconstruction
|
from network.SbSReconstruction import SbSReconstruction
|
||||||
|
@ -153,6 +154,14 @@ def build_network(
|
||||||
if cfg.network_structure.layer_type[layer_id].upper().find("POOLING") != -1:
|
if cfg.network_structure.layer_type[layer_id].upper().find("POOLING") != -1:
|
||||||
is_pooling_layer = True
|
is_pooling_layer = True
|
||||||
|
|
||||||
|
local_learning = False
|
||||||
|
if cfg.network_structure.layer_type[layer_id].upper().find("LOCAL") != -1:
|
||||||
|
local_learning = True
|
||||||
|
|
||||||
|
output_layer = False
|
||||||
|
if layer_id == len(cfg.network_structure.layer_type) - 1:
|
||||||
|
output_layer = True
|
||||||
|
|
||||||
network.append(
|
network.append(
|
||||||
SbSLayer(
|
SbSLayer(
|
||||||
number_of_input_neurons=in_channels,
|
number_of_input_neurons=in_channels,
|
||||||
|
@ -180,19 +189,13 @@ def build_network(
|
||||||
reduction_cooldown=cfg.reduction_cooldown,
|
reduction_cooldown=cfg.reduction_cooldown,
|
||||||
force_forward_h_dynamic_on_cpu=cfg.force_forward_h_dynamic_on_cpu,
|
force_forward_h_dynamic_on_cpu=cfg.force_forward_h_dynamic_on_cpu,
|
||||||
spike_full_layer_input_distribution=spike_full_layer_input_distribution,
|
spike_full_layer_input_distribution=spike_full_layer_input_distribution,
|
||||||
|
local_learning=local_learning,
|
||||||
|
output_layer=output_layer,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Adding the x,y output dimensions
|
# Adding the x,y output dimensions
|
||||||
input_size.append(network[-1]._output_size.tolist())
|
input_size.append(network[-1]._output_size.tolist())
|
||||||
|
|
||||||
network[-1]._output_layer = False
|
|
||||||
if layer_id == len(cfg.network_structure.layer_type) - 1:
|
|
||||||
network[-1]._output_layer = True
|
|
||||||
|
|
||||||
network[-1]._local_learning = False
|
|
||||||
if cfg.network_structure.layer_type[layer_id].upper().find("LOCAL") != -1:
|
|
||||||
network[-1]._local_learning = True
|
|
||||||
|
|
||||||
elif (
|
elif (
|
||||||
cfg.network_structure.layer_type[layer_id]
|
cfg.network_structure.layer_type[layer_id]
|
||||||
.upper()
|
.upper()
|
||||||
|
@ -276,6 +279,8 @@ def build_network(
|
||||||
):
|
):
|
||||||
logging.info(f"Layer: {layer_id} -> RELU Layer")
|
logging.info(f"Layer: {layer_id} -> RELU Layer")
|
||||||
network.append(torch.nn.ReLU())
|
network.append(torch.nn.ReLU())
|
||||||
|
network[-1]._w_trainable = False
|
||||||
|
|
||||||
input_size.append(input_size[-1])
|
input_size.append(input_size[-1])
|
||||||
|
|
||||||
# #############################################################
|
# #############################################################
|
||||||
|
@ -296,6 +301,8 @@ def build_network(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
network[-1]._w_trainable = False
|
||||||
|
|
||||||
# Calculate the x,y output dimensions
|
# Calculate the x,y output dimensions
|
||||||
input_size_temp = calculate_output_size(
|
input_size_temp = calculate_output_size(
|
||||||
value=input_size[-1],
|
value=input_size[-1],
|
||||||
|
@ -323,6 +330,9 @@ def build_network(
|
||||||
padding=(int(padding[0]), int(padding[1])),
|
padding=(int(padding[0]), int(padding[1])),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
network[-1]._w_trainable = False
|
||||||
|
|
||||||
# Calculate the x,y output dimensions
|
# Calculate the x,y output dimensions
|
||||||
input_size_temp = calculate_output_size(
|
input_size_temp = calculate_output_size(
|
||||||
value=input_size[-1],
|
value=input_size[-1],
|
||||||
|
@ -405,8 +415,67 @@ def build_network(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
network[-1]._w_trainable = False
|
||||||
|
|
||||||
input_size.append(input_size[-1])
|
input_size.append(input_size[-1])
|
||||||
|
|
||||||
|
# #############################################################
|
||||||
|
# NNMF:
|
||||||
|
# #############################################################
|
||||||
|
|
||||||
|
elif (
|
||||||
|
cfg.network_structure.layer_type[layer_id].upper().startswith("NNMF")
|
||||||
|
is True
|
||||||
|
):
|
||||||
|
|
||||||
|
assert in_channels > 0
|
||||||
|
assert out_channels > 0
|
||||||
|
|
||||||
|
number_of_iterations: int = -1
|
||||||
|
if len(cfg.number_of_spikes) > layer_id:
|
||||||
|
number_of_iterations = cfg.number_of_spikes[layer_id]
|
||||||
|
elif len(cfg.number_of_spikes) == 1:
|
||||||
|
number_of_iterations = cfg.number_of_spikes[0]
|
||||||
|
|
||||||
|
assert number_of_iterations > 0
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
(
|
||||||
|
f"Layer: {layer_id} -> NNMF Layer with {number_of_iterations} iterations "
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
local_learning = False
|
||||||
|
if cfg.network_structure.layer_type[layer_id].upper().find("LOCAL") != -1:
|
||||||
|
local_learning = True
|
||||||
|
|
||||||
|
output_layer = False
|
||||||
|
if layer_id == len(cfg.network_structure.layer_type) - 1:
|
||||||
|
output_layer = True
|
||||||
|
|
||||||
|
network.append(
|
||||||
|
NNMFLayer(
|
||||||
|
number_of_input_neurons=in_channels,
|
||||||
|
number_of_neurons=out_channels,
|
||||||
|
input_size=input_size[-1],
|
||||||
|
forward_kernel_size=kernel_size,
|
||||||
|
number_of_iterations=number_of_iterations,
|
||||||
|
epsilon_0=cfg.epsilon_0,
|
||||||
|
weight_noise_range=weight_noise_range,
|
||||||
|
strides=strides,
|
||||||
|
dilation=dilation,
|
||||||
|
padding=padding,
|
||||||
|
w_trainable=w_trainable,
|
||||||
|
device=device,
|
||||||
|
default_dtype=default_dtype,
|
||||||
|
layer_id=layer_id,
|
||||||
|
local_learning=local_learning,
|
||||||
|
output_layer=output_layer,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Adding the x,y output dimensions
|
||||||
|
input_size.append(network[-1]._output_size.tolist())
|
||||||
|
|
||||||
# #############################################################
|
# #############################################################
|
||||||
# Failure becaue we didn't found the selection of layer
|
# Failure becaue we didn't found the selection of layer
|
||||||
# #############################################################
|
# #############################################################
|
||||||
|
|
|
@ -2,6 +2,8 @@
|
||||||
import torch
|
import torch
|
||||||
from network.Parameter import Config
|
from network.Parameter import Config
|
||||||
from network.SbSLayer import SbSLayer
|
from network.SbSLayer import SbSLayer
|
||||||
|
from network.NNMFLayer import NNMFLayer
|
||||||
|
|
||||||
from network.Conv2dApproximation import Conv2dApproximation
|
from network.Conv2dApproximation import Conv2dApproximation
|
||||||
from network.Adam import Adam
|
from network.Adam import Adam
|
||||||
|
|
||||||
|
@ -26,6 +28,12 @@ def build_optimizer(
|
||||||
parameter_list_weights.append(network[id]._weights)
|
parameter_list_weights.append(network[id]._weights)
|
||||||
parameter_list_sbs.append(True)
|
parameter_list_sbs.append(True)
|
||||||
|
|
||||||
|
if (isinstance(network[id], NNMFLayer) is True) and (
|
||||||
|
network[id]._w_trainable is True
|
||||||
|
):
|
||||||
|
parameter_list_weights.append(network[id]._weights)
|
||||||
|
parameter_list_sbs.append(True)
|
||||||
|
|
||||||
if (isinstance(network[id], torch.nn.modules.conv.Conv2d) is True) and (
|
if (isinstance(network[id], torch.nn.modules.conv.Conv2d) is True) and (
|
||||||
network[id]._w_trainable is True
|
network[id]._w_trainable is True
|
||||||
):
|
):
|
||||||
|
|
|
@ -4,6 +4,8 @@ import glob
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from network.SbSLayer import SbSLayer
|
from network.SbSLayer import SbSLayer
|
||||||
|
from network.NNMFLayer import NNMFLayer
|
||||||
|
|
||||||
from network.SplitOnOffLayer import SplitOnOffLayer
|
from network.SplitOnOffLayer import SplitOnOffLayer
|
||||||
from network.Conv2dApproximation import Conv2dApproximation
|
from network.Conv2dApproximation import Conv2dApproximation
|
||||||
import os
|
import os
|
||||||
|
@ -46,6 +48,23 @@ def load_previous_weights(
|
||||||
)
|
)
|
||||||
logging.info(f"Weights file used for layer {id} : {file_to_load[0]}")
|
logging.info(f"Weights file used for layer {id} : {file_to_load[0]}")
|
||||||
|
|
||||||
|
if isinstance(network[id], NNMFLayer) is True:
|
||||||
|
filename_wilcard = os.path.join(
|
||||||
|
overload_path, f"Weight_L{id}_*{post_fix}.npy"
|
||||||
|
)
|
||||||
|
file_to_load = glob.glob(filename_wilcard)
|
||||||
|
|
||||||
|
if len(file_to_load) > 1:
|
||||||
|
raise Exception(f"Too many previous weights files {filename_wilcard}")
|
||||||
|
|
||||||
|
if len(file_to_load) == 1:
|
||||||
|
network[id].weights = torch.tensor(
|
||||||
|
np.load(file_to_load[0]),
|
||||||
|
dtype=default_dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
logging.info(f"Weights file used for layer {id} : {file_to_load[0]}")
|
||||||
|
|
||||||
if isinstance(network[id], torch.nn.modules.conv.Conv2d) is True:
|
if isinstance(network[id], torch.nn.modules.conv.Conv2d) is True:
|
||||||
|
|
||||||
# #################################################
|
# #################################################
|
||||||
|
|
|
@ -4,6 +4,7 @@ from network.Parameter import Config
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from network.SbSLayer import SbSLayer
|
from network.SbSLayer import SbSLayer
|
||||||
|
from network.NNMFLayer import NNMFLayer
|
||||||
from network.save_weight_and_bias import save_weight_and_bias
|
from network.save_weight_and_bias import save_weight_and_bias
|
||||||
from network.SbSReconstruction import SbSReconstruction
|
from network.SbSReconstruction import SbSReconstruction
|
||||||
|
|
||||||
|
@ -19,7 +20,9 @@ def add_weight_and_bias_to_histogram(
|
||||||
# ################################################
|
# ################################################
|
||||||
# Log the SbS Weights
|
# Log the SbS Weights
|
||||||
# ################################################
|
# ################################################
|
||||||
if isinstance(network[id], SbSLayer) is True:
|
if (isinstance(network[id], SbSLayer) is True) or (
|
||||||
|
isinstance(network[id], NNMFLayer) is True
|
||||||
|
):
|
||||||
if network[id]._w_trainable is True:
|
if network[id]._w_trainable is True:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -228,7 +231,9 @@ def run_optimizer(
|
||||||
cfg: Config,
|
cfg: Config,
|
||||||
) -> None:
|
) -> None:
|
||||||
for id in range(0, len(network)):
|
for id in range(0, len(network)):
|
||||||
if isinstance(network[id], SbSLayer) is True:
|
if (isinstance(network[id], SbSLayer) is True) or (
|
||||||
|
isinstance(network[id], NNMFLayer) is True
|
||||||
|
):
|
||||||
network[id].update_pre_care()
|
network[id].update_pre_care()
|
||||||
|
|
||||||
for optimizer_item in optimizer:
|
for optimizer_item in optimizer:
|
||||||
|
@ -236,7 +241,9 @@ def run_optimizer(
|
||||||
optimizer_item.step()
|
optimizer_item.step()
|
||||||
|
|
||||||
for id in range(0, len(network)):
|
for id in range(0, len(network)):
|
||||||
if isinstance(network[id], SbSLayer) is True:
|
if (isinstance(network[id], SbSLayer) is True) or (
|
||||||
|
isinstance(network[id], NNMFLayer) is True
|
||||||
|
):
|
||||||
network[id].update_after_care(
|
network[id].update_after_care(
|
||||||
cfg.learning_parameters.learning_rate_threshold_w
|
cfg.learning_parameters.learning_rate_threshold_w
|
||||||
/ float(
|
/ float(
|
||||||
|
@ -618,6 +625,94 @@ def loop_test(
|
||||||
return performance
|
return performance
|
||||||
|
|
||||||
|
|
||||||
|
def loop_test_mix(
|
||||||
|
epoch_id: int,
|
||||||
|
cfg: Config,
|
||||||
|
network: torch.nn.modules.container.Sequential,
|
||||||
|
my_loader_test: torch.utils.data.dataloader.DataLoader,
|
||||||
|
the_dataset_test,
|
||||||
|
device: torch.device,
|
||||||
|
default_dtype: torch.dtype,
|
||||||
|
logging,
|
||||||
|
tb: SummaryWriter | None,
|
||||||
|
overwrite_number_of_spikes: int = -1,
|
||||||
|
) -> tuple[float, float]:
|
||||||
|
|
||||||
|
test_correct_a_0: int = 0
|
||||||
|
test_correct_a_1: int = 0
|
||||||
|
test_correct_b_0: int = 0
|
||||||
|
test_correct_b_1: int = 0
|
||||||
|
|
||||||
|
test_count: int = 0
|
||||||
|
test_complete: int = the_dataset_test.__len__()
|
||||||
|
|
||||||
|
logging.info("")
|
||||||
|
logging.info("Testing:")
|
||||||
|
mini_batch_id: int = 0
|
||||||
|
|
||||||
|
for h_x, h_x_labels in my_loader_test:
|
||||||
|
assert len(h_x_labels) == 2
|
||||||
|
label_a = h_x_labels[0]
|
||||||
|
label_b = h_x_labels[1]
|
||||||
|
assert label_a.shape[0] == label_b.shape[0]
|
||||||
|
assert h_x.shape[0] == label_b.shape[0]
|
||||||
|
|
||||||
|
time_0 = time.perf_counter()
|
||||||
|
|
||||||
|
h_collection = forward_pass_test(
|
||||||
|
input=h_x,
|
||||||
|
labels=label_a,
|
||||||
|
the_dataset_test=the_dataset_test,
|
||||||
|
cfg=cfg,
|
||||||
|
network=network,
|
||||||
|
device=device,
|
||||||
|
default_dtype=default_dtype,
|
||||||
|
mini_batch_id=mini_batch_id,
|
||||||
|
overwrite_number_of_spikes=overwrite_number_of_spikes,
|
||||||
|
)
|
||||||
|
h_h: torch.Tensor = h_collection[-1].detach().clone().cpu()
|
||||||
|
|
||||||
|
# -------------
|
||||||
|
|
||||||
|
for id in range(0, h_h.shape[0]):
|
||||||
|
temp = h_h[id, ...].squeeze().argsort(descending=True)
|
||||||
|
test_correct_a_0 += float(label_a[id] == int(temp[0]))
|
||||||
|
test_correct_a_1 += float(label_a[id] == int(temp[1]))
|
||||||
|
test_correct_b_0 += float(label_b[id] == int(temp[0]))
|
||||||
|
test_correct_b_1 += float(label_b[id] == int(temp[1]))
|
||||||
|
|
||||||
|
test_count += h_h.shape[0]
|
||||||
|
performance_a_0: float = 100.0 * test_correct_a_0 / test_count
|
||||||
|
performance_a_1: float = 100.0 * test_correct_a_1 / test_count
|
||||||
|
performance_b_0: float = 100.0 * test_correct_b_0 / test_count
|
||||||
|
performance_b_1: float = 100.0 * test_correct_b_1 / test_count
|
||||||
|
time_1 = time.perf_counter()
|
||||||
|
time_measure_a = time_1 - time_0
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
(
|
||||||
|
f"\t\t{test_count} of {test_complete}"
|
||||||
|
f" with {performance_a_0/100:^6.2%}, "
|
||||||
|
f"{performance_a_1/100:^6.2%}, "
|
||||||
|
f"{performance_b_0/100:^6.2%}, "
|
||||||
|
f"{performance_b_1/100:^6.2%} \t "
|
||||||
|
f"Time used: {time_measure_a:^6.2f}sec"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mini_batch_id += 1
|
||||||
|
|
||||||
|
logging.info("")
|
||||||
|
|
||||||
|
if tb is not None:
|
||||||
|
tb.add_scalar("Test Error A0", 100.0 - performance_a_0, epoch_id)
|
||||||
|
tb.add_scalar("Test Error A1", 100.0 - performance_a_1, epoch_id)
|
||||||
|
tb.add_scalar("Test Error B0", 100.0 - performance_b_0, epoch_id)
|
||||||
|
tb.add_scalar("Test Error B1", 100.0 - performance_b_1, epoch_id)
|
||||||
|
tb.flush()
|
||||||
|
|
||||||
|
return performance_a_0, performance_a_1, performance_b_0, performance_b_1
|
||||||
|
|
||||||
|
|
||||||
def loop_test_reconstruction(
|
def loop_test_reconstruction(
|
||||||
epoch_id: int,
|
epoch_id: int,
|
||||||
cfg: Config,
|
cfg: Config,
|
||||||
|
|
|
@ -4,6 +4,7 @@ from network.Parameter import Config
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from network.SbSLayer import SbSLayer
|
from network.SbSLayer import SbSLayer
|
||||||
|
from network.NNMFLayer import NNMFLayer
|
||||||
from network.SplitOnOffLayer import SplitOnOffLayer
|
from network.SplitOnOffLayer import SplitOnOffLayer
|
||||||
from network.Conv2dApproximation import Conv2dApproximation
|
from network.Conv2dApproximation import Conv2dApproximation
|
||||||
|
|
||||||
|
@ -38,6 +39,21 @@ def save_weight_and_bias(
|
||||||
network[id].weights.detach().cpu().numpy(),
|
network[id].weights.detach().cpu().numpy(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ################################################
|
||||||
|
# Save the NNMF Weights
|
||||||
|
# ################################################
|
||||||
|
|
||||||
|
if isinstance(network[id], NNMFLayer) is True:
|
||||||
|
if network[id]._w_trainable is True:
|
||||||
|
|
||||||
|
np.save(
|
||||||
|
os.path.join(
|
||||||
|
cfg.weight_path,
|
||||||
|
f"Weight_L{id}_S{iteration_number}{post_fix}.npy",
|
||||||
|
),
|
||||||
|
network[id].weights.detach().cpu().numpy(),
|
||||||
|
)
|
||||||
|
|
||||||
# ################################################
|
# ################################################
|
||||||
# Save the Conv2 Weights and Biases
|
# Save the Conv2 Weights and Biases
|
||||||
# ################################################
|
# ################################################
|
||||||
|
@ -88,6 +104,7 @@ def save_weight_and_bias(
|
||||||
|
|
||||||
if isinstance(network[id], SplitOnOffLayer) is True:
|
if isinstance(network[id], SplitOnOffLayer) is True:
|
||||||
|
|
||||||
|
if network[id].mean is not None:
|
||||||
np.save(
|
np.save(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
cfg.weight_path, f"Mean_L{id}_S{iteration_number}{post_fix}.npy"
|
cfg.weight_path, f"Mean_L{id}_S{iteration_number}{post_fix}.npy"
|
||||||
|
|
Loading…
Reference in a new issue