diff --git a/pytorch/networks/README.md b/pytorch/networks/README.md index af3185f..dfef358 100644 --- a/pytorch/networks/README.md +++ b/pytorch/networks/README.md @@ -1203,3 +1203,80 @@ network.load_state_dict(torch.load("torch_network_dict_class.pt")) network.eval() ``` +## But how do I get the activities? + + +* Option 0: Use sequential as explained above. I think this is the best way. +* Option 1: Rewrite the forward function and save the data as e.g. self.activation_conv1 +* Option 2: Install forward hooks (see [Forward and Backward Function Hooks](https://pytorch.org/tutorials/beginner/former_torchies/nnft_tutorial.html#forward-and-backward-function-hooks)) + +First we need a function that we can hook in: + +```python +activation = {} + + +def get_activation(name): + def hook(model, input, output): + activation[name] = output.detach() + + return hook +``` + +Next we need to register the hooks: + +```python +network = MyNetworkClass() +network.conv1.register_forward_hook(get_activation("Conv1")) +network.relu1.register_forward_hook(get_activation("ReLU1")) +network.max_pooling_1.register_forward_hook(get_activation("MaxPooling1")) +network.conv2.register_forward_hook(get_activation("Conv2")) +network.relu2.register_forward_hook(get_activation("ReLU2")) +network.max_pooling_2.register_forward_hook(get_activation("MaxPooling2")) +network.flatten1.register_forward_hook(get_activation("Flatten1")) +network.fully_connected_1.register_forward_hook(get_activation("FullyConnected1")) +``` + +Then we can run the input through the network: + +```python +number_of_pattern: int = 111 +input_number_of_channel: int = 1 +input_dim_x: int = 24 +input_dim_y: int = 24 + +fake_input = torch.rand( + (number_of_pattern, input_number_of_channel, input_dim_x, input_dim_y), + dtype=torch.float32, +) +output = network(fake_input) +``` + +And we will find the activations in the variable activation. + +```python +for name, value in activation.items(): + print(f"Hook name: {name}") + print(value.shape) +``` + +Output: + +```python +Hook name: Conv1 +torch.Size([111, 32, 20, 20]) +Hook name: ReLU1 +torch.Size([111, 32, 20, 20]) +Hook name: MaxPooling1 +torch.Size([111, 32, 10, 10]) +Hook name: Conv2 +torch.Size([111, 64, 6, 6]) +Hook name: ReLU2 +torch.Size([111, 64, 6, 6]) +Hook name: MaxPooling2 +torch.Size([111, 64, 3, 3]) +Hook name: Flatten1 +torch.Size([111, 576]) +Hook name: FullyConnected1 +torch.Size([111, 10]) +```