58 lines
1.6 KiB
Python
58 lines
1.6 KiB
Python
import torch
|
|
import logging
|
|
|
|
|
|
@torch.no_grad()
|
|
def test(
|
|
model: torch.nn.modules.container.Sequential,
|
|
loader: torch.utils.data.dataloader.DataLoader,
|
|
device: torch.device,
|
|
tb,
|
|
epoch: int,
|
|
logger: logging.Logger,
|
|
test_accuracy: list[float],
|
|
test_losses: list[float],
|
|
scale_data: float,
|
|
) -> float:
|
|
test_loss: float = 0.0
|
|
correct: int = 0
|
|
pattern_count: float = 0.0
|
|
|
|
model.eval()
|
|
|
|
for data in loader:
|
|
label = data[0].to(device)
|
|
image = data[1].type(dtype=torch.float32).to(device)
|
|
if scale_data > 0:
|
|
image /= scale_data
|
|
|
|
output = model(image)
|
|
|
|
# loss and optimization
|
|
loss = torch.nn.functional.cross_entropy(output, label, reduction="sum")
|
|
pattern_count += float(label.shape[0])
|
|
test_loss += loss.item()
|
|
prediction = output.argmax(dim=1)
|
|
correct += prediction.eq(label).sum().item()
|
|
|
|
logger.info(
|
|
(
|
|
"Test set:"
|
|
f" Average loss: {test_loss / pattern_count:.3e},"
|
|
f" Accuracy: {correct}/{pattern_count},"
|
|
f"({100.0 * correct / pattern_count:.2f}%)"
|
|
)
|
|
)
|
|
logger.info("")
|
|
|
|
acc = 100.0 * correct / pattern_count
|
|
test_losses.append(test_loss / pattern_count)
|
|
test_accuracy.append(acc)
|
|
|
|
# add to tb:
|
|
tb.add_scalar("Test Loss", (test_loss / pattern_count), epoch)
|
|
tb.add_scalar("Test Performance", 100.0 * correct / pattern_count, epoch)
|
|
tb.add_scalar("Test Number Correct", correct, epoch)
|
|
tb.flush()
|
|
|
|
return acc
|