Update README.md
Signed-off-by: David Rotermund <54365609+davrot@users.noreply.github.com>
This commit is contained in:
parent
f3f3394199
commit
99468e3663
1 changed files with 39 additions and 1 deletions
|
@ -485,3 +485,41 @@ plt.show()
|
|||
```
|
||||
|
||||
![image25](image25.png)
|
||||
|
||||
|
||||
## Building your own filter
|
||||
In the case you need a special filter then you just can write it very easily on your own. Here is an example.
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
|
||||
class OnOffFilter(torch.nn.Module):
|
||||
def __init__(self, p: float = 0.5) -> None:
|
||||
super(OnOffFilter, self).__init__()
|
||||
|
||||
self.p: float = p
|
||||
|
||||
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
assert tensor.shape[1] == 1
|
||||
|
||||
tensor -= self.p
|
||||
temp_0: torch.Tensor = torch.where(
|
||||
tensor < 0.0, -tensor, tensor.new_zeros(tensor.shape, dtype=tensor.dtype)
|
||||
)
|
||||
temp_1: torch.Tensor = torch.where(
|
||||
tensor >= 0.0, tensor, tensor.new_zeros(tensor.shape, dtype=tensor.dtype)
|
||||
)
|
||||
|
||||
new_tensor: torch.Tensor = torch.cat((temp_0, temp_1), dim=1)
|
||||
|
||||
return new_tensor
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + "(p={0})".format(self.p)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
```
|
||||
|
|
Loading…
Reference in a new issue