Update README.md

Signed-off-by: David Rotermund <54365609+davrot@users.noreply.github.com>
This commit is contained in:
David Rotermund 2024-01-05 16:36:14 +01:00 committed by GitHub
parent 8818ee51ba
commit 450813f03c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -10,6 +10,49 @@
Questions to [David Rotermund](mailto:davrot@uni-bremen.de) Questions to [David Rotermund](mailto:davrot@uni-bremen.de)
## [TORCH.AUTOGRAD.FUNCTION.FUNCTIONCTX.SAVE_FOR_BACKWARD](https://pytorch.org/docs/stable/generated/torch.autograd.function.FunctionCtx.save_for_backward.html)
```python
FunctionCtx.save_for_backward(*tensors)
```
> Saves given tensors for a future call to backward().
>
> save_for_backward should be called at most once, only from inside the forward() method, and only with tensors.
>
> All tensors intended to be used in the backward pass should be saved with save_for_backward (as opposed to directly on ctx) to prevent incorrect gradients and memory leaks, and enable the application of saved tensor hooks.
>
> Note that if intermediary tensors, tensors that are neither inputs nor outputs of forward(), are saved for backward, your custom Function may not support double backward. Custom Functions that do not support double backward should decorate their backward() method with @once_differentiable so that performing double backward raises an error. If youd like to support double backward, you can either recompute intermediaries based on the inputs during backward or return the intermediaries as the outputs of the custom Function. See the [double backward tutorial](https://pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html) for more details.
>
> In backward(), saved tensors can be accessed through the saved_tensors attribute. Before returning them to the user, a check is made to ensure they werent used in any in-place operation that modified their content.
>
> **Arguments can also be None.** This is a no-op.
### Save
```python
ctx.save_for_backward(x, y, w, out)
```
Non-tensor (e.g. int):
```python
ctx.z = z
```
### Access
```python
x, y, w, out = ctx.saved_tensors
```
Non-tensor (e.g. int):
```python
z = ctx.z
```
## Example ## Example
```python ```python
@ -35,6 +78,7 @@ class FunctionalLinear(torch.autograd.Function):
return output return output
@staticmethod @staticmethod
@torch.autograd.function.once_differentiable
def backward( # type: ignore def backward( # type: ignore
ctx, grad_output: torch.Tensor ctx, grad_output: torch.Tensor
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: