pytorch-sbs/README.md

411 lines
6.5 KiB
Markdown
Raw Normal View History

2022-04-30 02:03:34 +02:00
# pytorch-sbs
SbS Extension for PyTorch
2022-04-30 02:14:02 +02:00
2022-04-30 13:45:04 +02:00
# Based on these scientific papers
2022-04-30 02:14:02 +02:00
2022-04-30 02:16:43 +02:00
**Back-Propagation Learning in Deep Spike-By-Spike Networks**
David Rotermund and Klaus R. Pawelzik
Front. Comput. Neurosci., https://doi.org/10.3389/fncom.2019.00055
https://www.frontiersin.org/articles/10.3389/fncom.2019.00055/full
2022-04-30 02:14:24 +02:00
2022-04-30 02:16:43 +02:00
**Efficient Computation Based on Stochastic Spikes**
Udo Ernst, David Rotermund, and Klaus Pawelzik
2022-04-30 02:17:18 +02:00
Neural Computation (2007) 19 (5): 13131343. https://doi.org/10.1162/neco.2007.19.5.1313
2022-04-30 02:16:43 +02:00
https://direct.mit.edu/neco/article-abstract/19/5/1313/7183/Efficient-Computation-Based-on-Stochastic-Spikes
2022-04-30 13:45:04 +02:00
# Python
It was programmed with 3.10.4. And I used some 3.10 Python expression. Thus you might get problems with older Python versions.
# C++
It works without compiling the C++ modules. However it is 10x slower.
You need to modify the Makefile in the C++ directory to your Python installation.
In addition yoir Python installation needs the PyBind11 package installed. You might want to perform a
pip install pybind11
The Makefile uses clang as a compiler. If you want something else then you need to change the Makefile.
2022-04-30 13:46:21 +02:00
The SbS.py autodetectes if the required C++ .so modules are in the same directory as the SbS.py file.
2022-05-01 01:24:00 +02:00
# SbS layer class
## Variables
2022-05-01 01:34:45 +02:00
```
2022-05-01 01:24:00 +02:00
epsilon_xy
2022-05-01 01:34:45 +02:00
```
```
2022-05-01 01:24:00 +02:00
epsilon_0
2022-05-01 01:34:45 +02:00
```
```
2022-05-01 01:24:00 +02:00
epsilon_t
2022-05-01 01:34:45 +02:00
```
```
2022-05-01 01:24:00 +02:00
weights
2022-05-01 01:34:45 +02:00
```
```
2022-05-01 01:24:00 +02:00
kernel_size
2022-05-01 01:34:45 +02:00
```
```
2022-05-01 01:24:00 +02:00
stride
2022-05-01 01:34:45 +02:00
```
```
2022-05-01 01:24:00 +02:00
dilation
2022-05-01 01:34:45 +02:00
```
```
2022-05-01 01:24:00 +02:00
padding
2022-05-01 01:34:45 +02:00
```
```
2022-05-01 01:24:00 +02:00
output_size
2022-05-01 01:34:45 +02:00
```
```
2022-05-01 01:24:00 +02:00
number_of_spikes
2022-05-01 01:34:45 +02:00
```
```
2022-05-01 01:24:00 +02:00
number_of_cpu_processes
2022-05-01 01:34:45 +02:00
```
```
2022-05-01 01:24:00 +02:00
number_of_neurons
2022-05-01 01:34:45 +02:00
```
```
2022-05-01 01:24:00 +02:00
number_of_input_neurons
2022-05-01 01:34:45 +02:00
```
```
2022-05-01 01:24:00 +02:00
h_initial
2022-05-01 01:34:45 +02:00
```
```
2022-05-01 01:24:00 +02:00
alpha_number_of_iterations
2022-05-01 01:34:45 +02:00
```
2022-05-01 01:24:00 +02:00
## Constructor
2022-05-01 01:34:45 +02:00
```
2022-05-01 01:35:47 +02:00
def __init__(
2022-05-01 01:34:45 +02:00
self,
number_of_input_neurons: int,
number_of_neurons: int,
input_size: list[int],
forward_kernel_size: list[int],
number_of_spikes: int,
epsilon_t: torch.Tensor,
epsilon_xy_intitial: float = 0.1,
epsilon_0: float = 1.0,
weight_noise_amplitude: float = 0.01,
is_pooling_layer: bool = False,
strides: list[int] = [1, 1],
dilation: list[int] = [0, 0],
padding: list[int] = [0, 0],
alpha_number_of_iterations: int = 0,
number_of_cpu_processes: int = 1,
2022-05-01 01:24:00 +02:00
) -> None:
2022-05-01 01:34:45 +02:00
```
2022-05-01 01:24:00 +02:00
## Methods
2022-05-01 01:34:45 +02:00
```
2022-05-01 01:35:47 +02:00
def initialize_weights(
2022-05-01 01:34:45 +02:00
self,
is_pooling_layer: bool = False,
noise_amplitude: float = 0.01,
2022-05-01 01:24:00 +02:00
) -> None:
2022-05-01 01:34:45 +02:00
```
2022-05-01 01:24:00 +02:00
For the generation of the initital weights. Switches between normal initial random weights and pooling weights.
2022-05-01 01:30:38 +02:00
---
2022-05-01 01:24:00 +02:00
2022-05-01 01:34:45 +02:00
```
2022-05-01 01:35:47 +02:00
def initialize_epsilon_xy(
2022-05-01 01:34:45 +02:00
self,
eps_xy_intitial: float,
2022-05-01 01:24:00 +02:00
) -> None:
2022-05-01 01:34:45 +02:00
```
2022-05-01 01:24:00 +02:00
Creates initial epsilon xy matrices.
2022-05-01 01:30:38 +02:00
---
2022-05-01 01:34:45 +02:00
```
2022-05-01 01:35:47 +02:00
def set_h_init_to_uniform(self) -> None:
2022-05-01 01:34:45 +02:00
```
2022-05-01 01:24:00 +02:00
2022-05-01 01:30:38 +02:00
---
2022-05-01 01:34:45 +02:00
```
2022-05-01 01:35:47 +02:00
def backup_epsilon_xy(self) -> None:
def restore_epsilon_xy(self) -> None:
def backup_weights(self) -> None:
def restore_weights(self) -> None:
2022-05-01 01:34:45 +02:00
```
2022-05-01 01:24:00 +02:00
2022-05-01 01:30:38 +02:00
---
2022-05-01 01:34:45 +02:00
```
2022-05-01 01:35:47 +02:00
def threshold_epsilon_xy(self, threshold: float) -> None:
def threshold_weights(self, threshold: float) -> None:
2022-05-01 01:34:45 +02:00
```
2022-05-01 01:24:00 +02:00
2022-05-01 01:30:38 +02:00
---
2022-05-01 01:34:45 +02:00
```
2022-05-01 01:35:47 +02:00
def mean_epsilon_xy(self) -> None:
2022-05-01 01:34:45 +02:00
```
2022-05-01 01:24:00 +02:00
2022-05-01 01:30:38 +02:00
---
2022-05-01 01:34:45 +02:00
```
2022-05-01 01:35:47 +02:00
def norm_weights(self) -> None:
2022-05-01 01:34:45 +02:00
```
2022-05-01 01:24:00 +02:00
2022-04-30 14:51:02 +02:00
# Parameters in JSON file
2022-04-30 13:45:04 +02:00
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:54:56 +02:00
data_mode: str = field(default="")
2022-05-01 17:07:56 +02:00
```
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:54:56 +02:00
data_path: str = field(default="./")
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:54:56 +02:00
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:54:56 +02:00
batch_size: int = field(default=500)
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:54:56 +02:00
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:54:56 +02:00
learning_step: int = field(default=0)
2022-05-01 17:06:56 +02:00
```
```
2022-04-30 14:54:56 +02:00
learning_step_max: int = field(default=10000)
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:54:56 +02:00
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:54:56 +02:00
number_of_cpu_processes: int = field(default=-1)
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:54:56 +02:00
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:54:56 +02:00
number_of_spikes: int = field(default=0)
2022-05-01 17:06:56 +02:00
```
```
2022-04-30 14:54:56 +02:00
cooldown_after_number_of_spikes: int = field(default=0)
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:54:56 +02:00
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:54:56 +02:00
weight_path: str = field(default="./Weights/")
2022-05-01 17:06:56 +02:00
```
```
2022-04-30 14:54:56 +02:00
eps_xy_path: str = field(default="./EpsXY/")
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:54:56 +02:00
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:54:56 +02:00
reduction_cooldown: float = field(default=25.0)
2022-05-01 17:06:56 +02:00
```
```
2022-04-30 14:54:56 +02:00
epsilon_0: float = field(default=1.0)
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:54:56 +02:00
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:54:56 +02:00
update_after_x_batch: float = field(default=1.0)
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:54:56 +02:00
2022-04-30 14:51:02 +02:00
## network_structure (required!)
Parameters of the network. The details about its layers and the number of output neurons.
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
number_of_output_neurons: int = field(default=0)
2022-05-01 17:06:56 +02:00
```
```
2022-04-30 14:51:02 +02:00
forward_neuron_numbers: list[list[int]] = field(default_factory=list)
2022-05-01 17:06:56 +02:00
```
```
2022-04-30 14:51:02 +02:00
is_pooling_layer: list[bool] = field(default_factory=list)
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
forward_kernel_size: list[list[int]] = field(default_factory=list)
2022-05-01 17:06:56 +02:00
```
```
2022-04-30 14:51:02 +02:00
strides: list[list[int]] = field(default_factory=list)
2022-05-01 17:06:56 +02:00
```
```
2022-04-30 14:51:02 +02:00
dilation: list[list[int]] = field(default_factory=list)
2022-05-01 17:06:56 +02:00
```
```
2022-04-30 14:51:02 +02:00
padding: list[list[int]] = field(default_factory=list)
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
w_trainable: list[bool] = field(default_factory=list)
2022-05-01 17:06:56 +02:00
```
```
2022-04-30 14:51:02 +02:00
eps_xy_trainable: list[bool] = field(default_factory=list)
2022-05-01 17:06:56 +02:00
```
```
2022-04-30 14:51:02 +02:00
eps_xy_mean: list[bool] = field(default_factory=list)
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
## learning_parameters
Parameter required for training
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
learning_active: bool = field(default=True)
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
loss_coeffs_mse: float = field(default=0.5)
2022-05-01 17:06:56 +02:00
```
```
2022-04-30 14:51:02 +02:00
loss_coeffs_kldiv: float = field(default=1.0)
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
optimizer_name: str = field(default="Adam")
2022-05-01 17:06:56 +02:00
```
```
2022-04-30 14:51:02 +02:00
learning_rate_gamma_w: float = field(default=-1.0)
2022-05-01 17:06:56 +02:00
```
```
2022-04-30 14:51:02 +02:00
learning_rate_gamma_eps_xy: float = field(default=-1.0)
2022-05-01 17:06:56 +02:00
```
```
2022-04-30 14:51:02 +02:00
learning_rate_threshold_w: float = field(default=0.00001)
2022-05-01 17:06:56 +02:00
```
```
2022-04-30 14:51:02 +02:00
learning_rate_threshold_eps_xy: float = field(default=0.00001)
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
lr_schedule_name: str = field(default="ReduceLROnPlateau")
2022-05-01 17:06:56 +02:00
```
```
2022-04-30 14:51:02 +02:00
lr_scheduler_factor_w: float = field(default=0.75)
2022-05-01 17:06:56 +02:00
```
```
2022-04-30 14:51:02 +02:00
lr_scheduler_patience_w: int = field(default=-1)
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
lr_scheduler_factor_eps_xy: float = field(default=0.75)
2022-05-01 17:06:56 +02:00
```
```
2022-04-30 14:51:02 +02:00
lr_scheduler_patience_eps_xy: int = field(default=-1)
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
number_of_batches_for_one_update: int = field(default=1)
2022-05-01 17:06:56 +02:00
```
```
2022-04-30 14:51:02 +02:00
overload_path: str = field(default="./Previous")
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
weight_noise_amplitude: float = field(default=0.01)
2022-05-01 17:06:56 +02:00
```
```
2022-04-30 14:51:02 +02:00
eps_xy_intitial: float = field(default=0.1)
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
test_every_x_learning_steps: int = field(default=50)
2022-05-01 17:06:56 +02:00
```
```
2022-04-30 14:51:02 +02:00
test_during_learning: bool = field(default=True)
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
alpha_number_of_iterations: int = field(default=0)
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
## augmentation
Parameters used for data augmentation.
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:54:56 +02:00
crop_width_in_pixel: int = field(default=2)
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:54:56 +02:00
flip_p: float = field(default=0.5)
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:54:56 +02:00
jitter_brightness: float = field(default=0.5)
2022-05-01 17:06:56 +02:00
```
```
2022-04-30 14:54:56 +02:00
jitter_contrast: float = field(default=0.1)
2022-05-01 17:06:56 +02:00
```
```
2022-04-30 14:54:56 +02:00
jitter_saturation: float = field(default=0.1)
2022-05-01 17:06:56 +02:00
```
```
2022-04-30 14:54:56 +02:00
jitter_hue: float = field(default=0.15)
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:51:02 +02:00
2022-05-01 17:06:56 +02:00
```
use_on_off_filter: bool = field(default=True)
```
2022-04-30 14:51:02 +02:00
## ImageStatistics (please ignore)
(Statistical) information about the input. i.e. mean values and the x and y size of the input
2022-05-01 17:06:56 +02:00
```
2022-04-30 14:54:56 +02:00
mean: list[float] = field(default_factory=list)
2022-05-01 17:06:56 +02:00
```
```
2022-04-30 14:54:56 +02:00
the_size: list[int] = field(default_factory=list)
2022-05-01 17:06:56 +02:00
```