gevi/functions/Anime.py

94 lines
2.7 KiB
Python
Raw Normal View History

2023-07-30 23:52:42 +02:00
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.animation
class Anime:
def __init__(self) -> None:
super().__init__()
def show(
self,
input: torch.Tensor | np.ndarray,
mask: torch.Tensor | np.ndarray | None = None,
vmin: float | None = None,
vmax: float | None = None,
cmap: str = "hot",
axis_off: bool = True,
show_frame_count: bool = True,
interval: int = 100,
repeat: bool = False,
colorbar: bool = True,
vmin_scale: float | None = None,
vmax_scale: float | None = None,
2023-09-27 15:45:55 +02:00
movie_file: str | None = None,
2023-07-30 23:52:42 +02:00
) -> None:
assert input.ndim == 3
if isinstance(input, torch.Tensor):
input_np: np.ndarray = input.cpu().numpy()
if mask is not None:
mask_np: np.ndarray | None = (mask == 0).cpu().numpy()
else:
mask_np = None
else:
input_np = input
if mask is not None:
mask_np = mask == 0 # type: ignore
else:
mask_np = None
if vmin is None:
vmin = float(np.where(np.isfinite(input_np), input_np, 0.0).min())
if vmax is None:
vmax = float(np.where(np.isfinite(input_np), input_np, 0.0).max())
if vmin_scale is not None:
vmin *= vmin_scale
if vmax_scale is not None:
vmax *= vmax_scale
fig = plt.figure()
image = np.nan_to_num(input_np[0, ...], copy=True, nan=0.0)
if mask_np is not None:
image[mask_np] = float("NaN")
image_handle = plt.imshow(
image,
cmap=cmap,
vmin=vmin,
vmax=vmax,
)
if colorbar:
plt.colorbar()
if axis_off:
plt.axis("off")
def next_frame(i: int) -> None:
image = np.nan_to_num(input_np[i, ...], copy=True, nan=0.0)
if mask_np is not None:
image[mask_np] = float("NaN")
image_handle.set_data(image)
if show_frame_count:
bar_length: int = 10
filled_length = int(round(bar_length * i / input_np.shape[0]))
bar = "\u25A0" * filled_length + "\u25A1" * (bar_length - filled_length)
plt.title(f"{bar} {i} of {int(input_np.shape[0]-1)}", loc="left")
return
2023-09-27 15:45:55 +02:00
ani = matplotlib.animation.FuncAnimation(
2023-07-30 23:52:42 +02:00
fig,
next_frame,
frames=int(input.shape[0]),
interval=interval,
repeat=repeat,
)
2023-09-27 15:45:55 +02:00
if movie_file is not None:
ani.save(movie_file)
else:
plt.show()