# Sparsifier.py # ==================================== # matches dictionary patches to contour images # # Version V1.0, 07.03.2023: # slight parameter scaling changes to David's last code version... # # Version V1.1, 07.03.2023: # merged David's rebuild code (GUI capable) # import torch import math import matplotlib.pyplot as plt import time class Sparsifier(torch.nn.Module): dictionary_filter_fft: torch.Tensor | None = None dictionary_filter: torch.Tensor dictionary: torch.Tensor parameter_ready: bool dictionary_ready: bool contour_convolved_sum: torch.Tensor | None = None use_map: torch.Tensor | None = None position_found: torch.Tensor | None = None size_exp_deadzone: float number_of_patches: int padding_deadzone_size_x: int padding_deadzone_size_y: int plot_use_map: bool deadzone_exp: bool deadzone_hard_cutout: int # 0 = not, 1 = round, 2 = box deadzone_hard_cutout_size: float dictionary_prior: torch.Tensor | None pi: torch.Tensor torch_device: torch.device default_dtype = torch.float32 def __init__( self, dictionary_filter: torch.Tensor, dictionary: torch.Tensor, dictionary_prior: torch.Tensor | None = None, number_of_patches: int = 10, # size_exp_deadzone: float = 1.0, # padding_deadzone_size_x: int = 0, # padding_deadzone_size_y: int = 0, # plot_use_map: bool = False, deadzone_exp: bool = True, # deadzone_hard_cutout: int = 1, # 0 = not, 1 = round deadzone_hard_cutout_size: float = 1.0, # torch_device: str = "cpu", ): super().__init__() self.dictionary_ready = False self.parameter_ready = False self.torch_device = torch.device(torch_device) self.pi = torch.tensor( math.pi, device=self.torch_device, dtype=self.default_dtype, ) self.plot_use_map = plot_use_map self.update_parameter( number_of_patches, size_exp_deadzone, padding_deadzone_size_x, padding_deadzone_size_y, deadzone_exp, deadzone_hard_cutout, deadzone_hard_cutout_size, ) self.update_dictionary(dictionary_filter, dictionary, dictionary_prior) def update_parameter( self, number_of_patches: int, size_exp_deadzone: float, padding_deadzone_size_x: int, padding_deadzone_size_y: int, deadzone_exp: bool, deadzone_hard_cutout: int, deadzone_hard_cutout_size: float, ) -> None: self.parameter_ready = False assert size_exp_deadzone > 0.0 assert number_of_patches > 0 assert padding_deadzone_size_x >= 0 assert padding_deadzone_size_y >= 0 self.number_of_patches = number_of_patches self.size_exp_deadzone = size_exp_deadzone self.padding_deadzone_size_x = padding_deadzone_size_x self.padding_deadzone_size_y = padding_deadzone_size_y self.deadzone_exp = deadzone_exp self.deadzone_hard_cutout = deadzone_hard_cutout self.deadzone_hard_cutout_size = deadzone_hard_cutout_size self.parameter_ready = True def update_dictionary( self, dictionary_filter: torch.Tensor, dictionary: torch.Tensor, dictionary_prior: torch.Tensor | None = None, ) -> None: self.dictionary_ready = False assert dictionary_filter.ndim == 4 assert dictionary.ndim == 4 # Odd number of pixels. Please! assert (dictionary_filter.shape[-2] % 2) == 1 assert (dictionary_filter.shape[-1] % 2) == 1 self.dictionary_filter = dictionary_filter.detach().clone() self.dictionary = dictionary if dictionary_prior is not None: assert dictionary_prior.ndim == 1 assert dictionary_prior.shape[0] == dictionary_filter.shape[0] self.dictionary_prior = dictionary_prior self.dictionary_filter_fft = None self.dictionary_ready = True def fourier_convolution(self, contour: torch.Tensor): # Pattern, X, Y assert contour.dim() == 4 assert self.dictionary_filter is not None assert contour.shape[-2] >= self.dictionary_filter.shape[-2] assert contour.shape[-1] >= self.dictionary_filter.shape[-1] t0 = time.time() contour_fft = torch.fft.rfft2(contour, dim=(-2, -1)) t1 = time.time() if ( (self.dictionary_filter_fft is None) or (self.dictionary_filter_fft.dim() != 4) or (self.dictionary_filter_fft.shape[-2] != contour.shape[-2]) or (self.dictionary_filter_fft.shape[-1] != contour.shape[-1]) ): dictionary_padded: torch.Tensor = torch.zeros( ( self.dictionary_filter.shape[0], self.dictionary_filter.shape[1], contour.shape[-2], contour.shape[-1], ), device=self.torch_device, dtype=self.default_dtype, ) dictionary_padded[ :, :, : self.dictionary_filter.shape[-2], : self.dictionary_filter.shape[-1], ] = self.dictionary_filter.flip(dims=[-2, -1]) dictionary_padded = dictionary_padded.roll( ( -(self.dictionary_filter.shape[-2] // 2), -(self.dictionary_filter.shape[-1] // 2), ), (-2, -1), ) self.dictionary_filter_fft = torch.fft.rfft2( dictionary_padded, dim=(-2, -1) ) t2 = time.time() assert self.dictionary_filter_fft is not None # dimension order for multiplication: [pat, feat, ori, x, y] self.contour_convolved_sum = torch.fft.irfft2( contour_fft.unsqueeze(1) * self.dictionary_filter_fft.unsqueeze(0), dim=(-2, -1), ).sum(dim=2) # --> [pat, feat, x, y] t3 = time.time() self.use_map: torch.Tensor = torch.ones( (contour.shape[0], 1, contour.shape[-2], contour.shape[-1]), device=self.torch_device, dtype=self.default_dtype, ) if self.padding_deadzone_size_x > 0: self.use_map[:, 0, : self.padding_deadzone_size_x, :] = 0.0 self.use_map[:, 0, -self.padding_deadzone_size_x :, :] = 0.0 if self.padding_deadzone_size_y > 0: self.use_map[:, 0, :, : self.padding_deadzone_size_y] = 0.0 self.use_map[:, 0, :, -self.padding_deadzone_size_y :] = 0.0 t4 = time.time() print( "Sparsifier-convol {:.3f}s: fft-img-{:.3f}s, fft-fil-{:.3f}s, convol-{:.3f}s, pad-{:.3f}s".format( t4 - t0, t1 - t0, t2 - t1, t3 - t2, t4 - t3 ) ) return def find_next_element(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert self.use_map is not None assert self.contour_convolved_sum is not None assert self.dictionary_filter is not None # store feature index, x pos and y pos (3 entries) position_found = torch.zeros( (self.contour_convolved_sum.shape[0], 3), dtype=torch.int64, device=self.torch_device, ) overlap_found = torch.zeros( (self.contour_convolved_sum.shape[0], 2), device=self.torch_device, dtype=self.default_dtype, ) for pattern_id in range(0, self.contour_convolved_sum.shape[0]): t0 = time.time() # search_tensor: torch.Tensor = ( # self.contour_convolved[pattern_id] * self.use_map[pattern_id] # ).sum(dim=1) search_tensor: torch.Tensor = ( self.contour_convolved_sum[pattern_id] * self.use_map[pattern_id] ) t1 = time.time() if self.dictionary_prior is not None: search_tensor *= self.dictionary_prior.unsqueeze(-1).unsqueeze(-1) t2 = time.time() temp, max_0 = search_tensor.max(dim=0) temp, max_1 = temp.max(dim=0) temp_overlap, max_2 = temp.max(dim=0) position_base_3: int = int(max_2) position_base_2: int = int(max_1[position_base_3]) position_base_1: int = int(max_0[position_base_2, position_base_3]) position_base_0: int = int(pattern_id) position_found[position_base_0, 0] = position_base_1 position_found[position_base_0, 1] = position_base_2 position_found[position_base_0, 2] = position_base_3 overlap_found[pattern_id, 0] = temp_overlap overlap_found[pattern_id, 1] = self.contour_convolved_sum[ position_base_0, position_base_1, position_base_2, position_base_3 ] t3 = time.time() x_max: int = int(self.contour_convolved_sum.shape[-2]) y_max: int = int(self.contour_convolved_sum.shape[-1]) # Center arround the max position x_0 = int(-position_base_2) x_1 = int(x_max - position_base_2) y_0 = int(-position_base_3) y_1 = int(y_max - position_base_3) temp_grid_x: torch.Tensor = torch.arange( start=x_0, end=x_1, device=self.torch_device, dtype=self.default_dtype, ) temp_grid_y: torch.Tensor = torch.arange( start=y_0, end=y_1, device=self.torch_device, dtype=self.default_dtype, ) x, y = torch.meshgrid(temp_grid_x, temp_grid_y, indexing="ij") # discourage the neigbourhood around for the future if self.deadzone_exp is True: self.temp_map: torch.Tensor = 1.0 - torch.exp( -(x**2 + y**2) / (2 * self.size_exp_deadzone**2) ) else: self.temp_map = torch.ones( (x.shape[0], x.shape[1]), device=self.torch_device, dtype=self.default_dtype, ) assert self.deadzone_hard_cutout >= 0 assert self.deadzone_hard_cutout <= 1 assert self.deadzone_hard_cutout_size >= 0 if self.deadzone_hard_cutout == 1: temp = x**2 + y**2 self.temp_map *= torch.where( temp <= self.deadzone_hard_cutout_size**2, 0.0, 1.0, ) self.use_map[position_base_0, 0, :, :] *= self.temp_map t4 = time.time() # Only for keeping it in the float32 range: self.use_map[position_base_0, 0, :, :] /= self.use_map[ position_base_0, 0, :, : ].max() t5 = time.time() # print( # "{}, {}, {}, {}, {}".format(t1 - t0, t2 - t1, t3 - t2, t4 - t3, t5 - t4) # ) return ( position_found, overlap_found, torch.tensor((t1 - t0, t2 - t1, t3 - t2, t4 - t3, t5 - t4)), ) def forward(self, input: torch.Tensor) -> None: assert self.number_of_patches > 0 assert self.dictionary_ready is True assert self.parameter_ready is True self.position_found = torch.zeros( (input.shape[0], self.number_of_patches, 3), dtype=torch.int64, device=self.torch_device, ) self.overlap_found = torch.zeros( (input.shape[0], self.number_of_patches, 2), device=self.torch_device, dtype=self.default_dtype, ) # Folding the images and the dictionary self.fourier_convolution(input) t0 = time.time() dt = torch.tensor([0, 0, 0, 0, 0]) for patch_id in range(0, self.number_of_patches): ( self.position_found[:, patch_id, :], self.overlap_found[:, patch_id, :], dt_tmp, ) = self.find_next_element() dt = dt + dt_tmp if self.plot_use_map is True: assert self.position_found.shape[0] == 1 assert self.use_map is not None print("Position Saliency:") print(self.overlap_found[0, :, 0]) print("Overlap with Contour Image:") print(self.overlap_found[0, :, 1]) plt.subplot(1, 2, 1) plt.imshow(self.use_map[0, 0].cpu(), cmap="gray") plt.title(f"patch-number: {patch_id}") plt.axis("off") plt.colorbar(shrink=0.5) plt.subplot(1, 2, 2) plt.imshow( (self.use_map[0, 0] * input[0].sum(dim=0)).cpu(), cmap="gray" ) plt.title(f"patch-number: {patch_id}") plt.axis("off") plt.colorbar(shrink=0.5) plt.show() # self.overlap_found /= self.overlap_found.max(dim=1, keepdim=True)[0] t1 = time.time() print( "Sparsifier-forward {:.3f}s: usemap-{:.3f}s, prior-{:.3f}s, findmax-{:.3f}s, notch-{:.3f}s, norm-{:.3f}s: (sum-{:.3f})".format( t1 - t0, dt[0], dt[1], dt[2], dt[3], dt[4], dt.sum() ) )