Update README.md

Signed-off-by: David Rotermund <54365609+davrot@users.noreply.github.com>
This commit is contained in:
David Rotermund 2024-01-02 20:40:35 +01:00 committed by GitHub
parent 06d7daf0cc
commit a5c21ee512
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1203,3 +1203,80 @@ network.load_state_dict(torch.load("torch_network_dict_class.pt"))
network.eval() network.eval()
``` ```
## But how do I get the activities?
* Option 0: Use sequential as explained above. I think this is the best way.
* Option 1: Rewrite the forward function and save the data as e.g. self.activation_conv1
* Option 2: Install forward hooks (see [Forward and Backward Function Hooks](https://pytorch.org/tutorials/beginner/former_torchies/nnft_tutorial.html#forward-and-backward-function-hooks))
First we need a function that we can hook in:
```python
activation = {}
def get_activation(name):
def hook(model, input, output):
activation[name] = output.detach()
return hook
```
Next we need to register the hooks:
```python
network = MyNetworkClass()
network.conv1.register_forward_hook(get_activation("Conv1"))
network.relu1.register_forward_hook(get_activation("ReLU1"))
network.max_pooling_1.register_forward_hook(get_activation("MaxPooling1"))
network.conv2.register_forward_hook(get_activation("Conv2"))
network.relu2.register_forward_hook(get_activation("ReLU2"))
network.max_pooling_2.register_forward_hook(get_activation("MaxPooling2"))
network.flatten1.register_forward_hook(get_activation("Flatten1"))
network.fully_connected_1.register_forward_hook(get_activation("FullyConnected1"))
```
Then we can run the input through the network:
```python
number_of_pattern: int = 111
input_number_of_channel: int = 1
input_dim_x: int = 24
input_dim_y: int = 24
fake_input = torch.rand(
(number_of_pattern, input_number_of_channel, input_dim_x, input_dim_y),
dtype=torch.float32,
)
output = network(fake_input)
```
And we will find the activations in the variable activation.
```python
for name, value in activation.items():
print(f"Hook name: {name}")
print(value.shape)
```
Output:
```python
Hook name: Conv1
torch.Size([111, 32, 20, 20])
Hook name: ReLU1
torch.Size([111, 32, 20, 20])
Hook name: MaxPooling1
torch.Size([111, 32, 10, 10])
Hook name: Conv2
torch.Size([111, 64, 6, 6])
Hook name: ReLU2
torch.Size([111, 64, 6, 6])
Hook name: MaxPooling2
torch.Size([111, 64, 3, 3])
Hook name: Flatten1
torch.Size([111, 576])
Hook name: FullyConnected1
torch.Size([111, 10])
```