diff --git a/learn_it.py b/learn_it.py index b01d345..9431de3 100644 --- a/learn_it.py +++ b/learn_it.py @@ -465,6 +465,8 @@ with torch.no_grad(): network[id].threshold_epsilon_xy( cfg.learning_parameters.learning_rate_threshold_eps_xy ) + if cfg.network_structure.eps_xy_mean[id] is True: + network[id].mean_epsilon_xy() else: network[id].epsilon_xy = torch.tensor( eps_xy[id], dtype=torch.float64