Update README.md
Signed-off-by: David Rotermund <54365609+davrot@users.noreply.github.com>
This commit is contained in:
parent
0f7e9cc01f
commit
09da889f92
1 changed files with 101 additions and 0 deletions
|
@ -976,7 +976,108 @@ Our main interest is located in \_parameters :
|
|||
print(network._modules["0"].__dict__["_parameters"].keys())
|
||||
```
|
||||
|
||||
And here we find:
|
||||
|
||||
```python
|
||||
odict_keys(['weight', 'bias'])
|
||||
```
|
||||
|
||||
## Who has parameters?
|
||||
|
||||
Now we can analyse which of the layers have parameters:
|
||||
|
||||
```python
|
||||
for module_id in range(0, len(network._modules)):
|
||||
print(
|
||||
f'ID: {module_id} ==> {network._modules[str(module_id)].__dict__["_parameters"].keys()}'
|
||||
)
|
||||
```
|
||||
|
||||
Output:
|
||||
|
||||
```python
|
||||
ID: 0 ==> odict_keys(['weight', 'bias'])
|
||||
ID: 1 ==> odict_keys([])
|
||||
ID: 2 ==> odict_keys([])
|
||||
ID: 3 ==> odict_keys(['weight', 'bias'])
|
||||
ID: 4 ==> odict_keys([])
|
||||
ID: 5 ==> odict_keys([])
|
||||
ID: 6 ==> odict_keys([])
|
||||
ID: 7 ==> odict_keys(['weight', 'bias'])
|
||||
ID: 8 ==> odict_keys([])
|
||||
ID: 9 ==> odict_keys(['weight', 'bias'])
|
||||
```
|
||||
|
||||
## Give me your weights!
|
||||
|
||||
|
||||
```python
|
||||
conv1_bias = network._modules["0"].__dict__["_parameters"]["bias"].data
|
||||
conv1_weights = network._modules["0"].__dict__["_parameters"]["weight"].data
|
||||
|
||||
conv2_bias = network._modules["3"].__dict__["_parameters"]["bias"].data
|
||||
conv2_weights = network._modules["3"].__dict__["_parameters"]["weight"].data
|
||||
|
||||
full1_bias = network._modules["7"].__dict__["_parameters"]["bias"].data
|
||||
full1_weights = network._modules["7"].__dict__["_parameters"]["weight"].data
|
||||
|
||||
output_bias = network._modules["9"].__dict__["_parameters"]["bias"].data
|
||||
output_weights = network._modules["9"].__dict__["_parameters"]["weight"].data
|
||||
|
||||
|
||||
print(conv1_bias.shape) # -> torch.Size([32])
|
||||
print(conv1_weights.shape) # -> torch.Size([32, 1, 5, 5])
|
||||
|
||||
print(conv2_bias.shape) # -> torch.Size([64])
|
||||
print(conv2_weights.shape) # -> torch.Size([64, 32, 5, 5])
|
||||
|
||||
print(full1_bias.shape) # -> torch.Size([1024])
|
||||
print(full1_weights.shape) # -> torch.Size([1024, 576])
|
||||
|
||||
print(output_bias.shape) # -> torch.Size([10])
|
||||
print(output_weights.shape) # -> torch.Size([10, 1024])
|
||||
```
|
||||
|
||||
**Note: The order of the dimensions is strange. It is [Output Channel, Input Channel, Kernel X, Kernel Y] for the 2D convolution layer and [Output Channel, Input Channel] for the full layer.**
|
||||
|
||||
**Note: If you want to interact with the weights, then you have to use .data** If you write directly into e.g. \_\_dict\_\_["_parameters"]["bias"] you might accidently convert it from a parameter into a tensor and/or destroy the connection to the optimizer (which holds only a reference to the weights).
|
||||
|
||||
## Replace weights
|
||||
|
||||
We can now easily replace the weights
|
||||
|
||||
```python
|
||||
network[0].__dict__["_parameters"]["bias"].data = 5 * torch.ones(
|
||||
(32), dtype=torch.float32
|
||||
)
|
||||
network[0].__dict__["_parameters"]["weight"].data = torch.ones(
|
||||
(32, 1, 5, 5), dtype=torch.float32
|
||||
)
|
||||
|
||||
|
||||
fake_input = torch.ones(
|
||||
(1, 1, 24, 24),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
output = network[0](fake_input)
|
||||
print(output)
|
||||
print(output.shape) # -> torch.Size([1, 32, 20, 20])
|
||||
```
|
||||
|
||||
Output:
|
||||
|
||||
```python
|
||||
tensor([[[[30., 30., 30., ..., 30., 30., 30.],
|
||||
[30., 30., 30., ..., 30., 30., 30.],
|
||||
[30., 30., 30., ..., 30., 30., 30.],
|
||||
...,
|
||||
[30., 30., 30., ..., 30., 30., 30.],
|
||||
[30., 30., 30., ..., 30., 30., 30.],
|
||||
[30., 30., 30., ..., 30., 30., 30.]],
|
||||
[...]
|
||||
[30., 30., 30., ..., 30., 30., 30.],
|
||||
[30., 30., 30., ..., 30., 30., 30.],
|
||||
[30., 30., 30., ..., 30., 30., 30.]]]],
|
||||
grad_fn=<ConvolutionBackward0>)
|
||||
```
|
||||
|
|
Loading…
Reference in a new issue