nnmf_24b/Functional2Layer.py

41 lines
1.3 KiB
Python

import torch
from typing import Callable, Any
class Functional2Layer(torch.nn.Module):
def __init__(
self, func: Callable[..., torch.Tensor], *args: Any, **kwargs: Any
) -> None:
super().__init__()
self.func = func
self.args = args
self.kwargs = kwargs
def forward(self, input: torch.Tensor) -> torch.Tensor:
return self.func(input, *self.args, **self.kwargs)
def extra_repr(self) -> str:
func_name = (
self.func.__name__ if hasattr(self.func, "__name__") else str(self.func)
)
args_repr = ", ".join(map(repr, self.args))
kwargs_repr = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items())
return f"func={func_name}, args=({args_repr}), kwargs={{{kwargs_repr}}}"
if __name__ == "__main__":
print("Permute Example")
test_layer_permute = Functional2Layer(func=torch.permute, dims=(0, 2, 3, 1))
input = torch.zeros((10, 11, 12, 13))
output = test_layer_permute(input)
print(input.shape)
print(output.shape)
print(test_layer_permute)
print()
print("Clamp Example")
test_layer_clamp = Functional2Layer(func=torch.clamp, min=5, max=100)
output = test_layer_permute(input)
print(output[0, 0, 0, 0])
print(test_layer_clamp)