On/Off filter can be turned on/off
This commit is contained in:
parent
2929ba2a63
commit
92050f5933
2 changed files with 99 additions and 50 deletions
47
Dataset.py
47
Dataset.py
|
@ -135,10 +135,13 @@ class DatasetMNIST(DatasetMaster):
|
||||||
pattern = scripted_transforms(pattern)
|
pattern = scripted_transforms(pattern)
|
||||||
|
|
||||||
# => On/Off
|
# => On/Off
|
||||||
|
if cfg.augmentation.use_on_off_filter is True:
|
||||||
my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0])
|
my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0])
|
||||||
gray: torch.Tensor = my_on_off_filter(
|
gray: torch.Tensor = my_on_off_filter(
|
||||||
pattern[:, 0:1, :, :],
|
pattern[:, 0:1, :, :],
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
gray = pattern[:, 0:1, :, :] + torch.finfo(torch.float32).eps
|
||||||
|
|
||||||
return gray
|
return gray
|
||||||
|
|
||||||
|
@ -166,10 +169,13 @@ class DatasetMNIST(DatasetMaster):
|
||||||
pattern = scripted_transforms(pattern)
|
pattern = scripted_transforms(pattern)
|
||||||
|
|
||||||
# => On/Off
|
# => On/Off
|
||||||
|
if cfg.augmentation.use_on_off_filter is True:
|
||||||
my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0])
|
my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0])
|
||||||
gray: torch.Tensor = my_on_off_filter(
|
gray: torch.Tensor = my_on_off_filter(
|
||||||
pattern[:, 0:1, :, :],
|
pattern[:, 0:1, :, :],
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
gray = pattern[:, 0:1, :, :] + torch.finfo(torch.float32).eps
|
||||||
|
|
||||||
return gray
|
return gray
|
||||||
|
|
||||||
|
@ -225,10 +231,13 @@ class DatasetFashionMNIST(DatasetMaster):
|
||||||
pattern = scripted_transforms(pattern)
|
pattern = scripted_transforms(pattern)
|
||||||
|
|
||||||
# => On/Off
|
# => On/Off
|
||||||
|
if cfg.augmentation.use_on_off_filter is True:
|
||||||
my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0])
|
my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0])
|
||||||
gray: torch.Tensor = my_on_off_filter(
|
gray: torch.Tensor = my_on_off_filter(
|
||||||
pattern[:, 0:1, :, :],
|
pattern[:, 0:1, :, :],
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
gray = pattern[:, 0:1, :, :] + torch.finfo(torch.float32).eps
|
||||||
|
|
||||||
return gray
|
return gray
|
||||||
|
|
||||||
|
@ -263,10 +272,13 @@ class DatasetFashionMNIST(DatasetMaster):
|
||||||
pattern = scripted_transforms(pattern)
|
pattern = scripted_transforms(pattern)
|
||||||
|
|
||||||
# => On/Off
|
# => On/Off
|
||||||
|
if cfg.augmentation.use_on_off_filter is True:
|
||||||
my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0])
|
my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0])
|
||||||
gray: torch.Tensor = my_on_off_filter(
|
gray: torch.Tensor = my_on_off_filter(
|
||||||
pattern[:, 0:1, :, :],
|
pattern[:, 0:1, :, :],
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
gray = pattern[:, 0:1, :, :] + torch.finfo(torch.float32).eps
|
||||||
|
|
||||||
return gray
|
return gray
|
||||||
|
|
||||||
|
@ -321,10 +333,16 @@ class DatasetCIFAR(DatasetMaster):
|
||||||
pattern = scripted_transforms(pattern)
|
pattern = scripted_transforms(pattern)
|
||||||
|
|
||||||
# => On/Off
|
# => On/Off
|
||||||
|
if cfg.augmentation.use_on_off_filter is True:
|
||||||
my_on_off_filter_r: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0])
|
my_on_off_filter_r: OnOffFilter = OnOffFilter(
|
||||||
my_on_off_filter_g: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[1])
|
p=cfg.image_statistics.mean[0]
|
||||||
my_on_off_filter_b: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[2])
|
)
|
||||||
|
my_on_off_filter_g: OnOffFilter = OnOffFilter(
|
||||||
|
p=cfg.image_statistics.mean[1]
|
||||||
|
)
|
||||||
|
my_on_off_filter_b: OnOffFilter = OnOffFilter(
|
||||||
|
p=cfg.image_statistics.mean[2]
|
||||||
|
)
|
||||||
r: torch.Tensor = my_on_off_filter_r(
|
r: torch.Tensor = my_on_off_filter_r(
|
||||||
pattern[:, 0:1, :, :],
|
pattern[:, 0:1, :, :],
|
||||||
)
|
)
|
||||||
|
@ -334,6 +352,10 @@ class DatasetCIFAR(DatasetMaster):
|
||||||
b: torch.Tensor = my_on_off_filter_b(
|
b: torch.Tensor = my_on_off_filter_b(
|
||||||
pattern[:, 2:3, :, :],
|
pattern[:, 2:3, :, :],
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
r = pattern[:, 0:1, :, :] + torch.finfo(torch.float32).eps
|
||||||
|
g = pattern[:, 1:2, :, :] + torch.finfo(torch.float32).eps
|
||||||
|
b = pattern[:, 2:3, :, :] + torch.finfo(torch.float32).eps
|
||||||
|
|
||||||
new_tensor: torch.Tensor = torch.cat((r, g, b), dim=1)
|
new_tensor: torch.Tensor = torch.cat((r, g, b), dim=1)
|
||||||
return new_tensor
|
return new_tensor
|
||||||
|
@ -370,9 +392,16 @@ class DatasetCIFAR(DatasetMaster):
|
||||||
pattern = scripted_transforms(pattern)
|
pattern = scripted_transforms(pattern)
|
||||||
|
|
||||||
# => On/Off
|
# => On/Off
|
||||||
my_on_off_filter_r: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0])
|
if cfg.augmentation.use_on_off_filter is True:
|
||||||
my_on_off_filter_g: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[1])
|
my_on_off_filter_r: OnOffFilter = OnOffFilter(
|
||||||
my_on_off_filter_b: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[2])
|
p=cfg.image_statistics.mean[0]
|
||||||
|
)
|
||||||
|
my_on_off_filter_g: OnOffFilter = OnOffFilter(
|
||||||
|
p=cfg.image_statistics.mean[1]
|
||||||
|
)
|
||||||
|
my_on_off_filter_b: OnOffFilter = OnOffFilter(
|
||||||
|
p=cfg.image_statistics.mean[2]
|
||||||
|
)
|
||||||
r: torch.Tensor = my_on_off_filter_r(
|
r: torch.Tensor = my_on_off_filter_r(
|
||||||
pattern[:, 0:1, :, :],
|
pattern[:, 0:1, :, :],
|
||||||
)
|
)
|
||||||
|
@ -382,6 +411,10 @@ class DatasetCIFAR(DatasetMaster):
|
||||||
b: torch.Tensor = my_on_off_filter_b(
|
b: torch.Tensor = my_on_off_filter_b(
|
||||||
pattern[:, 2:3, :, :],
|
pattern[:, 2:3, :, :],
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
r = pattern[:, 0:1, :, :] + torch.finfo(torch.float32).eps
|
||||||
|
g = pattern[:, 1:2, :, :] + torch.finfo(torch.float32).eps
|
||||||
|
b = pattern[:, 2:3, :, :] + torch.finfo(torch.float32).eps
|
||||||
|
|
||||||
new_tensor: torch.Tensor = torch.cat((r, g, b), dim=1)
|
new_tensor: torch.Tensor = torch.cat((r, g, b), dim=1)
|
||||||
return new_tensor
|
return new_tensor
|
||||||
|
|
34
Parameter.py
34
Parameter.py
|
@ -42,12 +42,14 @@ class Network:
|
||||||
its layers and the number of output neurons."""
|
its layers and the number of output neurons."""
|
||||||
|
|
||||||
number_of_output_neurons: int = field(default=0)
|
number_of_output_neurons: int = field(default=0)
|
||||||
forward_kernel_size: list[list[int]] = field(default_factory=list)
|
|
||||||
forward_neuron_numbers: list[list[int]] = field(default_factory=list)
|
forward_neuron_numbers: list[list[int]] = field(default_factory=list)
|
||||||
|
is_pooling_layer: list[bool] = field(default_factory=list)
|
||||||
|
|
||||||
|
forward_kernel_size: list[list[int]] = field(default_factory=list)
|
||||||
strides: list[list[int]] = field(default_factory=list)
|
strides: list[list[int]] = field(default_factory=list)
|
||||||
dilation: list[list[int]] = field(default_factory=list)
|
dilation: list[list[int]] = field(default_factory=list)
|
||||||
padding: list[list[int]] = field(default_factory=list)
|
padding: list[list[int]] = field(default_factory=list)
|
||||||
is_pooling_layer: list[bool] = field(default_factory=list)
|
|
||||||
w_trainable: list[bool] = field(default_factory=list)
|
w_trainable: list[bool] = field(default_factory=list)
|
||||||
eps_xy_trainable: list[bool] = field(default_factory=list)
|
eps_xy_trainable: list[bool] = field(default_factory=list)
|
||||||
eps_xy_mean: list[bool] = field(default_factory=list)
|
eps_xy_mean: list[bool] = field(default_factory=list)
|
||||||
|
@ -57,24 +59,34 @@ class Network:
|
||||||
class LearningParameters:
|
class LearningParameters:
|
||||||
"""Parameter required for training"""
|
"""Parameter required for training"""
|
||||||
|
|
||||||
|
learning_active: bool = field(default=True)
|
||||||
|
|
||||||
loss_coeffs_mse: float = field(default=0.5)
|
loss_coeffs_mse: float = field(default=0.5)
|
||||||
loss_coeffs_kldiv: float = field(default=1.0)
|
loss_coeffs_kldiv: float = field(default=1.0)
|
||||||
|
|
||||||
|
optimizer_name: str = field(default="Adam")
|
||||||
learning_rate_gamma_w: float = field(default=-1.0)
|
learning_rate_gamma_w: float = field(default=-1.0)
|
||||||
learning_rate_gamma_eps_xy: float = field(default=-1.0)
|
learning_rate_gamma_eps_xy: float = field(default=-1.0)
|
||||||
learning_rate_threshold_w: float = field(default=0.00001)
|
learning_rate_threshold_w: float = field(default=0.00001)
|
||||||
learning_rate_threshold_eps_xy: float = field(default=0.00001)
|
learning_rate_threshold_eps_xy: float = field(default=0.00001)
|
||||||
learning_active: bool = field(default=True)
|
|
||||||
|
lr_schedule_name: str = field(default="ReduceLROnPlateau")
|
||||||
|
lr_scheduler_factor_w: float = field(default=0.75)
|
||||||
|
lr_scheduler_patience_w: int = field(default=-1)
|
||||||
|
|
||||||
|
lr_scheduler_factor_eps_xy: float = field(default=0.75)
|
||||||
|
lr_scheduler_patience_eps_xy: int = field(default=-1)
|
||||||
|
|
||||||
|
number_of_batches_for_one_update: int = field(default=1)
|
||||||
|
overload_path: str = field(default="./Previous")
|
||||||
|
|
||||||
weight_noise_amplitude: float = field(default=0.01)
|
weight_noise_amplitude: float = field(default=0.01)
|
||||||
eps_xy_intitial: float = field(default=0.1)
|
eps_xy_intitial: float = field(default=0.1)
|
||||||
|
|
||||||
test_every_x_learning_steps: int = field(default=50)
|
test_every_x_learning_steps: int = field(default=50)
|
||||||
test_during_learning: bool = field(default=True)
|
test_during_learning: bool = field(default=True)
|
||||||
lr_scheduler_factor: float = field(default=0.75)
|
|
||||||
lr_scheduler_patience: int = field(default=10)
|
|
||||||
optimizer_name: str = field(default="Adam")
|
|
||||||
lr_schedule_name: str = field(default="ReduceLROnPlateau")
|
|
||||||
number_of_batches_for_one_update: int = field(default=1)
|
|
||||||
alpha_number_of_iterations: int = field(default=0)
|
alpha_number_of_iterations: int = field(default=0)
|
||||||
overload_path: str = field(default="./Previous")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -82,12 +94,16 @@ class Augmentation:
|
||||||
"""Parameters used for data augmentation."""
|
"""Parameters used for data augmentation."""
|
||||||
|
|
||||||
crop_width_in_pixel: int = field(default=2)
|
crop_width_in_pixel: int = field(default=2)
|
||||||
|
|
||||||
flip_p: float = field(default=0.5)
|
flip_p: float = field(default=0.5)
|
||||||
|
|
||||||
jitter_brightness: float = field(default=0.5)
|
jitter_brightness: float = field(default=0.5)
|
||||||
jitter_contrast: float = field(default=0.1)
|
jitter_contrast: float = field(default=0.1)
|
||||||
jitter_saturation: float = field(default=0.1)
|
jitter_saturation: float = field(default=0.1)
|
||||||
jitter_hue: float = field(default=0.15)
|
jitter_hue: float = field(default=0.15)
|
||||||
|
|
||||||
|
use_on_off_filter: bool = field(default=True)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ImageStatistics:
|
class ImageStatistics:
|
||||||
|
|
Loading…
Reference in a new issue