Create README.md
Signed-off-by: David Rotermund <54365609+davrot@users.noreply.github.com>
This commit is contained in:
parent
1d956be537
commit
f9a6fa068c
1 changed files with 116 additions and 0 deletions
116
pytorch/replace_autograd/README.md
Normal file
116
pytorch/replace_autograd/README.md
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
# Replace autograd
|
||||||
|
{:.no_toc}
|
||||||
|
|
||||||
|
<nav markdown="1" class="toc-class">
|
||||||
|
* TOC
|
||||||
|
{:toc}
|
||||||
|
</nav>
|
||||||
|
|
||||||
|
## Top
|
||||||
|
|
||||||
|
Questions to [David Rotermund](mailto:davrot@uni-bremen.de)
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
```python
|
||||||
|
class FunctionalLinear(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward( # type: ignore
|
||||||
|
ctx, input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
output = (weight.unsqueeze(0) * input.unsqueeze(1)).sum(dim=-1)
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
output = output + bias.unsqueeze(0)
|
||||||
|
|
||||||
|
# ###########################################################
|
||||||
|
# Save the necessary data for the backward pass
|
||||||
|
# ###########################################################
|
||||||
|
ctx.save_for_backward(
|
||||||
|
input,
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward( # type: ignore
|
||||||
|
ctx, grad_output: torch.Tensor
|
||||||
|
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
|
||||||
|
# ##############################################
|
||||||
|
# Get the variables back
|
||||||
|
# ##############################################
|
||||||
|
(
|
||||||
|
input,
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
) = ctx.saved_tensors
|
||||||
|
|
||||||
|
# ##############################################
|
||||||
|
# Default output
|
||||||
|
# ##############################################
|
||||||
|
grad_input: torch.Tensor | None = None
|
||||||
|
grad_weight: torch.Tensor | None = None
|
||||||
|
grad_bias: torch.Tensor | None = None
|
||||||
|
|
||||||
|
grad_weight = grad_output.unsqueeze(-1) * input.unsqueeze(-2)
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
grad_bias = grad_output.detach().clone()
|
||||||
|
|
||||||
|
grad_input = (grad_output.unsqueeze(-1) * weight.unsqueeze(0)).sum(dim=1)
|
||||||
|
|
||||||
|
return grad_input, grad_weight, grad_bias
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MyOwnLayer(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
bias: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert in_features > 0
|
||||||
|
assert out_features > 0
|
||||||
|
|
||||||
|
self.in_features: int = in_features
|
||||||
|
self.out_features: int = out_features
|
||||||
|
|
||||||
|
self.weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
(out_features, in_features),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if bias:
|
||||||
|
self.bias = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
out_features,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.register_parameter("bias", None)
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
self.functional_linear = FunctionalLinear.apply
|
||||||
|
|
||||||
|
def reset_parameters(self) -> None:
|
||||||
|
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||||
|
if self.bias is not None:
|
||||||
|
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
|
||||||
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||||
|
torch.nn.init.uniform_(self.bias, -bound, bound)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.functional_linear(input, self.weight, self.bias)
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"
|
||||||
|
```
|
||||||
|
|
Loading…
Reference in a new issue