2023-01-05 13:23:58 +01:00
|
|
|
import torch
|
|
|
|
|
|
|
|
from network.Parameter import Config
|
|
|
|
|
|
|
|
import numpy as np
|
2023-02-04 14:24:47 +01:00
|
|
|
from network.SbSLayer import SbSLayer
|
2023-02-21 14:37:51 +01:00
|
|
|
from network.NNMFLayer import NNMFLayer
|
2023-01-05 13:23:58 +01:00
|
|
|
from network.SplitOnOffLayer import SplitOnOffLayer
|
|
|
|
from network.Conv2dApproximation import Conv2dApproximation
|
|
|
|
|
2023-02-04 14:24:47 +01:00
|
|
|
import os
|
|
|
|
|
2023-01-05 13:23:58 +01:00
|
|
|
|
|
|
|
def save_weight_and_bias(
|
2023-02-04 14:24:47 +01:00
|
|
|
cfg: Config,
|
|
|
|
network: torch.nn.modules.container.Sequential,
|
|
|
|
iteration_number: int,
|
|
|
|
order_id: float | int | None = None,
|
2023-01-05 13:23:58 +01:00
|
|
|
) -> None:
|
2023-02-04 14:24:47 +01:00
|
|
|
if order_id is None:
|
|
|
|
post_fix: str = ""
|
|
|
|
else:
|
|
|
|
post_fix = f"_{order_id}"
|
2023-01-05 13:23:58 +01:00
|
|
|
|
|
|
|
for id in range(0, len(network)):
|
|
|
|
|
|
|
|
# ################################################
|
|
|
|
# Save the SbS Weights
|
|
|
|
# ################################################
|
|
|
|
|
2023-02-04 14:24:47 +01:00
|
|
|
if isinstance(network[id], SbSLayer) is True:
|
2023-01-05 13:23:58 +01:00
|
|
|
if network[id]._w_trainable is True:
|
|
|
|
|
|
|
|
np.save(
|
2023-02-04 14:24:47 +01:00
|
|
|
os.path.join(
|
|
|
|
cfg.weight_path,
|
|
|
|
f"Weight_L{id}_S{iteration_number}{post_fix}.npy",
|
|
|
|
),
|
2023-01-05 13:23:58 +01:00
|
|
|
network[id].weights.detach().cpu().numpy(),
|
|
|
|
)
|
|
|
|
|
2023-02-21 14:37:51 +01:00
|
|
|
# ################################################
|
|
|
|
# Save the NNMF Weights
|
|
|
|
# ################################################
|
|
|
|
|
|
|
|
if isinstance(network[id], NNMFLayer) is True:
|
|
|
|
if network[id]._w_trainable is True:
|
|
|
|
|
|
|
|
np.save(
|
|
|
|
os.path.join(
|
|
|
|
cfg.weight_path,
|
|
|
|
f"Weight_L{id}_S{iteration_number}{post_fix}.npy",
|
|
|
|
),
|
|
|
|
network[id].weights.detach().cpu().numpy(),
|
|
|
|
)
|
|
|
|
|
2023-01-05 13:23:58 +01:00
|
|
|
# ################################################
|
|
|
|
# Save the Conv2 Weights and Biases
|
|
|
|
# ################################################
|
|
|
|
|
|
|
|
if isinstance(network[id], torch.nn.modules.conv.Conv2d) is True:
|
|
|
|
if network[id]._w_trainable is True:
|
|
|
|
# Save the new values
|
|
|
|
np.save(
|
2023-02-04 14:24:47 +01:00
|
|
|
os.path.join(
|
|
|
|
cfg.weight_path,
|
|
|
|
f"Weight_L{id}_S{iteration_number}{post_fix}.npy",
|
|
|
|
),
|
2023-01-05 13:23:58 +01:00
|
|
|
network[id]._parameters["weight"].data.detach().cpu().numpy(),
|
|
|
|
)
|
|
|
|
|
|
|
|
# Save the new values
|
|
|
|
np.save(
|
2023-02-04 14:24:47 +01:00
|
|
|
os.path.join(
|
|
|
|
cfg.weight_path, f"Bias_L{id}_S{iteration_number}{post_fix}.npy"
|
|
|
|
),
|
2023-01-05 13:23:58 +01:00
|
|
|
network[id]._parameters["bias"].data.detach().cpu().numpy(),
|
|
|
|
)
|
|
|
|
|
|
|
|
# ################################################
|
|
|
|
# Save the Approximate Conv2 Weights and Biases
|
|
|
|
# ################################################
|
|
|
|
|
|
|
|
if isinstance(network[id], Conv2dApproximation) is True:
|
|
|
|
if network[id]._w_trainable is True:
|
|
|
|
# Save the new values
|
|
|
|
np.save(
|
2023-02-04 14:24:47 +01:00
|
|
|
os.path.join(
|
|
|
|
cfg.weight_path,
|
|
|
|
f"Weight_L{id}_S{iteration_number}{post_fix}.npy",
|
|
|
|
),
|
2023-01-05 13:23:58 +01:00
|
|
|
network[id].weights.data.detach().cpu().numpy(),
|
|
|
|
)
|
|
|
|
|
|
|
|
# Save the new values
|
|
|
|
if network[id].bias is not None:
|
|
|
|
np.save(
|
2023-02-04 14:24:47 +01:00
|
|
|
os.path.join(
|
|
|
|
cfg.weight_path,
|
|
|
|
f"Bias_L{id}_S{iteration_number}{post_fix}.npy",
|
|
|
|
),
|
2023-01-05 13:23:58 +01:00
|
|
|
network[id].bias.data.detach().cpu().numpy(),
|
|
|
|
)
|
|
|
|
|
|
|
|
if isinstance(network[id], SplitOnOffLayer) is True:
|
|
|
|
|
2023-02-21 14:37:51 +01:00
|
|
|
if network[id].mean is not None:
|
|
|
|
np.save(
|
|
|
|
os.path.join(
|
|
|
|
cfg.weight_path, f"Mean_L{id}_S{iteration_number}{post_fix}.npy"
|
|
|
|
),
|
|
|
|
network[id].mean.detach().cpu().numpy(),
|
|
|
|
)
|