Update README.md

Signed-off-by: David Rotermund <54365609+davrot@users.noreply.github.com>
This commit is contained in:
David Rotermund 2024-01-02 20:36:04 +01:00 committed by GitHub
parent fd4b1e2a4a
commit 06d7daf0cc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1088,7 +1088,7 @@ Usually you will see this construct in tutorials:
```python
class MyNetworkClass(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
input_number_of_channel: int = 1
@ -1147,7 +1147,7 @@ class MyNetworkClass(torch.nn.Module):
bias=True,
)
def forward(self, input):
def forward(self, input: torch.Tensor) -> torch.Tensor:
out = self.conv1(input)
out = self.relu1(out)
out = self.max_pooling_1(out)
@ -1160,3 +1160,46 @@ class MyNetworkClass(torch.nn.Module):
```
In the constructor of the class you define the layers as elements of the class. And we write a forward function that connections the flow of information from the input to the output feed through the layers.
Now we can do the following:
```python
network = MyNetworkClass()
number_of_pattern: int = 111
input_number_of_channel: int = 1
input_dim_x: int = 24
input_dim_y: int = 24
fake_input = torch.rand(
(number_of_pattern, input_number_of_channel, input_dim_x, input_dim_y),
dtype=torch.float32,
)
output = network(fake_input)
print(fake_input.shape) # -> torch.Size([111, 1, 24, 24])
print(output.shape) # -> torch.Size([111, 10])
```
For accessing the innards we now need to address them via their variable names:
```python
print(network.conv1.__dict__["_parameters"].keys()) # -> odict_keys(['weight', 'bias'])
print(network.conv2.__dict__["_parameters"].keys()) # -> odict_keys(['weight', 'bias'])
print(network.fully_connected_1.__dict__["_parameters"].keys()) # -> odict_keys(['weight', 'bias'])
```
Save is still like this:
```python
torch.save(network.state_dict(), "torch_network_dict_class.pt")
```
and load shortens, if it can reuse the class defintion, to:
```python
network = MyNetworkClass()
network.load_state_dict(torch.load("torch_network_dict_class.pt"))
network.eval()
```