pytorch-sbs/network/load_previous_weights.py

174 lines
6.2 KiB
Python
Raw Normal View History

2023-01-05 13:23:58 +01:00
# %%
import torch
import glob
import numpy as np
2023-02-04 14:24:47 +01:00
from network.SbSLayer import SbSLayer
2023-02-21 14:37:51 +01:00
from network.NNMFLayer import NNMFLayer
2023-03-15 16:41:33 +01:00
from network.NNMFLayerSbSBP import NNMFLayerSbSBP
2023-02-21 14:37:51 +01:00
2023-01-05 13:23:58 +01:00
from network.SplitOnOffLayer import SplitOnOffLayer
from network.Conv2dApproximation import Conv2dApproximation
2023-02-04 14:24:47 +01:00
import os
2023-01-05 13:23:58 +01:00
def load_previous_weights(
network: torch.nn.Sequential,
overload_path: str,
logging,
device: torch.device,
default_dtype: torch.dtype,
2023-02-04 14:24:47 +01:00
order_id: float | int | None = None,
2023-01-05 13:23:58 +01:00
) -> None:
2023-02-04 14:24:47 +01:00
if order_id is None:
post_fix: str = ""
else:
post_fix = f"_{order_id}"
2023-01-05 13:23:58 +01:00
for id in range(0, len(network)):
# #################################################
# SbS
# #################################################
2023-02-04 14:24:47 +01:00
if isinstance(network[id], SbSLayer) is True:
filename_wilcard = os.path.join(
overload_path, f"Weight_L{id}_*{post_fix}.npy"
)
file_to_load = glob.glob(filename_wilcard)
2023-01-05 13:23:58 +01:00
if len(file_to_load) > 1:
2023-02-04 14:24:47 +01:00
raise Exception(f"Too many previous weights files {filename_wilcard}")
2023-01-05 13:23:58 +01:00
if len(file_to_load) == 1:
network[id].weights = torch.tensor(
np.load(file_to_load[0]),
dtype=default_dtype,
device=device,
)
logging.info(f"Weights file used for layer {id} : {file_to_load[0]}")
2023-03-15 16:41:33 +01:00
if (isinstance(network[id], NNMFLayer) is True) or (
isinstance(network[id], NNMFLayerSbSBP) is True
):
2023-02-21 14:37:51 +01:00
filename_wilcard = os.path.join(
overload_path, f"Weight_L{id}_*{post_fix}.npy"
)
file_to_load = glob.glob(filename_wilcard)
if len(file_to_load) > 1:
raise Exception(f"Too many previous weights files {filename_wilcard}")
if len(file_to_load) == 1:
network[id].weights = torch.tensor(
np.load(file_to_load[0]),
dtype=default_dtype,
device=device,
)
logging.info(f"Weights file used for layer {id} : {file_to_load[0]}")
2023-01-05 13:23:58 +01:00
if isinstance(network[id], torch.nn.modules.conv.Conv2d) is True:
# #################################################
# Conv2d weights
# #################################################
2023-02-04 14:24:47 +01:00
filename_wilcard = os.path.join(
overload_path, f"Weight_L{id}_*{post_fix}.npy"
)
file_to_load = glob.glob(filename_wilcard)
2023-01-05 13:23:58 +01:00
if len(file_to_load) > 1:
2023-02-04 14:24:47 +01:00
raise Exception(f"Too many previous weights files {filename_wilcard}")
2023-01-05 13:23:58 +01:00
if len(file_to_load) == 1:
network[id]._parameters["weight"].data = torch.tensor(
np.load(file_to_load[0]),
dtype=default_dtype,
device=device,
)
logging.info(f"Weights file used for layer {id} : {file_to_load[0]}")
# #################################################
# Conv2d bias
# #################################################
2023-02-04 14:24:47 +01:00
filename_wilcard = os.path.join(
overload_path, f"Bias_L{id}_*{post_fix}.npy"
)
file_to_load = glob.glob(filename_wilcard)
2023-01-05 13:23:58 +01:00
if len(file_to_load) > 1:
2023-02-04 14:24:47 +01:00
raise Exception(f"Too many previous weights files {filename_wilcard}")
2023-01-05 13:23:58 +01:00
if len(file_to_load) == 1:
network[id]._parameters["bias"].data = torch.tensor(
np.load(file_to_load[0]),
dtype=default_dtype,
device=device,
)
logging.info(f"Bias file used for layer {id} : {file_to_load[0]}")
if isinstance(network[id], Conv2dApproximation) is True:
# #################################################
# Approximate Conv2d weights
# #################################################
2023-02-04 14:24:47 +01:00
filename_wilcard = os.path.join(
overload_path, f"Weight_L{id}_*{post_fix}.npy"
)
file_to_load = glob.glob(filename_wilcard)
2023-01-05 13:23:58 +01:00
if len(file_to_load) > 1:
2023-02-04 14:24:47 +01:00
raise Exception(f"Too many previous weights files {filename_wilcard}")
2023-01-05 13:23:58 +01:00
if len(file_to_load) == 1:
network[id].weights.data = torch.tensor(
np.load(file_to_load[0]),
dtype=default_dtype,
device=device,
)
logging.info(f"Weights file used for layer {id} : {file_to_load[0]}")
# #################################################
# Approximate Conv2d bias
# #################################################
2023-02-04 14:24:47 +01:00
filename_wilcard = os.path.join(
overload_path, f"Bias_L{id}_*{post_fix}.npy"
)
file_to_load = glob.glob(filename_wilcard)
2023-01-05 13:23:58 +01:00
if len(file_to_load) > 1:
2023-02-04 14:24:47 +01:00
raise Exception(f"Too many previous weights files {filename_wilcard}")
2023-01-05 13:23:58 +01:00
if len(file_to_load) == 1:
network[id].bias.data = torch.tensor(
np.load(file_to_load[0]),
dtype=default_dtype,
device=device,
)
logging.info(f"Bias file used for layer {id} : {file_to_load[0]}")
# #################################################
# SplitOnOffLayer
# #################################################
if isinstance(network[id], SplitOnOffLayer) is True:
2023-02-04 14:24:47 +01:00
filename_wilcard = os.path.join(
overload_path, f"Mean_L{id}_*{post_fix}.npy"
)
file_to_load = glob.glob(filename_wilcard)
2023-01-05 13:23:58 +01:00
if len(file_to_load) > 1:
2023-02-04 14:24:47 +01:00
raise Exception(f"Too many previous mean files {filename_wilcard}")
2023-01-05 13:23:58 +01:00
if len(file_to_load) == 1:
network[id].mean = torch.tensor(
np.load(file_to_load[0]),
dtype=default_dtype,
device=device,
)
logging.info(f"Meanfile used for layer {id} : {file_to_load[0]}")