526 lines
20 KiB
Python
526 lines
20 KiB
Python
|
import torch
|
||
|
from tools.append_block import append_block
|
||
|
from tools.L1NormLayer import L1NormLayer
|
||
|
from tools.NNMF2d import NNMF2d
|
||
|
from tools.append_parameter import append_parameter
|
||
|
|
||
|
import json
|
||
|
from jsmin import jsmin
|
||
|
|
||
|
|
||
|
def make_network(
|
||
|
input_dim_x: int,
|
||
|
input_dim_y: int,
|
||
|
input_number_of_channel: int,
|
||
|
device: torch.device,
|
||
|
config_network_filename: str = "config_network.json",
|
||
|
) -> tuple[
|
||
|
torch.nn.Sequential,
|
||
|
list[list[torch.nn.parameter.Parameter]],
|
||
|
list[str],
|
||
|
]:
|
||
|
|
||
|
with open(config_network_filename, "r") as file:
|
||
|
minified = jsmin(file.read())
|
||
|
config_network = json.loads(minified)
|
||
|
|
||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||
|
list(config_network["number_of_neurons_b"])
|
||
|
)
|
||
|
|
||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||
|
list(config_network["kernel_size_conv"])
|
||
|
)
|
||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||
|
list(config_network["stride_conv"])
|
||
|
)
|
||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||
|
list(config_network["padding_conv"])
|
||
|
)
|
||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||
|
list(config_network["dilation_conv"])
|
||
|
)
|
||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||
|
list(config_network["kernel_size_pool"])
|
||
|
)
|
||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||
|
list(config_network["stride_pool"])
|
||
|
)
|
||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||
|
list(config_network["padding_pool"])
|
||
|
)
|
||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||
|
list(config_network["dilation_pool"])
|
||
|
)
|
||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||
|
list(config_network["type_of_pooling"])
|
||
|
)
|
||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||
|
list(config_network["local_learning_pooling"])
|
||
|
)
|
||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||
|
list(config_network["local_learning_use_kl_pooling"])
|
||
|
)
|
||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||
|
list(config_network["type_of_neuron_a"])
|
||
|
)
|
||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||
|
list(config_network["type_of_neuron_b"])
|
||
|
)
|
||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||
|
list(config_network["batch_norm_neuron_a"])
|
||
|
)
|
||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||
|
list(config_network["batch_norm_neuron_b"])
|
||
|
)
|
||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||
|
list(config_network["bias_norm_neuron_a"])
|
||
|
)
|
||
|
assert len(list(config_network["number_of_neurons_a"])) == len(
|
||
|
list(config_network["bias_norm_neuron_b"])
|
||
|
)
|
||
|
|
||
|
parameter_neuron_b: list[torch.nn.parameter.Parameter] = []
|
||
|
parameter_neuron_a: list[torch.nn.parameter.Parameter] = []
|
||
|
parameter_batchnorm2d: list[torch.nn.parameter.Parameter] = []
|
||
|
parameter_neuron_pool: list[torch.nn.parameter.Parameter] = []
|
||
|
|
||
|
test_image = torch.ones(
|
||
|
(1, input_number_of_channel, input_dim_x, input_dim_y), device=device
|
||
|
)
|
||
|
|
||
|
network = torch.nn.Sequential()
|
||
|
network = network.to(device)
|
||
|
|
||
|
epsilon: float | None = None
|
||
|
|
||
|
if isinstance(config_network["epsilon"], float):
|
||
|
epsilon = float(config_network["epsilon"])
|
||
|
|
||
|
for block_id in range(0, len(list(config_network["number_of_neurons_a"]))):
|
||
|
|
||
|
test_image = append_block(
|
||
|
network=network,
|
||
|
number_of_neurons_a=int(
|
||
|
list(config_network["number_of_neurons_a"])[block_id]
|
||
|
),
|
||
|
number_of_neurons_b=int(
|
||
|
list(config_network["number_of_neurons_b"])[block_id]
|
||
|
),
|
||
|
test_image=test_image,
|
||
|
dilation=list(list(config_network["dilation_conv"])[block_id]),
|
||
|
padding=list(list(config_network["padding_conv"])[block_id]),
|
||
|
stride=list(list(config_network["stride_conv"])[block_id]),
|
||
|
kernel_size=list(list(config_network["kernel_size_conv"])[block_id]),
|
||
|
epsilon=epsilon,
|
||
|
iterations=int(config_network["iterations"]),
|
||
|
device=device,
|
||
|
parameter_neuron_a=parameter_neuron_a,
|
||
|
parameter_neuron_b=parameter_neuron_b,
|
||
|
parameter_batchnorm2d=parameter_batchnorm2d,
|
||
|
type_of_neuron_a=int(list(config_network["type_of_neuron_a"])[block_id]),
|
||
|
type_of_neuron_b=int(list(config_network["type_of_neuron_b"])[block_id]),
|
||
|
batch_norm_neuron_a=bool(
|
||
|
list(config_network["batch_norm_neuron_a"])[block_id]
|
||
|
),
|
||
|
batch_norm_neuron_b=bool(
|
||
|
list(config_network["batch_norm_neuron_b"])[block_id]
|
||
|
),
|
||
|
bias_norm_neuron_a=bool(
|
||
|
list(config_network["bias_norm_neuron_a"])[block_id]
|
||
|
),
|
||
|
bias_norm_neuron_b=bool(
|
||
|
list(config_network["bias_norm_neuron_b"])[block_id]
|
||
|
),
|
||
|
)
|
||
|
|
||
|
if (int(list(list(config_network["kernel_size_pool"])[block_id])[0]) > 0) and (
|
||
|
(int(list(list(config_network["kernel_size_pool"])[block_id])[1]) > 0)
|
||
|
):
|
||
|
if int(list(config_network["type_of_pooling"])[block_id]) == 0:
|
||
|
pass
|
||
|
|
||
|
elif int(list(config_network["type_of_pooling"])[block_id]) == 1:
|
||
|
network.append(
|
||
|
torch.nn.AvgPool2d(
|
||
|
kernel_size=(
|
||
|
(
|
||
|
int(
|
||
|
list(
|
||
|
list(config_network["kernel_size_pool"])[
|
||
|
block_id
|
||
|
]
|
||
|
)[0]
|
||
|
)
|
||
|
),
|
||
|
(
|
||
|
int(
|
||
|
list(
|
||
|
list(config_network["kernel_size_pool"])[
|
||
|
block_id
|
||
|
]
|
||
|
)[1]
|
||
|
)
|
||
|
),
|
||
|
),
|
||
|
stride=(
|
||
|
(
|
||
|
int(
|
||
|
list(list(config_network["stride_pool"])[block_id])[
|
||
|
0
|
||
|
]
|
||
|
)
|
||
|
),
|
||
|
(
|
||
|
int(
|
||
|
list(list(config_network["stride_pool"])[block_id])[
|
||
|
1
|
||
|
]
|
||
|
)
|
||
|
),
|
||
|
),
|
||
|
padding=(
|
||
|
(
|
||
|
int(
|
||
|
list(
|
||
|
list(config_network["padding_pool"])[block_id]
|
||
|
)[0]
|
||
|
)
|
||
|
),
|
||
|
(
|
||
|
int(
|
||
|
list(
|
||
|
list(config_network["padding_pool"])[block_id]
|
||
|
)[1]
|
||
|
)
|
||
|
),
|
||
|
),
|
||
|
)
|
||
|
)
|
||
|
test_image = network[-1](test_image)
|
||
|
|
||
|
elif int(list(config_network["type_of_pooling"])[block_id]) == 2:
|
||
|
network.append(
|
||
|
torch.nn.MaxPool2d(
|
||
|
kernel_size=(
|
||
|
(
|
||
|
int(
|
||
|
list(
|
||
|
list(config_network["kernel_size_pool"])[
|
||
|
block_id
|
||
|
]
|
||
|
)[0]
|
||
|
)
|
||
|
),
|
||
|
(
|
||
|
int(
|
||
|
list(
|
||
|
list(config_network["kernel_size_pool"])[
|
||
|
block_id
|
||
|
]
|
||
|
)[1]
|
||
|
)
|
||
|
),
|
||
|
),
|
||
|
stride=(
|
||
|
(
|
||
|
int(
|
||
|
list(list(config_network["stride_pool"])[block_id])[
|
||
|
0
|
||
|
]
|
||
|
)
|
||
|
),
|
||
|
(
|
||
|
int(
|
||
|
list(list(config_network["stride_pool"])[block_id])[
|
||
|
1
|
||
|
]
|
||
|
)
|
||
|
),
|
||
|
),
|
||
|
padding=(
|
||
|
(
|
||
|
int(
|
||
|
list(
|
||
|
list(config_network["padding_pool"])[block_id]
|
||
|
)[0]
|
||
|
)
|
||
|
),
|
||
|
(
|
||
|
int(
|
||
|
list(
|
||
|
list(config_network["padding_pool"])[block_id]
|
||
|
)[1]
|
||
|
)
|
||
|
),
|
||
|
),
|
||
|
)
|
||
|
)
|
||
|
test_image = network[-1](test_image)
|
||
|
elif (int(list(config_network["type_of_pooling"])[block_id]) == 3) or (
|
||
|
int(list(config_network["type_of_pooling"])[block_id]) == 4
|
||
|
):
|
||
|
|
||
|
network.append(torch.nn.ReLU())
|
||
|
test_image = network[-1](test_image)
|
||
|
|
||
|
mock_output = (
|
||
|
torch.nn.functional.conv2d(
|
||
|
torch.zeros(
|
||
|
1,
|
||
|
1,
|
||
|
test_image.shape[2],
|
||
|
test_image.shape[3],
|
||
|
),
|
||
|
torch.zeros(
|
||
|
(
|
||
|
1,
|
||
|
1,
|
||
|
int(
|
||
|
list(
|
||
|
list(config_network["kernel_size_pool"])[
|
||
|
block_id
|
||
|
]
|
||
|
)[0]
|
||
|
),
|
||
|
int(
|
||
|
list(
|
||
|
list(config_network["kernel_size_pool"])[
|
||
|
block_id
|
||
|
]
|
||
|
)[1]
|
||
|
),
|
||
|
)
|
||
|
),
|
||
|
stride=(
|
||
|
(
|
||
|
int(
|
||
|
list(list(config_network["stride_pool"])[block_id])[
|
||
|
0
|
||
|
]
|
||
|
)
|
||
|
),
|
||
|
(
|
||
|
int(
|
||
|
list(list(config_network["stride_pool"])[block_id])[
|
||
|
1
|
||
|
]
|
||
|
)
|
||
|
),
|
||
|
),
|
||
|
padding=(
|
||
|
(
|
||
|
int(
|
||
|
list(
|
||
|
list(config_network["padding_pool"])[block_id]
|
||
|
)[0]
|
||
|
)
|
||
|
),
|
||
|
(
|
||
|
int(
|
||
|
list(
|
||
|
list(config_network["padding_pool"])[block_id]
|
||
|
)[1]
|
||
|
)
|
||
|
),
|
||
|
),
|
||
|
dilation=(
|
||
|
(
|
||
|
int(
|
||
|
list(
|
||
|
list(config_network["dilation_pool"])[block_id]
|
||
|
)[0]
|
||
|
)
|
||
|
),
|
||
|
(
|
||
|
int(
|
||
|
list(
|
||
|
list(config_network["dilation_pool"])[block_id]
|
||
|
)[1]
|
||
|
)
|
||
|
),
|
||
|
),
|
||
|
)
|
||
|
.squeeze(0)
|
||
|
.squeeze(0)
|
||
|
)
|
||
|
|
||
|
network.append(
|
||
|
torch.nn.Unfold(
|
||
|
kernel_size=(
|
||
|
int(
|
||
|
list(
|
||
|
list(config_network["kernel_size_pool"])[block_id]
|
||
|
)[0]
|
||
|
),
|
||
|
int(
|
||
|
list(
|
||
|
list(config_network["kernel_size_pool"])[block_id]
|
||
|
)[1]
|
||
|
),
|
||
|
),
|
||
|
stride=(
|
||
|
(
|
||
|
int(
|
||
|
list(list(config_network["stride_pool"])[block_id])[
|
||
|
0
|
||
|
]
|
||
|
)
|
||
|
),
|
||
|
(
|
||
|
int(
|
||
|
list(list(config_network["stride_pool"])[block_id])[
|
||
|
1
|
||
|
]
|
||
|
)
|
||
|
),
|
||
|
),
|
||
|
padding=(
|
||
|
(
|
||
|
int(
|
||
|
list(
|
||
|
list(config_network["padding_pool"])[block_id]
|
||
|
)[0]
|
||
|
)
|
||
|
),
|
||
|
(
|
||
|
int(
|
||
|
list(
|
||
|
list(config_network["padding_pool"])[block_id]
|
||
|
)[1]
|
||
|
)
|
||
|
),
|
||
|
),
|
||
|
dilation=(
|
||
|
(
|
||
|
int(
|
||
|
list(
|
||
|
list(config_network["dilation_pool"])[block_id]
|
||
|
)[0]
|
||
|
)
|
||
|
),
|
||
|
(
|
||
|
int(
|
||
|
list(
|
||
|
list(config_network["dilation_pool"])[block_id]
|
||
|
)[1]
|
||
|
)
|
||
|
),
|
||
|
),
|
||
|
)
|
||
|
)
|
||
|
test_image = network[-1](test_image)
|
||
|
|
||
|
network.append(
|
||
|
torch.nn.Fold(
|
||
|
output_size=mock_output.shape,
|
||
|
kernel_size=(1, 1),
|
||
|
dilation=1,
|
||
|
padding=0,
|
||
|
stride=1,
|
||
|
)
|
||
|
)
|
||
|
test_image = network[-1](test_image)
|
||
|
|
||
|
network.append(L1NormLayer())
|
||
|
test_image = network[-1](test_image)
|
||
|
|
||
|
if int(list(config_network["type_of_pooling"])[block_id]) == 3:
|
||
|
network.append(
|
||
|
torch.nn.Conv2d(
|
||
|
in_channels=test_image.shape[1],
|
||
|
out_channels=test_image.shape[1]
|
||
|
// (
|
||
|
int(
|
||
|
list(
|
||
|
list(config_network["kernel_size_pool"])[
|
||
|
block_id
|
||
|
]
|
||
|
)[0]
|
||
|
)
|
||
|
* int(
|
||
|
list(
|
||
|
list(config_network["kernel_size_pool"])[
|
||
|
block_id
|
||
|
]
|
||
|
)[1]
|
||
|
)
|
||
|
),
|
||
|
kernel_size=(1, 1),
|
||
|
bias=False,
|
||
|
).to(device)
|
||
|
)
|
||
|
else:
|
||
|
network.append(
|
||
|
NNMF2d(
|
||
|
in_channels=test_image.shape[1],
|
||
|
out_channels=test_image.shape[1]
|
||
|
// (
|
||
|
int(
|
||
|
list(
|
||
|
list(config_network["kernel_size_pool"])[
|
||
|
block_id
|
||
|
]
|
||
|
)[0]
|
||
|
)
|
||
|
* int(
|
||
|
list(
|
||
|
list(config_network["kernel_size_pool"])[
|
||
|
block_id
|
||
|
]
|
||
|
)[1]
|
||
|
)
|
||
|
),
|
||
|
epsilon=epsilon,
|
||
|
local_learning=bool(
|
||
|
list(config_network["local_learning_pooling"])[block_id]
|
||
|
),
|
||
|
local_learning_kl=bool(
|
||
|
list(config_network["local_learning_use_kl_pooling"])[
|
||
|
block_id
|
||
|
]
|
||
|
),
|
||
|
).to(device)
|
||
|
)
|
||
|
|
||
|
test_image = network[-1](test_image)
|
||
|
append_parameter(
|
||
|
module=network[-1], parameter_list=parameter_neuron_pool
|
||
|
)
|
||
|
|
||
|
network.append(
|
||
|
torch.nn.BatchNorm2d(
|
||
|
num_features=test_image.shape[1],
|
||
|
device=device,
|
||
|
momentum=0.1,
|
||
|
track_running_stats=False,
|
||
|
)
|
||
|
)
|
||
|
test_image = network[-1](test_image)
|
||
|
append_parameter(
|
||
|
module=network[-1], parameter_list=parameter_batchnorm2d
|
||
|
)
|
||
|
|
||
|
else:
|
||
|
assert int(list(config_network["type_of_pooling"])[block_id]) > 4
|
||
|
network.append(torch.nn.Softmax(dim=1))
|
||
|
test_image = network[-1](test_image)
|
||
|
|
||
|
network.append(torch.nn.Flatten())
|
||
|
test_image = network[-1](test_image)
|
||
|
|
||
|
parameters: list[list[torch.nn.parameter.Parameter]] = [
|
||
|
parameter_neuron_a,
|
||
|
parameter_neuron_b,
|
||
|
parameter_batchnorm2d,
|
||
|
parameter_neuron_pool,
|
||
|
]
|
||
|
|
||
|
name_list: list[str] = ["neuron a", "neuron b", "batchnorm2d", "neuron pool"]
|
||
|
|
||
|
return (
|
||
|
network,
|
||
|
parameters,
|
||
|
name_list,
|
||
|
)
|