pytorch-sbs/network/save_weight_and_bias.py

72 lines
2.6 KiB
Python
Raw Normal View History

2023-01-05 13:23:58 +01:00
import torch
from network.Parameter import Config
import numpy as np
from network.SbS import SbS
from network.SplitOnOffLayer import SplitOnOffLayer
from network.Conv2dApproximation import Conv2dApproximation
def save_weight_and_bias(
cfg: Config, network: torch.nn.modules.container.Sequential, iteration_number: int
) -> None:
for id in range(0, len(network)):
# ################################################
# Save the SbS Weights
# ################################################
if isinstance(network[id], SbS) is True:
if network[id]._w_trainable is True:
np.save(
f"{cfg.weight_path}/Weight_L{id}_S{iteration_number}.npy",
network[id].weights.detach().cpu().numpy(),
)
# ################################################
# Save the Conv2 Weights and Biases
# ################################################
if isinstance(network[id], torch.nn.modules.conv.Conv2d) is True:
if network[id]._w_trainable is True:
# Save the new values
np.save(
f"{cfg.weight_path}/Weight_L{id}_S{iteration_number}.npy",
network[id]._parameters["weight"].data.detach().cpu().numpy(),
)
# Save the new values
np.save(
f"{cfg.weight_path}/Bias_L{id}_S{iteration_number}.npy",
network[id]._parameters["bias"].data.detach().cpu().numpy(),
)
# ################################################
# Save the Approximate Conv2 Weights and Biases
# ################################################
if isinstance(network[id], Conv2dApproximation) is True:
if network[id]._w_trainable is True:
# Save the new values
np.save(
f"{cfg.weight_path}/Weight_L{id}_S{iteration_number}.npy",
network[id].weights.data.detach().cpu().numpy(),
)
# Save the new values
if network[id].bias is not None:
np.save(
f"{cfg.weight_path}/Bias_L{id}_S{iteration_number}.npy",
network[id].bias.data.detach().cpu().numpy(),
)
if isinstance(network[id], SplitOnOffLayer) is True:
np.save(
f"{cfg.weight_path}/Mean_L{id}_S{iteration_number}.npy",
network[id].mean.detach().cpu().numpy(),
)