41 lines
1.3 KiB
Python
41 lines
1.3 KiB
Python
import torch
|
|
from typing import Callable, Any
|
|
|
|
|
|
class Functional2Layer(torch.nn.Module):
|
|
def __init__(
|
|
self, func: Callable[..., torch.Tensor], *args: Any, **kwargs: Any
|
|
) -> None:
|
|
super().__init__()
|
|
self.func = func
|
|
self.args = args
|
|
self.kwargs = kwargs
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
return self.func(input, *self.args, **self.kwargs)
|
|
|
|
def extra_repr(self) -> str:
|
|
func_name = (
|
|
self.func.__name__ if hasattr(self.func, "__name__") else str(self.func)
|
|
)
|
|
args_repr = ", ".join(map(repr, self.args))
|
|
kwargs_repr = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items())
|
|
return f"func={func_name}, args=({args_repr}), kwargs={{{kwargs_repr}}}"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
print("Permute Example")
|
|
test_layer_permute = Functional2Layer(func=torch.permute, dims=(0, 2, 3, 1))
|
|
input = torch.zeros((10, 11, 12, 13))
|
|
output = test_layer_permute(input)
|
|
print(input.shape)
|
|
print(output.shape)
|
|
print(test_layer_permute)
|
|
|
|
print()
|
|
print("Clamp Example")
|
|
test_layer_clamp = Functional2Layer(func=torch.clamp, min=5, max=100)
|
|
output = test_layer_permute(input)
|
|
print(output[0, 0, 0, 0])
|
|
print(test_layer_clamp)
|