Update README.md
Signed-off-by: David Rotermund <54365609+davrot@users.noreply.github.com>
This commit is contained in:
parent
6519bcd40c
commit
e6eed9ea79
1 changed files with 48 additions and 1 deletions
|
@ -417,7 +417,7 @@ test_data_load = torch.utils.data.DataLoader(
|
|||
# -------------------------------------------
|
||||
|
||||
# The optimizer
|
||||
optimizer = torch.optim.Adam(network.parameters(), lr=0.01)
|
||||
optimizer = torch.optim.Adam(network.parameters(), lr=0.001)
|
||||
|
||||
# The LR Scheduler
|
||||
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
|
||||
|
@ -498,3 +498,50 @@ for epoch_id in range(0, number_of_epoch):
|
|||
tb.close()
|
||||
```
|
||||
|
||||
## Mean square error
|
||||
|
||||
You might be inclined to use the MSE instead of the cross entropy.
|
||||
|
||||
But be aware that you need to change more than just the loss function from
|
||||
|
||||
```python
|
||||
loss_function = torch.nn.CrossEntropyLoss()
|
||||
```
|
||||
to
|
||||
|
||||
```python
|
||||
loss_function = torch.nn.MSELoss()
|
||||
```
|
||||
|
||||
Why? Because the input changes from the correct class represented by an integer into a one hot encoded vector.
|
||||
|
||||
A fast way to do so is this function which uses in-place scatter
|
||||
|
||||
```python
|
||||
def class_to_one_hot(
|
||||
correct_label: torch.Tensor, number_of_neurons: int
|
||||
) -> torch.Tensor:
|
||||
|
||||
target_one_hot: torch.Tensor = torch.zeros(
|
||||
(correct_label.shape[0], number_of_neurons)
|
||||
)
|
||||
target_one_hot.scatter_(
|
||||
1, correct_label.unsqueeze(1), torch.ones((correct_label.shape[0], 1))
|
||||
)
|
||||
|
||||
return target_one_hot
|
||||
```
|
||||
|
||||
Obviously, we also need to modify this line
|
||||
|
||||
```python
|
||||
loss = loss_function(output, target)
|
||||
```
|
||||
|
||||
to
|
||||
|
||||
```python
|
||||
loss = loss_function(
|
||||
output, class_to_one_hot(target, number_of_output_channels_full1)
|
||||
)
|
||||
```
|
||||
|
|
Loading…
Reference in a new issue