nnmf_24b/PositionalEncoding.py

23 lines
607 B
Python
Raw Normal View History

2024-07-26 12:55:02 +02:00
import torch
class PositionalEncoding(torch.nn.Module):
init_std: float
pos_embedding: torch.nn.Parameter
def __init__(self, dim: list[int], init_std: float = 0.2):
super().__init__()
self.init_std = init_std
assert len(dim) == 3
self.pos_embedding: torch.nn.Parameter = torch.nn.Parameter(
torch.randn(1, *dim)
)
self.init_parameters()
def init_parameters(self):
torch.nn.init.trunc_normal_(self.pos_embedding, std=self.init_std)
def forward(self, input: torch.Tensor):
return input + self.pos_embedding