pytorch-sbs/network/save_weight_and_bias.py

114 lines
3.9 KiB
Python
Raw Normal View History

2023-01-05 13:23:58 +01:00
import torch
from network.Parameter import Config
import numpy as np
2023-02-04 14:24:47 +01:00
from network.SbSLayer import SbSLayer
2023-02-21 14:37:51 +01:00
from network.NNMFLayer import NNMFLayer
2023-01-05 13:23:58 +01:00
from network.SplitOnOffLayer import SplitOnOffLayer
from network.Conv2dApproximation import Conv2dApproximation
2023-02-04 14:24:47 +01:00
import os
2023-01-05 13:23:58 +01:00
def save_weight_and_bias(
2023-02-04 14:24:47 +01:00
cfg: Config,
network: torch.nn.modules.container.Sequential,
iteration_number: int,
order_id: float | int | None = None,
2023-01-05 13:23:58 +01:00
) -> None:
2023-02-04 14:24:47 +01:00
if order_id is None:
post_fix: str = ""
else:
post_fix = f"_{order_id}"
2023-01-05 13:23:58 +01:00
for id in range(0, len(network)):
# ################################################
# Save the SbS Weights
# ################################################
2023-02-04 14:24:47 +01:00
if isinstance(network[id], SbSLayer) is True:
2023-01-05 13:23:58 +01:00
if network[id]._w_trainable is True:
np.save(
2023-02-04 14:24:47 +01:00
os.path.join(
cfg.weight_path,
f"Weight_L{id}_S{iteration_number}{post_fix}.npy",
),
2023-01-05 13:23:58 +01:00
network[id].weights.detach().cpu().numpy(),
)
2023-02-21 14:37:51 +01:00
# ################################################
# Save the NNMF Weights
# ################################################
if isinstance(network[id], NNMFLayer) is True:
if network[id]._w_trainable is True:
np.save(
os.path.join(
cfg.weight_path,
f"Weight_L{id}_S{iteration_number}{post_fix}.npy",
),
network[id].weights.detach().cpu().numpy(),
)
2023-01-05 13:23:58 +01:00
# ################################################
# Save the Conv2 Weights and Biases
# ################################################
if isinstance(network[id], torch.nn.modules.conv.Conv2d) is True:
if network[id]._w_trainable is True:
# Save the new values
np.save(
2023-02-04 14:24:47 +01:00
os.path.join(
cfg.weight_path,
f"Weight_L{id}_S{iteration_number}{post_fix}.npy",
),
2023-01-05 13:23:58 +01:00
network[id]._parameters["weight"].data.detach().cpu().numpy(),
)
# Save the new values
np.save(
2023-02-04 14:24:47 +01:00
os.path.join(
cfg.weight_path, f"Bias_L{id}_S{iteration_number}{post_fix}.npy"
),
2023-01-05 13:23:58 +01:00
network[id]._parameters["bias"].data.detach().cpu().numpy(),
)
# ################################################
# Save the Approximate Conv2 Weights and Biases
# ################################################
if isinstance(network[id], Conv2dApproximation) is True:
if network[id]._w_trainable is True:
# Save the new values
np.save(
2023-02-04 14:24:47 +01:00
os.path.join(
cfg.weight_path,
f"Weight_L{id}_S{iteration_number}{post_fix}.npy",
),
2023-01-05 13:23:58 +01:00
network[id].weights.data.detach().cpu().numpy(),
)
# Save the new values
if network[id].bias is not None:
np.save(
2023-02-04 14:24:47 +01:00
os.path.join(
cfg.weight_path,
f"Bias_L{id}_S{iteration_number}{post_fix}.npy",
),
2023-01-05 13:23:58 +01:00
network[id].bias.data.detach().cpu().numpy(),
)
if isinstance(network[id], SplitOnOffLayer) is True:
2023-02-21 14:37:51 +01:00
if network[id].mean is not None:
np.save(
os.path.join(
cfg.weight_path, f"Mean_L{id}_S{iteration_number}{post_fix}.npy"
),
network[id].mean.detach().cpu().numpy(),
)