kk_contour_net_shallow/functions/test.py
2023-07-22 14:53:46 +02:00

58 lines
1.5 KiB
Python

import torch
import logging
@torch.no_grad()
def test(
model: torch.nn.modules.container.Sequential,
loader: torch.utils.data.dataloader.DataLoader,
device: torch.device,
tb,
epoch: int,
logger: logging.Logger,
test_accuracy: list[float],
test_losses: list[float],
scale_data: float,
) -> float:
test_loss: float = 0.0
correct: int = 0
pattern_count: float = 0.0
model.eval()
for data in loader:
label = data[0].to(device)
image = data[1].type(dtype=torch.float32).to(device)
if scale_data > 0:
image /= scale_data
output = model(image)
# loss and optimization
loss = torch.nn.functional.cross_entropy(output, label, reduction="sum")
pattern_count += float(label.shape[0])
test_loss += loss.item()
prediction = output.argmax(dim=1)
correct += prediction.eq(label).sum().item()
logger.info(
(
"Test set:"
f" Average loss: {test_loss / pattern_count:.3e},"
f" Accuracy: {correct}/{pattern_count},"
f"({100.0 * correct / pattern_count:.2f}%)"
)
)
logger.info("")
acc = 100.0 * correct / pattern_count
test_losses.append(test_loss / pattern_count)
test_accuracy.append(acc)
# add to tb:
tb.add_scalar("Test Loss", (test_loss / pattern_count), epoch)
tb.add_scalar("Test Performance", 100.0 * correct / pattern_count, epoch)
tb.add_scalar("Test Number Correct", correct, epoch)
tb.flush()
return acc