Update README.md
Signed-off-by: David Rotermund <54365609+davrot@users.noreply.github.com>
This commit is contained in:
parent
06d7daf0cc
commit
a5c21ee512
1 changed files with 77 additions and 0 deletions
|
@ -1203,3 +1203,80 @@ network.load_state_dict(torch.load("torch_network_dict_class.pt"))
|
|||
network.eval()
|
||||
```
|
||||
|
||||
## But how do I get the activities?
|
||||
|
||||
|
||||
* Option 0: Use sequential as explained above. I think this is the best way.
|
||||
* Option 1: Rewrite the forward function and save the data as e.g. self.activation_conv1
|
||||
* Option 2: Install forward hooks (see [Forward and Backward Function Hooks](https://pytorch.org/tutorials/beginner/former_torchies/nnft_tutorial.html#forward-and-backward-function-hooks))
|
||||
|
||||
First we need a function that we can hook in:
|
||||
|
||||
```python
|
||||
activation = {}
|
||||
|
||||
|
||||
def get_activation(name):
|
||||
def hook(model, input, output):
|
||||
activation[name] = output.detach()
|
||||
|
||||
return hook
|
||||
```
|
||||
|
||||
Next we need to register the hooks:
|
||||
|
||||
```python
|
||||
network = MyNetworkClass()
|
||||
network.conv1.register_forward_hook(get_activation("Conv1"))
|
||||
network.relu1.register_forward_hook(get_activation("ReLU1"))
|
||||
network.max_pooling_1.register_forward_hook(get_activation("MaxPooling1"))
|
||||
network.conv2.register_forward_hook(get_activation("Conv2"))
|
||||
network.relu2.register_forward_hook(get_activation("ReLU2"))
|
||||
network.max_pooling_2.register_forward_hook(get_activation("MaxPooling2"))
|
||||
network.flatten1.register_forward_hook(get_activation("Flatten1"))
|
||||
network.fully_connected_1.register_forward_hook(get_activation("FullyConnected1"))
|
||||
```
|
||||
|
||||
Then we can run the input through the network:
|
||||
|
||||
```python
|
||||
number_of_pattern: int = 111
|
||||
input_number_of_channel: int = 1
|
||||
input_dim_x: int = 24
|
||||
input_dim_y: int = 24
|
||||
|
||||
fake_input = torch.rand(
|
||||
(number_of_pattern, input_number_of_channel, input_dim_x, input_dim_y),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
output = network(fake_input)
|
||||
```
|
||||
|
||||
And we will find the activations in the variable activation.
|
||||
|
||||
```python
|
||||
for name, value in activation.items():
|
||||
print(f"Hook name: {name}")
|
||||
print(value.shape)
|
||||
```
|
||||
|
||||
Output:
|
||||
|
||||
```python
|
||||
Hook name: Conv1
|
||||
torch.Size([111, 32, 20, 20])
|
||||
Hook name: ReLU1
|
||||
torch.Size([111, 32, 20, 20])
|
||||
Hook name: MaxPooling1
|
||||
torch.Size([111, 32, 10, 10])
|
||||
Hook name: Conv2
|
||||
torch.Size([111, 64, 6, 6])
|
||||
Hook name: ReLU2
|
||||
torch.Size([111, 64, 6, 6])
|
||||
Hook name: MaxPooling2
|
||||
torch.Size([111, 64, 3, 3])
|
||||
Hook name: Flatten1
|
||||
torch.Size([111, 576])
|
||||
Hook name: FullyConnected1
|
||||
torch.Size([111, 10])
|
||||
```
|
||||
|
|
Loading…
Reference in a new issue