Update README.md
Signed-off-by: David Rotermund <54365609+davrot@users.noreply.github.com>
This commit is contained in:
parent
fd4b1e2a4a
commit
06d7daf0cc
1 changed files with 45 additions and 2 deletions
|
@ -1088,7 +1088,7 @@ Usually you will see this construct in tutorials:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class MyNetworkClass(torch.nn.Module):
|
class MyNetworkClass(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
input_number_of_channel: int = 1
|
input_number_of_channel: int = 1
|
||||||
|
@ -1147,7 +1147,7 @@ class MyNetworkClass(torch.nn.Module):
|
||||||
bias=True,
|
bias=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
out = self.conv1(input)
|
out = self.conv1(input)
|
||||||
out = self.relu1(out)
|
out = self.relu1(out)
|
||||||
out = self.max_pooling_1(out)
|
out = self.max_pooling_1(out)
|
||||||
|
@ -1160,3 +1160,46 @@ class MyNetworkClass(torch.nn.Module):
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
In the constructor of the class you define the layers as elements of the class. And we write a forward function that connections the flow of information from the input to the output feed through the layers.
|
||||||
|
|
||||||
|
Now we can do the following:
|
||||||
|
|
||||||
|
```python
|
||||||
|
network = MyNetworkClass()
|
||||||
|
|
||||||
|
number_of_pattern: int = 111
|
||||||
|
input_number_of_channel: int = 1
|
||||||
|
input_dim_x: int = 24
|
||||||
|
input_dim_y: int = 24
|
||||||
|
|
||||||
|
fake_input = torch.rand(
|
||||||
|
(number_of_pattern, input_number_of_channel, input_dim_x, input_dim_y),
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
output = network(fake_input)
|
||||||
|
print(fake_input.shape) # -> torch.Size([111, 1, 24, 24])
|
||||||
|
print(output.shape) # -> torch.Size([111, 10])
|
||||||
|
```
|
||||||
|
|
||||||
|
For accessing the innards we now need to address them via their variable names:
|
||||||
|
|
||||||
|
```python
|
||||||
|
print(network.conv1.__dict__["_parameters"].keys()) # -> odict_keys(['weight', 'bias'])
|
||||||
|
print(network.conv2.__dict__["_parameters"].keys()) # -> odict_keys(['weight', 'bias'])
|
||||||
|
print(network.fully_connected_1.__dict__["_parameters"].keys()) # -> odict_keys(['weight', 'bias'])
|
||||||
|
```
|
||||||
|
|
||||||
|
Save is still like this:
|
||||||
|
|
||||||
|
```python
|
||||||
|
torch.save(network.state_dict(), "torch_network_dict_class.pt")
|
||||||
|
```
|
||||||
|
|
||||||
|
and load shortens, if it can reuse the class defintion, to:
|
||||||
|
|
||||||
|
```python
|
||||||
|
network = MyNetworkClass()
|
||||||
|
network.load_state_dict(torch.load("torch_network_dict_class.pt"))
|
||||||
|
network.eval()
|
||||||
|
```
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue