# ContourExtract.py # ==================================== # extracts contours from gray-level images # # Version V1.0, pre-07.03.2023: # no actual changes, is David's last code version... # # Version V1.1, 07.03.2023: # merged David's rebuild code (GUI capable) # import torch import time import math class ContourExtract(torch.nn.Module): image_width: int = -1 image_height: int = -1 output_threshold: torch.Tensor = torch.tensor(-1) n_orientations: int sigma_kernel: float = 0 lambda_kernel: float = 0 gamma_aspect_ratio: float = 0 kernel_axis_x: torch.Tensor | None = None kernel_axis_y: torch.Tensor | None = None target_orientations: torch.Tensor weight_vector: torch.Tensor image_scale: float | None psi_phase_offset_cos: torch.Tensor psi_phase_offset_sin: torch.Tensor fft_gabor_cos_bank: torch.Tensor | None = None fft_gabor_sin_bank: torch.Tensor | None = None pi: torch.Tensor torch_device: torch.device default_dtype = torch.float32 rebuild_kernels: bool def __init__( self, n_orientations: int, sigma_kernel: float, lambda_kernel: float, gamma_aspect_ratio: float = 1.0, image_scale: float | None = 255.0, torch_device: str = "cpu", ): super().__init__() self.torch_device = torch.device(torch_device) self.pi = torch.tensor( math.pi, device=self.torch_device, dtype=self.default_dtype, ) self.psi_phase_offset_cos = torch.tensor( 0.0, device=self.torch_device, dtype=self.default_dtype, ) self.psi_phase_offset_sin = -self.pi / 2 self.gamma_aspect_ratio = gamma_aspect_ratio self.image_scale = image_scale self.update_settings( n_orientations=n_orientations, sigma_kernel=sigma_kernel, lambda_kernel=lambda_kernel, ) def forward(self, input: torch.Tensor) -> torch.Tensor: assert input.ndim == 4, "input must have 4 dims!" # We can only handle one color channel assert input.shape[1] == 1, "input.shape[1] must be 1!" input = input.type(dtype=self.default_dtype).to(device=self.torch_device) if self.image_scale is not None: # scale grey level [0, 255] to range [0.0, 1.0] input /= self.image_scale # Do we have valid kernels? if input.shape[-2] != self.image_width: self.image_width = input.shape[-2] self.rebuild_kernels = True if input.shape[-1] != self.image_height: self.image_height = input.shape[-1] self.rebuild_kernels = True assert self.image_width > 0 assert self.image_height > 0 t0 = time.time() # We need to rebuild the kernels if self.rebuild_kernels is True: self.kernel_axis_x = self.create_kernel_axis(self.image_width) self.kernel_axis_y = self.create_kernel_axis(self.image_height) assert self.kernel_axis_x is not None assert self.kernel_axis_y is not None gabor_cos_bank, gabor_sin_bank = self.create_gabor_filter_bank() assert gabor_cos_bank is not None assert gabor_sin_bank is not None self.fft_gabor_cos_bank = torch.fft.rfft2( gabor_cos_bank, s=None, dim=(-2, -1), norm=None ).unsqueeze(0) self.fft_gabor_sin_bank = torch.fft.rfft2( gabor_sin_bank, s=None, dim=(-2, -1), norm=None ).unsqueeze(0) # compute threshold for ignoring non-zero outputs arising due # to numerical imprecision (fft, kernel definition) assert self.weight_vector is not None norm_input = torch.full( (1, 1, self.image_width, self.image_height), 1.0, device=self.torch_device, dtype=self.default_dtype, ) norm_fft_input = torch.fft.rfft2( norm_input, s=None, dim=(-2, -1), norm=None ) norm_output_sin = torch.fft.irfft2( norm_fft_input * self.fft_gabor_sin_bank, s=None, dim=(-2, -1), norm=None, ) norm_output_cos = torch.fft.irfft2( norm_fft_input * self.fft_gabor_cos_bank, s=None, dim=(-2, -1), norm=None, ) norm_value_matrix = torch.sqrt(norm_output_sin**2 + norm_output_cos**2) # norm_output = torch.abs( # (self.weight_vector * norm_value_matrix).sum(dim=-1) # ).type(dtype=torch.float32) self.output_threshold = norm_value_matrix.max() self.rebuild_kernels = False assert self.fft_gabor_cos_bank is not None assert self.fft_gabor_sin_bank is not None t1 = time.time() fft_input = torch.fft.rfft2(input, s=None, dim=(-2, -1), norm=None) output_sin = torch.fft.irfft2( fft_input * self.fft_gabor_sin_bank, s=None, dim=(-2, -1), norm=None ) output_cos = torch.fft.irfft2( fft_input * self.fft_gabor_cos_bank, s=None, dim=(-2, -1), norm=None ) t2 = time.time() output = torch.sqrt(output_sin**2 + output_cos**2) t3 = time.time() print( "ContourExtract {:.3f}s: prep-{:.3f}s, fft-{:.3f}s, out-{:.3f}s".format( t3 - t0, t1 - t0, t2 - t1, t3 - t2 ) ) return output def create_collapse(self, input: torch.Tensor) -> torch.Tensor: assert self.weight_vector is not None output = torch.abs((self.weight_vector * input).sum(dim=1)).type( dtype=self.default_dtype ) return output def create_kernel_axis(self, axis_size: int) -> torch.Tensor: lower_bound_axis: int = -int(math.floor(axis_size / 2)) upper_bound_axis: int = int(math.ceil(axis_size / 2)) kernel_axis = torch.arange( lower_bound_axis, upper_bound_axis, device=self.torch_device, dtype=self.default_dtype, ) kernel_axis = torch.roll(kernel_axis, int(math.ceil(axis_size / 2))) return kernel_axis def create_gabor_filter_bank(self) -> tuple[torch.Tensor, torch.Tensor]: assert self.kernel_axis_x is not None assert self.kernel_axis_y is not None assert self.target_orientations is not None orientation_matrix = ( self.target_orientations.unsqueeze(-1).unsqueeze(-1).detach().clone() ) x_kernel_matrix = self.kernel_axis_x.unsqueeze(0).unsqueeze(-1).detach().clone() y_kernel_matrix = self.kernel_axis_y.unsqueeze(0).unsqueeze(0).detach().clone() r2 = x_kernel_matrix**2 + self.gamma_aspect_ratio * y_kernel_matrix**2 kr = x_kernel_matrix * torch.cos( orientation_matrix ) + y_kernel_matrix * torch.sin(orientation_matrix) c0 = torch.exp(-2 * (self.pi * self.sigma_kernel / self.lambda_kernel) ** 2) gauss: torch.Tensor = torch.exp(-r2 / 2 / self.sigma_kernel**2) gabor_cos_bank: torch.Tensor = gauss * ( torch.cos(2 * self.pi * kr / self.lambda_kernel + self.psi_phase_offset_cos) - c0 * torch.cos(self.psi_phase_offset_cos) ) gabor_sin_bank: torch.Tensor = gauss * ( torch.cos(2 * self.pi * kr / self.lambda_kernel + self.psi_phase_offset_sin) - c0 * torch.cos(self.psi_phase_offset_sin) ) return gabor_cos_bank, gabor_sin_bank def update_settings( self, n_orientations: int, sigma_kernel: float, lambda_kernel: float, ): self.rebuild_kernels = True self.n_orientations = n_orientations self.sigma_kernel = sigma_kernel self.lambda_kernel = lambda_kernel # generate orientation axis and axis for complex summation self.target_orientations: torch.Tensor = ( torch.arange( start=0, end=int(self.n_orientations), device=self.torch_device, dtype=self.default_dtype, ) * torch.tensor( math.pi, device=self.torch_device, dtype=self.default_dtype, ) / self.n_orientations ) self.weight_vector: torch.Tensor = ( torch.exp( 2.0 * torch.complex(torch.tensor(0.0), torch.tensor(1.0)) * self.target_orientations ) .unsqueeze(-1) .unsqueeze(-1) .unsqueeze(0) )