diff --git a/pytorch/augmentation/README.md b/pytorch/augmentation/README.md index 38746e5..eb82234 100644 --- a/pytorch/augmentation/README.md +++ b/pytorch/augmentation/README.md @@ -467,7 +467,7 @@ plt.show() Apply randomly a list of transformations with a given probability. -**Note: It randomly applies the whole list of transformation or none. ** +**Note: It randomly applies the whole list of transformation or none.** ```python randomapply_transform = tv.transforms.RandomApply( @@ -485,3 +485,41 @@ plt.show() ``` ![image25](image25.png) + + +## Building your own filter +In the case you need a special filter then you just can write it very easily on your own. Here is an example. + +```python +import torch + + +class OnOffFilter(torch.nn.Module): + def __init__(self, p: float = 0.5) -> None: + super(OnOffFilter, self).__init__() + + self.p: float = p + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + + assert tensor.shape[1] == 1 + + tensor -= self.p + temp_0: torch.Tensor = torch.where( + tensor < 0.0, -tensor, tensor.new_zeros(tensor.shape, dtype=tensor.dtype) + ) + temp_1: torch.Tensor = torch.where( + tensor >= 0.0, tensor, tensor.new_zeros(tensor.shape, dtype=tensor.dtype) + ) + + new_tensor: torch.Tensor = torch.cat((temp_0, temp_1), dim=1) + + return new_tensor + + def __repr__(self): + return self.__class__.__name__ + "(p={0})".format(self.p) + + +if __name__ == "__main__": + pass +```