Add files via upload

This commit is contained in:
David Rotermund 2023-07-27 20:13:44 +02:00 committed by GitHub
parent fa1f80ba70
commit 842225ae29
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 124 additions and 31 deletions

View file

@ -1,5 +1,6 @@
import torch import torch
import numpy as np
# import numpy as np
from functions.SoftmaxPower import SoftmaxPower from functions.SoftmaxPower import SoftmaxPower
@ -19,22 +20,44 @@ def make_cnn(
conv_0_meanmode_softmax: bool, conv_0_meanmode_softmax: bool,
conv_0_no_input_mode_softmax: bool, conv_0_no_input_mode_softmax: bool,
l_relu_negative_slope: float, l_relu_negative_slope: float,
input_shape: torch.Size,
) -> torch.nn.Sequential: ) -> torch.nn.Sequential:
assert len(conv_out_channels_list) >= 1 assert len(conv_out_channels_list) >= 1
assert len(conv_out_channels_list) == len(conv_kernel_size) + 1 assert len(conv_out_channels_list) == len(conv_kernel_size) + 1
cnn = torch.nn.Sequential() cnn = torch.nn.Sequential()
temp_image: torch.Tensor = torch.zeros(
(1, *input_shape), dtype=torch.float32, device=torch.device("cpu")
)
logger.info(
(
f"Input shape: {int(temp_image.shape[1])}, "
f"{int(temp_image.shape[2])}, "
f"{int(temp_image.shape[3])}"
)
)
layer_counter: int = 0
# Fixed structure # Fixed structure
cnn.append( cnn.append(
torch.nn.Conv2d( torch.nn.Conv2d(
in_channels=1, in_channels=int(temp_image.shape[0]),
out_channels=conv_out_channels_list[0] if train_conv_0 else 32, out_channels=conv_out_channels_list[0] if train_conv_0 else 32,
kernel_size=conv_0_kernel_size, kernel_size=conv_0_kernel_size,
stride=1, stride=1,
bias=train_conv_0, bias=train_conv_0,
) )
) )
temp_image = cnn[layer_counter](temp_image)
logger.info(
(
f"After layer {layer_counter}: {int(temp_image.shape[1])}, "
f"{int(temp_image.shape[2])}, "
f"{int(temp_image.shape[3])}"
)
)
layer_counter += 1
setting_understood: bool = False setting_understood: bool = False
if conv_activation_function.upper() == str("relu").upper(): if conv_activation_function.upper() == str("relu").upper():
@ -49,16 +72,15 @@ def make_cnn(
elif conv_activation_function.upper() == str("none").upper(): elif conv_activation_function.upper() == str("none").upper():
setting_understood = True setting_understood = True
assert setting_understood assert setting_understood
temp_image = cnn[layer_counter](temp_image)
if conv_0_enable_softmax: logger.info(
cnn.append( (
SoftmaxPower( f"After layer {layer_counter}: {int(temp_image.shape[1])}, "
dim=1, f"{int(temp_image.shape[2])}, "
power=conv_0_power_softmax, f"{int(temp_image.shape[3])}"
mean_mode=conv_0_meanmode_softmax,
no_input_mode=conv_0_no_input_mode_softmax,
) )
) )
layer_counter += 1
setting_understood = False setting_understood = False
if pooling_type.upper() == str("max").upper(): if pooling_type.upper() == str("max").upper():
@ -70,7 +92,34 @@ def make_cnn(
elif pooling_type.upper() == str("none").upper(): elif pooling_type.upper() == str("none").upper():
setting_understood = True setting_understood = True
assert setting_understood assert setting_understood
temp_image = cnn[layer_counter](temp_image)
logger.info(
(
f"After layer {layer_counter}: {int(temp_image.shape[1])}, "
f"{int(temp_image.shape[2])}, "
f"{int(temp_image.shape[3])}"
)
)
layer_counter += 1
if conv_0_enable_softmax:
cnn.append(
SoftmaxPower(
dim=1,
power=conv_0_power_softmax,
mean_mode=conv_0_meanmode_softmax,
no_input_mode=conv_0_no_input_mode_softmax,
)
)
temp_image = cnn[layer_counter](temp_image)
logger.info(
(
f"After layer {layer_counter}: {int(temp_image.shape[1])}, "
f"{int(temp_image.shape[2])}, "
f"{int(temp_image.shape[3])}"
)
)
layer_counter += 1
# Changing structure # Changing structure
for i in range(1, len(conv_out_channels_list)): for i in range(1, len(conv_out_channels_list)):
@ -87,6 +136,16 @@ def make_cnn(
bias=True, bias=True,
) )
) )
temp_image = cnn[layer_counter](temp_image)
logger.info(
(
f"After layer {layer_counter}: {int(temp_image.shape[1])}, "
f"{int(temp_image.shape[2])}, "
f"{int(temp_image.shape[3])}"
)
)
layer_counter += 1
setting_understood = False setting_understood = False
if conv_activation_function.upper() == str("relu").upper(): if conv_activation_function.upper() == str("relu").upper():
cnn.append(torch.nn.ReLU()) cnn.append(torch.nn.ReLU())
@ -101,26 +160,53 @@ def make_cnn(
setting_understood = True setting_understood = True
assert setting_understood assert setting_understood
temp_image = cnn[layer_counter](temp_image)
# Fixed structure logger.info(
# define fully connected layer: (
cnn.append(torch.nn.Flatten(start_dim=1)) f"After layer {layer_counter}: {int(temp_image.shape[1])}, "
cnn.append(torch.nn.LazyLinear(2, bias=True)) f"{int(temp_image.shape[2])}, "
f"{int(temp_image.shape[3])}"
# if conv1 not trained:
filename_load_weight_0: str | None = None
if train_conv_0 is False and cnn[0]._parameters["weight"].shape[0] == 32:
filename_load_weight_0 = "weights_radius10.npy"
if train_conv_0 is False and cnn[0]._parameters["weight"].shape[0] == 16:
filename_load_weight_0 = "8orient_2phase_weights.npy"
if filename_load_weight_0 is not None:
logger.info(f"Replace weights in CNN 0 with {filename_load_weight_0}")
cnn[0]._parameters["weight"] = torch.tensor(
np.load(filename_load_weight_0),
dtype=cnn[0]._parameters["weight"].dtype,
requires_grad=False,
device=cnn[0]._parameters["weight"].device,
) )
)
layer_counter += 1
# Output layer
cnn.append(
torch.nn.Conv2d(
in_channels=int(temp_image.shape[1]),
out_channels=2,
kernel_size=(int(temp_image.shape[2]), int(temp_image.shape[3])),
stride=1,
bias=True,
)
)
temp_image = cnn[layer_counter](temp_image)
logger.info(
(
f"After layer {layer_counter}: {int(temp_image.shape[1])}, "
f"{int(temp_image.shape[2])}, "
f"{int(temp_image.shape[3])}"
)
)
layer_counter += 1
# Need to repair loading data
assert train_conv_0 is True
# # if conv1 not trained:
# filename_load_weight_0: str | None = None
# if train_conv_0 is False and cnn[0]._parameters["weight"].shape[0] == 32:
# filename_load_weight_0 = "weights_radius10.npy"
# if train_conv_0 is False and cnn[0]._parameters["weight"].shape[0] == 16:
# filename_load_weight_0 = "8orient_2phase_weights.npy"
# if filename_load_weight_0 is not None:
# logger.info(f"Replace weights in CNN 0 with {filename_load_weight_0}")
# cnn[0]._parameters["weight"] = torch.tensor(
# np.load(filename_load_weight_0),
# dtype=cnn[0]._parameters["weight"].dtype,
# requires_grad=False,
# device=cnn[0]._parameters["weight"].device,
# )
return cnn return cnn

View file

@ -27,6 +27,9 @@ def test(
image /= scale_data image /= scale_data
output = model(image) output = model(image)
if output.ndim == 4:
output = output.squeeze(-1).squeeze(-1)
assert output.ndim == 2
# loss and optimization # loss and optimization
loss = torch.nn.functional.cross_entropy(output, label, reduction="sum") loss = torch.nn.functional.cross_entropy(output, label, reduction="sum")

View file

@ -30,6 +30,10 @@ def train(
optimizer.zero_grad() optimizer.zero_grad()
output = model(image) output = model(image)
if output.ndim == 4:
output = output.squeeze(-1).squeeze(-1)
assert output.ndim == 2
loss = torch.nn.functional.cross_entropy(output, label, reduction="sum") loss = torch.nn.functional.cross_entropy(output, label, reduction="sum")
loss.backward() loss.backward()