72 lines
2.6 KiB
Python
72 lines
2.6 KiB
Python
|
import torch
|
||
|
|
||
|
from network.Parameter import Config
|
||
|
|
||
|
import numpy as np
|
||
|
from network.SbS import SbS
|
||
|
from network.SplitOnOffLayer import SplitOnOffLayer
|
||
|
from network.Conv2dApproximation import Conv2dApproximation
|
||
|
|
||
|
|
||
|
def save_weight_and_bias(
|
||
|
cfg: Config, network: torch.nn.modules.container.Sequential, iteration_number: int
|
||
|
) -> None:
|
||
|
|
||
|
for id in range(0, len(network)):
|
||
|
|
||
|
# ################################################
|
||
|
# Save the SbS Weights
|
||
|
# ################################################
|
||
|
|
||
|
if isinstance(network[id], SbS) is True:
|
||
|
if network[id]._w_trainable is True:
|
||
|
|
||
|
np.save(
|
||
|
f"{cfg.weight_path}/Weight_L{id}_S{iteration_number}.npy",
|
||
|
network[id].weights.detach().cpu().numpy(),
|
||
|
)
|
||
|
|
||
|
# ################################################
|
||
|
# Save the Conv2 Weights and Biases
|
||
|
# ################################################
|
||
|
|
||
|
if isinstance(network[id], torch.nn.modules.conv.Conv2d) is True:
|
||
|
if network[id]._w_trainable is True:
|
||
|
# Save the new values
|
||
|
np.save(
|
||
|
f"{cfg.weight_path}/Weight_L{id}_S{iteration_number}.npy",
|
||
|
network[id]._parameters["weight"].data.detach().cpu().numpy(),
|
||
|
)
|
||
|
|
||
|
# Save the new values
|
||
|
np.save(
|
||
|
f"{cfg.weight_path}/Bias_L{id}_S{iteration_number}.npy",
|
||
|
network[id]._parameters["bias"].data.detach().cpu().numpy(),
|
||
|
)
|
||
|
|
||
|
# ################################################
|
||
|
# Save the Approximate Conv2 Weights and Biases
|
||
|
# ################################################
|
||
|
|
||
|
if isinstance(network[id], Conv2dApproximation) is True:
|
||
|
if network[id]._w_trainable is True:
|
||
|
# Save the new values
|
||
|
np.save(
|
||
|
f"{cfg.weight_path}/Weight_L{id}_S{iteration_number}.npy",
|
||
|
network[id].weights.data.detach().cpu().numpy(),
|
||
|
)
|
||
|
|
||
|
# Save the new values
|
||
|
if network[id].bias is not None:
|
||
|
np.save(
|
||
|
f"{cfg.weight_path}/Bias_L{id}_S{iteration_number}.npy",
|
||
|
network[id].bias.data.detach().cpu().numpy(),
|
||
|
)
|
||
|
|
||
|
if isinstance(network[id], SplitOnOffLayer) is True:
|
||
|
|
||
|
np.save(
|
||
|
f"{cfg.weight_path}/Mean_L{id}_S{iteration_number}.npy",
|
||
|
network[id].mean.detach().cpu().numpy(),
|
||
|
)
|