Update README.md

Signed-off-by: David Rotermund <54365609+davrot@users.noreply.github.com>
This commit is contained in:
David Rotermund 2024-01-02 21:39:50 +01:00 committed by GitHub
parent 6519bcd40c
commit e6eed9ea79
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -417,7 +417,7 @@ test_data_load = torch.utils.data.DataLoader(
# -------------------------------------------
# The optimizer
optimizer = torch.optim.Adam(network.parameters(), lr=0.01)
optimizer = torch.optim.Adam(network.parameters(), lr=0.001)
# The LR Scheduler
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
@ -498,3 +498,50 @@ for epoch_id in range(0, number_of_epoch):
tb.close()
```
## Mean square error
You might be inclined to use the MSE instead of the cross entropy.
But be aware that you need to change more than just the loss function from
```python
loss_function = torch.nn.CrossEntropyLoss()
```
to
```python
loss_function = torch.nn.MSELoss()
```
Why? Because the input changes from the correct class represented by an integer into a one hot encoded vector.
A fast way to do so is this function which uses in-place scatter
```python
def class_to_one_hot(
correct_label: torch.Tensor, number_of_neurons: int
) -> torch.Tensor:
target_one_hot: torch.Tensor = torch.zeros(
(correct_label.shape[0], number_of_neurons)
)
target_one_hot.scatter_(
1, correct_label.unsqueeze(1), torch.ones((correct_label.shape[0], 1))
)
return target_one_hot
```
Obviously, we also need to modify this line
```python
loss = loss_function(output, target)
```
to
```python
loss = loss_function(
output, class_to_one_hot(target, number_of_output_channels_full1)
)
```