Published on

Custom Autograd Functions

A simple example

For complex operations not supported by Pytorch's built-in function, we can define custom autograd functions.

import torch

def my_custom_autograd_function(x):
    return x**2 + x*x

x = torch.tensor(2.0, requires_grad=True)
y = my_custom_autograd_function(x)

y.backward()

print(x.grad) # 6.0

Gradient Clipping

To prevent exploding gradients, we can clip the gradients during backpropagation.

import torch.nn as nn

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

THE END