diff --git a/functions/Anime.py b/functions/Anime.py index 628624c..bfc4e46 100644 --- a/functions/Anime.py +++ b/functions/Anime.py @@ -22,6 +22,7 @@ class Anime: colorbar: bool = True, vmin_scale: float | None = None, vmax_scale: float | None = None, + movie_file: str | None = None, ) -> None: assert input.ndim == 3 @@ -79,12 +80,14 @@ class Anime: plt.title(f"{bar} {i} of {int(input_np.shape[0]-1)}", loc="left") return - _ = matplotlib.animation.FuncAnimation( + ani = matplotlib.animation.FuncAnimation( fig, next_frame, frames=int(input.shape[0]), interval=interval, repeat=repeat, ) - - plt.show() + if movie_file is not None: + ani.save(movie_file) + else: + plt.show() diff --git a/functions/DataContainer.py b/functions/DataContainer.py index 09103ad..ccc7233 100644 --- a/functions/DataContainer.py +++ b/functions/DataContainer.py @@ -232,10 +232,12 @@ class DataContainer(torch.nn.Module): else: assert self.acceptor is not None - assert self.acceptor.ndim == temp.ndim + assert self.acceptor.ndim + 1 == temp.ndim assert self.acceptor.shape[0] == temp.shape[0] assert self.acceptor.shape[1] == temp.shape[1] - assert self.acceptor.shape[3] == temp.shape[3] + # assert self.acceptor.shape[2] == temp.shape[2] + assert temp.shape[3] == 4 + self.acceptor = torch.cat( ( self.acceptor, @@ -258,10 +260,12 @@ class DataContainer(torch.nn.Module): else: assert self.donor is not None - assert self.donor.ndim == temp.ndim + assert self.donor.ndim + 1 == temp.ndim assert self.donor.shape[0] == temp.shape[0] assert self.donor.shape[1] == temp.shape[1] - assert self.donor.shape[3] == temp.shape[3] + # assert self.donor.shape[2] == temp.shape[2] + assert temp.shape[3] == 4 + self.donor = torch.cat( ( self.donor, @@ -284,10 +288,12 @@ class DataContainer(torch.nn.Module): ) else: assert self.oxygenation is not None - assert self.oxygenation.ndim == temp.ndim + assert self.oxygenation.ndim + 1 == temp.ndim assert self.oxygenation.shape[0] == temp.shape[0] assert self.oxygenation.shape[1] == temp.shape[1] - assert self.oxygenation.shape[3] == temp.shape[3] + # assert self.oxygenation.shape[2] == temp.shape[2] + assert temp.shape[3] == 4 + self.oxygenation = torch.cat( ( self.oxygenation, @@ -311,10 +317,12 @@ class DataContainer(torch.nn.Module): ) else: assert self.volume is not None - assert self.volume.ndim == temp.ndim + assert self.volume.ndim + 1 == temp.ndim assert self.volume.shape[0] == temp.shape[0] assert self.volume.shape[1] == temp.shape[1] - assert self.volume.shape[3] == temp.shape[3] + # assert self.volume.shape[2] == temp.shape[2] + assert temp.shape[3] == 4 + self.volume = torch.cat( ( self.volume,