On/Off filter can be turned on/off

This commit is contained in:
David Rotermund 2022-05-01 17:03:19 +02:00 committed by GitHub
parent 2929ba2a63
commit 92050f5933
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 99 additions and 50 deletions

View file

@ -135,10 +135,13 @@ class DatasetMNIST(DatasetMaster):
pattern = scripted_transforms(pattern) pattern = scripted_transforms(pattern)
# => On/Off # => On/Off
my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0]) if cfg.augmentation.use_on_off_filter is True:
gray: torch.Tensor = my_on_off_filter( my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0])
pattern[:, 0:1, :, :], gray: torch.Tensor = my_on_off_filter(
) pattern[:, 0:1, :, :],
)
else:
gray = pattern[:, 0:1, :, :] + torch.finfo(torch.float32).eps
return gray return gray
@ -166,10 +169,13 @@ class DatasetMNIST(DatasetMaster):
pattern = scripted_transforms(pattern) pattern = scripted_transforms(pattern)
# => On/Off # => On/Off
my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0]) if cfg.augmentation.use_on_off_filter is True:
gray: torch.Tensor = my_on_off_filter( my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0])
pattern[:, 0:1, :, :], gray: torch.Tensor = my_on_off_filter(
) pattern[:, 0:1, :, :],
)
else:
gray = pattern[:, 0:1, :, :] + torch.finfo(torch.float32).eps
return gray return gray
@ -225,10 +231,13 @@ class DatasetFashionMNIST(DatasetMaster):
pattern = scripted_transforms(pattern) pattern = scripted_transforms(pattern)
# => On/Off # => On/Off
my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0]) if cfg.augmentation.use_on_off_filter is True:
gray: torch.Tensor = my_on_off_filter( my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0])
pattern[:, 0:1, :, :], gray: torch.Tensor = my_on_off_filter(
) pattern[:, 0:1, :, :],
)
else:
gray = pattern[:, 0:1, :, :] + torch.finfo(torch.float32).eps
return gray return gray
@ -263,10 +272,13 @@ class DatasetFashionMNIST(DatasetMaster):
pattern = scripted_transforms(pattern) pattern = scripted_transforms(pattern)
# => On/Off # => On/Off
my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0]) if cfg.augmentation.use_on_off_filter is True:
gray: torch.Tensor = my_on_off_filter( my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0])
pattern[:, 0:1, :, :], gray: torch.Tensor = my_on_off_filter(
) pattern[:, 0:1, :, :],
)
else:
gray = pattern[:, 0:1, :, :] + torch.finfo(torch.float32).eps
return gray return gray
@ -321,19 +333,29 @@ class DatasetCIFAR(DatasetMaster):
pattern = scripted_transforms(pattern) pattern = scripted_transforms(pattern)
# => On/Off # => On/Off
if cfg.augmentation.use_on_off_filter is True:
my_on_off_filter_r: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0]) my_on_off_filter_r: OnOffFilter = OnOffFilter(
my_on_off_filter_g: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[1]) p=cfg.image_statistics.mean[0]
my_on_off_filter_b: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[2]) )
r: torch.Tensor = my_on_off_filter_r( my_on_off_filter_g: OnOffFilter = OnOffFilter(
pattern[:, 0:1, :, :], p=cfg.image_statistics.mean[1]
) )
g: torch.Tensor = my_on_off_filter_g( my_on_off_filter_b: OnOffFilter = OnOffFilter(
pattern[:, 1:2, :, :], p=cfg.image_statistics.mean[2]
) )
b: torch.Tensor = my_on_off_filter_b( r: torch.Tensor = my_on_off_filter_r(
pattern[:, 2:3, :, :], pattern[:, 0:1, :, :],
) )
g: torch.Tensor = my_on_off_filter_g(
pattern[:, 1:2, :, :],
)
b: torch.Tensor = my_on_off_filter_b(
pattern[:, 2:3, :, :],
)
else:
r = pattern[:, 0:1, :, :] + torch.finfo(torch.float32).eps
g = pattern[:, 1:2, :, :] + torch.finfo(torch.float32).eps
b = pattern[:, 2:3, :, :] + torch.finfo(torch.float32).eps
new_tensor: torch.Tensor = torch.cat((r, g, b), dim=1) new_tensor: torch.Tensor = torch.cat((r, g, b), dim=1)
return new_tensor return new_tensor
@ -370,18 +392,29 @@ class DatasetCIFAR(DatasetMaster):
pattern = scripted_transforms(pattern) pattern = scripted_transforms(pattern)
# => On/Off # => On/Off
my_on_off_filter_r: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0]) if cfg.augmentation.use_on_off_filter is True:
my_on_off_filter_g: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[1]) my_on_off_filter_r: OnOffFilter = OnOffFilter(
my_on_off_filter_b: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[2]) p=cfg.image_statistics.mean[0]
r: torch.Tensor = my_on_off_filter_r( )
pattern[:, 0:1, :, :], my_on_off_filter_g: OnOffFilter = OnOffFilter(
) p=cfg.image_statistics.mean[1]
g: torch.Tensor = my_on_off_filter_g( )
pattern[:, 1:2, :, :], my_on_off_filter_b: OnOffFilter = OnOffFilter(
) p=cfg.image_statistics.mean[2]
b: torch.Tensor = my_on_off_filter_b( )
pattern[:, 2:3, :, :], r: torch.Tensor = my_on_off_filter_r(
) pattern[:, 0:1, :, :],
)
g: torch.Tensor = my_on_off_filter_g(
pattern[:, 1:2, :, :],
)
b: torch.Tensor = my_on_off_filter_b(
pattern[:, 2:3, :, :],
)
else:
r = pattern[:, 0:1, :, :] + torch.finfo(torch.float32).eps
g = pattern[:, 1:2, :, :] + torch.finfo(torch.float32).eps
b = pattern[:, 2:3, :, :] + torch.finfo(torch.float32).eps
new_tensor: torch.Tensor = torch.cat((r, g, b), dim=1) new_tensor: torch.Tensor = torch.cat((r, g, b), dim=1)
return new_tensor return new_tensor

View file

@ -42,12 +42,14 @@ class Network:
its layers and the number of output neurons.""" its layers and the number of output neurons."""
number_of_output_neurons: int = field(default=0) number_of_output_neurons: int = field(default=0)
forward_kernel_size: list[list[int]] = field(default_factory=list)
forward_neuron_numbers: list[list[int]] = field(default_factory=list) forward_neuron_numbers: list[list[int]] = field(default_factory=list)
is_pooling_layer: list[bool] = field(default_factory=list)
forward_kernel_size: list[list[int]] = field(default_factory=list)
strides: list[list[int]] = field(default_factory=list) strides: list[list[int]] = field(default_factory=list)
dilation: list[list[int]] = field(default_factory=list) dilation: list[list[int]] = field(default_factory=list)
padding: list[list[int]] = field(default_factory=list) padding: list[list[int]] = field(default_factory=list)
is_pooling_layer: list[bool] = field(default_factory=list)
w_trainable: list[bool] = field(default_factory=list) w_trainable: list[bool] = field(default_factory=list)
eps_xy_trainable: list[bool] = field(default_factory=list) eps_xy_trainable: list[bool] = field(default_factory=list)
eps_xy_mean: list[bool] = field(default_factory=list) eps_xy_mean: list[bool] = field(default_factory=list)
@ -57,24 +59,34 @@ class Network:
class LearningParameters: class LearningParameters:
"""Parameter required for training""" """Parameter required for training"""
learning_active: bool = field(default=True)
loss_coeffs_mse: float = field(default=0.5) loss_coeffs_mse: float = field(default=0.5)
loss_coeffs_kldiv: float = field(default=1.0) loss_coeffs_kldiv: float = field(default=1.0)
optimizer_name: str = field(default="Adam")
learning_rate_gamma_w: float = field(default=-1.0) learning_rate_gamma_w: float = field(default=-1.0)
learning_rate_gamma_eps_xy: float = field(default=-1.0) learning_rate_gamma_eps_xy: float = field(default=-1.0)
learning_rate_threshold_w: float = field(default=0.00001) learning_rate_threshold_w: float = field(default=0.00001)
learning_rate_threshold_eps_xy: float = field(default=0.00001) learning_rate_threshold_eps_xy: float = field(default=0.00001)
learning_active: bool = field(default=True)
lr_schedule_name: str = field(default="ReduceLROnPlateau")
lr_scheduler_factor_w: float = field(default=0.75)
lr_scheduler_patience_w: int = field(default=-1)
lr_scheduler_factor_eps_xy: float = field(default=0.75)
lr_scheduler_patience_eps_xy: int = field(default=-1)
number_of_batches_for_one_update: int = field(default=1)
overload_path: str = field(default="./Previous")
weight_noise_amplitude: float = field(default=0.01) weight_noise_amplitude: float = field(default=0.01)
eps_xy_intitial: float = field(default=0.1) eps_xy_intitial: float = field(default=0.1)
test_every_x_learning_steps: int = field(default=50) test_every_x_learning_steps: int = field(default=50)
test_during_learning: bool = field(default=True) test_during_learning: bool = field(default=True)
lr_scheduler_factor: float = field(default=0.75)
lr_scheduler_patience: int = field(default=10)
optimizer_name: str = field(default="Adam")
lr_schedule_name: str = field(default="ReduceLROnPlateau")
number_of_batches_for_one_update: int = field(default=1)
alpha_number_of_iterations: int = field(default=0) alpha_number_of_iterations: int = field(default=0)
overload_path: str = field(default="./Previous")
@dataclass @dataclass
@ -82,12 +94,16 @@ class Augmentation:
"""Parameters used for data augmentation.""" """Parameters used for data augmentation."""
crop_width_in_pixel: int = field(default=2) crop_width_in_pixel: int = field(default=2)
flip_p: float = field(default=0.5) flip_p: float = field(default=0.5)
jitter_brightness: float = field(default=0.5) jitter_brightness: float = field(default=0.5)
jitter_contrast: float = field(default=0.1) jitter_contrast: float = field(default=0.1)
jitter_saturation: float = field(default=0.1) jitter_saturation: float = field(default=0.1)
jitter_hue: float = field(default=0.15) jitter_hue: float = field(default=0.15)
use_on_off_filter: bool = field(default=True)
@dataclass @dataclass
class ImageStatistics: class ImageStatistics: