Update README.md
Signed-off-by: David Rotermund <54365609+davrot@users.noreply.github.com>
This commit is contained in:
parent
fd4b1e2a4a
commit
06d7daf0cc
1 changed files with 45 additions and 2 deletions
|
@ -1088,7 +1088,7 @@ Usually you will see this construct in tutorials:
|
|||
|
||||
```python
|
||||
class MyNetworkClass(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
input_number_of_channel: int = 1
|
||||
|
@ -1147,7 +1147,7 @@ class MyNetworkClass(torch.nn.Module):
|
|||
bias=True,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
out = self.conv1(input)
|
||||
out = self.relu1(out)
|
||||
out = self.max_pooling_1(out)
|
||||
|
@ -1160,3 +1160,46 @@ class MyNetworkClass(torch.nn.Module):
|
|||
|
||||
```
|
||||
|
||||
In the constructor of the class you define the layers as elements of the class. And we write a forward function that connections the flow of information from the input to the output feed through the layers.
|
||||
|
||||
Now we can do the following:
|
||||
|
||||
```python
|
||||
network = MyNetworkClass()
|
||||
|
||||
number_of_pattern: int = 111
|
||||
input_number_of_channel: int = 1
|
||||
input_dim_x: int = 24
|
||||
input_dim_y: int = 24
|
||||
|
||||
fake_input = torch.rand(
|
||||
(number_of_pattern, input_number_of_channel, input_dim_x, input_dim_y),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
output = network(fake_input)
|
||||
print(fake_input.shape) # -> torch.Size([111, 1, 24, 24])
|
||||
print(output.shape) # -> torch.Size([111, 10])
|
||||
```
|
||||
|
||||
For accessing the innards we now need to address them via their variable names:
|
||||
|
||||
```python
|
||||
print(network.conv1.__dict__["_parameters"].keys()) # -> odict_keys(['weight', 'bias'])
|
||||
print(network.conv2.__dict__["_parameters"].keys()) # -> odict_keys(['weight', 'bias'])
|
||||
print(network.fully_connected_1.__dict__["_parameters"].keys()) # -> odict_keys(['weight', 'bias'])
|
||||
```
|
||||
|
||||
Save is still like this:
|
||||
|
||||
```python
|
||||
torch.save(network.state_dict(), "torch_network_dict_class.pt")
|
||||
```
|
||||
|
||||
and load shortens, if it can reuse the class defintion, to:
|
||||
|
||||
```python
|
||||
network = MyNetworkClass()
|
||||
network.load_state_dict(torch.load("torch_network_dict_class.pt"))
|
||||
network.eval()
|
||||
```
|
||||
|
||||
|
|
Loading…
Reference in a new issue