From 06d7daf0cce4838e1ece07c97572c0906c0d10dd Mon Sep 17 00:00:00 2001 From: David Rotermund <54365609+davrot@users.noreply.github.com> Date: Tue, 2 Jan 2024 20:36:04 +0100 Subject: [PATCH] Update README.md Signed-off-by: David Rotermund <54365609+davrot@users.noreply.github.com> --- pytorch/networks/README.md | 47 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/pytorch/networks/README.md b/pytorch/networks/README.md index 29e4b85..af3185f 100644 --- a/pytorch/networks/README.md +++ b/pytorch/networks/README.md @@ -1088,7 +1088,7 @@ Usually you will see this construct in tutorials: ```python class MyNetworkClass(torch.nn.Module): - def __init__(self): + def __init__(self) -> None: super().__init__() input_number_of_channel: int = 1 @@ -1147,7 +1147,7 @@ class MyNetworkClass(torch.nn.Module): bias=True, ) - def forward(self, input): + def forward(self, input: torch.Tensor) -> torch.Tensor: out = self.conv1(input) out = self.relu1(out) out = self.max_pooling_1(out) @@ -1160,3 +1160,46 @@ class MyNetworkClass(torch.nn.Module): ``` +In the constructor of the class you define the layers as elements of the class. And we write a forward function that connections the flow of information from the input to the output feed through the layers. + +Now we can do the following: + +```python +network = MyNetworkClass() + +number_of_pattern: int = 111 +input_number_of_channel: int = 1 +input_dim_x: int = 24 +input_dim_y: int = 24 + +fake_input = torch.rand( + (number_of_pattern, input_number_of_channel, input_dim_x, input_dim_y), + dtype=torch.float32, +) +output = network(fake_input) +print(fake_input.shape) # -> torch.Size([111, 1, 24, 24]) +print(output.shape) # -> torch.Size([111, 10]) +``` + +For accessing the innards we now need to address them via their variable names: + +```python +print(network.conv1.__dict__["_parameters"].keys()) # -> odict_keys(['weight', 'bias']) +print(network.conv2.__dict__["_parameters"].keys()) # -> odict_keys(['weight', 'bias']) +print(network.fully_connected_1.__dict__["_parameters"].keys()) # -> odict_keys(['weight', 'bias']) +``` + +Save is still like this: + +```python +torch.save(network.state_dict(), "torch_network_dict_class.pt") +``` + +and load shortens, if it can reuse the class defintion, to: + +```python +network = MyNetworkClass() +network.load_state_dict(torch.load("torch_network_dict_class.pt")) +network.eval() +``` +