Add files via upload
This commit is contained in:
parent
fa1f80ba70
commit
842225ae29
3 changed files with 124 additions and 31 deletions
|
@ -1,5 +1,6 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
|
||||
# import numpy as np
|
||||
from functions.SoftmaxPower import SoftmaxPower
|
||||
|
||||
|
||||
|
@ -19,22 +20,44 @@ def make_cnn(
|
|||
conv_0_meanmode_softmax: bool,
|
||||
conv_0_no_input_mode_softmax: bool,
|
||||
l_relu_negative_slope: float,
|
||||
input_shape: torch.Size,
|
||||
) -> torch.nn.Sequential:
|
||||
assert len(conv_out_channels_list) >= 1
|
||||
assert len(conv_out_channels_list) == len(conv_kernel_size) + 1
|
||||
|
||||
cnn = torch.nn.Sequential()
|
||||
|
||||
temp_image: torch.Tensor = torch.zeros(
|
||||
(1, *input_shape), dtype=torch.float32, device=torch.device("cpu")
|
||||
)
|
||||
logger.info(
|
||||
(
|
||||
f"Input shape: {int(temp_image.shape[1])}, "
|
||||
f"{int(temp_image.shape[2])}, "
|
||||
f"{int(temp_image.shape[3])}"
|
||||
)
|
||||
)
|
||||
layer_counter: int = 0
|
||||
|
||||
# Fixed structure
|
||||
cnn.append(
|
||||
torch.nn.Conv2d(
|
||||
in_channels=1,
|
||||
in_channels=int(temp_image.shape[0]),
|
||||
out_channels=conv_out_channels_list[0] if train_conv_0 else 32,
|
||||
kernel_size=conv_0_kernel_size,
|
||||
stride=1,
|
||||
bias=train_conv_0,
|
||||
)
|
||||
)
|
||||
temp_image = cnn[layer_counter](temp_image)
|
||||
logger.info(
|
||||
(
|
||||
f"After layer {layer_counter}: {int(temp_image.shape[1])}, "
|
||||
f"{int(temp_image.shape[2])}, "
|
||||
f"{int(temp_image.shape[3])}"
|
||||
)
|
||||
)
|
||||
layer_counter += 1
|
||||
|
||||
setting_understood: bool = False
|
||||
if conv_activation_function.upper() == str("relu").upper():
|
||||
|
@ -49,16 +72,15 @@ def make_cnn(
|
|||
elif conv_activation_function.upper() == str("none").upper():
|
||||
setting_understood = True
|
||||
assert setting_understood
|
||||
|
||||
if conv_0_enable_softmax:
|
||||
cnn.append(
|
||||
SoftmaxPower(
|
||||
dim=1,
|
||||
power=conv_0_power_softmax,
|
||||
mean_mode=conv_0_meanmode_softmax,
|
||||
no_input_mode=conv_0_no_input_mode_softmax,
|
||||
temp_image = cnn[layer_counter](temp_image)
|
||||
logger.info(
|
||||
(
|
||||
f"After layer {layer_counter}: {int(temp_image.shape[1])}, "
|
||||
f"{int(temp_image.shape[2])}, "
|
||||
f"{int(temp_image.shape[3])}"
|
||||
)
|
||||
)
|
||||
layer_counter += 1
|
||||
|
||||
setting_understood = False
|
||||
if pooling_type.upper() == str("max").upper():
|
||||
|
@ -70,7 +92,34 @@ def make_cnn(
|
|||
elif pooling_type.upper() == str("none").upper():
|
||||
setting_understood = True
|
||||
assert setting_understood
|
||||
temp_image = cnn[layer_counter](temp_image)
|
||||
logger.info(
|
||||
(
|
||||
f"After layer {layer_counter}: {int(temp_image.shape[1])}, "
|
||||
f"{int(temp_image.shape[2])}, "
|
||||
f"{int(temp_image.shape[3])}"
|
||||
)
|
||||
)
|
||||
layer_counter += 1
|
||||
|
||||
if conv_0_enable_softmax:
|
||||
cnn.append(
|
||||
SoftmaxPower(
|
||||
dim=1,
|
||||
power=conv_0_power_softmax,
|
||||
mean_mode=conv_0_meanmode_softmax,
|
||||
no_input_mode=conv_0_no_input_mode_softmax,
|
||||
)
|
||||
)
|
||||
temp_image = cnn[layer_counter](temp_image)
|
||||
logger.info(
|
||||
(
|
||||
f"After layer {layer_counter}: {int(temp_image.shape[1])}, "
|
||||
f"{int(temp_image.shape[2])}, "
|
||||
f"{int(temp_image.shape[3])}"
|
||||
)
|
||||
)
|
||||
layer_counter += 1
|
||||
|
||||
# Changing structure
|
||||
for i in range(1, len(conv_out_channels_list)):
|
||||
|
@ -87,6 +136,16 @@ def make_cnn(
|
|||
bias=True,
|
||||
)
|
||||
)
|
||||
temp_image = cnn[layer_counter](temp_image)
|
||||
logger.info(
|
||||
(
|
||||
f"After layer {layer_counter}: {int(temp_image.shape[1])}, "
|
||||
f"{int(temp_image.shape[2])}, "
|
||||
f"{int(temp_image.shape[3])}"
|
||||
)
|
||||
)
|
||||
layer_counter += 1
|
||||
|
||||
setting_understood = False
|
||||
if conv_activation_function.upper() == str("relu").upper():
|
||||
cnn.append(torch.nn.ReLU())
|
||||
|
@ -101,26 +160,53 @@ def make_cnn(
|
|||
setting_understood = True
|
||||
|
||||
assert setting_understood
|
||||
|
||||
# Fixed structure
|
||||
# define fully connected layer:
|
||||
cnn.append(torch.nn.Flatten(start_dim=1))
|
||||
cnn.append(torch.nn.LazyLinear(2, bias=True))
|
||||
|
||||
# if conv1 not trained:
|
||||
filename_load_weight_0: str | None = None
|
||||
if train_conv_0 is False and cnn[0]._parameters["weight"].shape[0] == 32:
|
||||
filename_load_weight_0 = "weights_radius10.npy"
|
||||
if train_conv_0 is False and cnn[0]._parameters["weight"].shape[0] == 16:
|
||||
filename_load_weight_0 = "8orient_2phase_weights.npy"
|
||||
|
||||
if filename_load_weight_0 is not None:
|
||||
logger.info(f"Replace weights in CNN 0 with {filename_load_weight_0}")
|
||||
cnn[0]._parameters["weight"] = torch.tensor(
|
||||
np.load(filename_load_weight_0),
|
||||
dtype=cnn[0]._parameters["weight"].dtype,
|
||||
requires_grad=False,
|
||||
device=cnn[0]._parameters["weight"].device,
|
||||
temp_image = cnn[layer_counter](temp_image)
|
||||
logger.info(
|
||||
(
|
||||
f"After layer {layer_counter}: {int(temp_image.shape[1])}, "
|
||||
f"{int(temp_image.shape[2])}, "
|
||||
f"{int(temp_image.shape[3])}"
|
||||
)
|
||||
)
|
||||
layer_counter += 1
|
||||
|
||||
# Output layer
|
||||
cnn.append(
|
||||
torch.nn.Conv2d(
|
||||
in_channels=int(temp_image.shape[1]),
|
||||
out_channels=2,
|
||||
kernel_size=(int(temp_image.shape[2]), int(temp_image.shape[3])),
|
||||
stride=1,
|
||||
bias=True,
|
||||
)
|
||||
)
|
||||
temp_image = cnn[layer_counter](temp_image)
|
||||
logger.info(
|
||||
(
|
||||
f"After layer {layer_counter}: {int(temp_image.shape[1])}, "
|
||||
f"{int(temp_image.shape[2])}, "
|
||||
f"{int(temp_image.shape[3])}"
|
||||
)
|
||||
)
|
||||
layer_counter += 1
|
||||
|
||||
# Need to repair loading data
|
||||
assert train_conv_0 is True
|
||||
|
||||
# # if conv1 not trained:
|
||||
# filename_load_weight_0: str | None = None
|
||||
# if train_conv_0 is False and cnn[0]._parameters["weight"].shape[0] == 32:
|
||||
# filename_load_weight_0 = "weights_radius10.npy"
|
||||
# if train_conv_0 is False and cnn[0]._parameters["weight"].shape[0] == 16:
|
||||
# filename_load_weight_0 = "8orient_2phase_weights.npy"
|
||||
|
||||
# if filename_load_weight_0 is not None:
|
||||
# logger.info(f"Replace weights in CNN 0 with {filename_load_weight_0}")
|
||||
# cnn[0]._parameters["weight"] = torch.tensor(
|
||||
# np.load(filename_load_weight_0),
|
||||
# dtype=cnn[0]._parameters["weight"].dtype,
|
||||
# requires_grad=False,
|
||||
# device=cnn[0]._parameters["weight"].device,
|
||||
# )
|
||||
|
||||
return cnn
|
||||
|
|
|
@ -27,6 +27,9 @@ def test(
|
|||
image /= scale_data
|
||||
|
||||
output = model(image)
|
||||
if output.ndim == 4:
|
||||
output = output.squeeze(-1).squeeze(-1)
|
||||
assert output.ndim == 2
|
||||
|
||||
# loss and optimization
|
||||
loss = torch.nn.functional.cross_entropy(output, label, reduction="sum")
|
||||
|
|
|
@ -30,6 +30,10 @@ def train(
|
|||
|
||||
optimizer.zero_grad()
|
||||
output = model(image)
|
||||
if output.ndim == 4:
|
||||
output = output.squeeze(-1).squeeze(-1)
|
||||
assert output.ndim == 2
|
||||
|
||||
loss = torch.nn.functional.cross_entropy(output, label, reduction="sum")
|
||||
loss.backward()
|
||||
|
||||
|
|
Loading…
Reference in a new issue