pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Small difference between functorch grads and torch.autograd.grad #1144

Open alanjeffares opened 1 month ago

alanjeffares commented 1 month ago

Hello, I am using this linked solution from stack overflow to compute gradients more efficiently than a manual loop.

I notice that there is some small difference in the gradients calculated using the two methods (i.e. torch.abs(grads_torch - grads_func).sum() returns ~10e-06). What might explain this difference? Is one solution more correct than the other?

MWE

import torch
from torchvision import datasets, transforms
import torch.nn as nn

###### SETUP ######

class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h = self.fc1(x)
        pred = self.fc2(self.relu(h))
        return pred

train_dataset = datasets.MNIST(root='./data', train=True, download=True, 
                            transform=transforms.Compose(
                                [transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (0.5,))
        ]))

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=False)

X, y = next(iter(train_dataloader))  # take a random batch of data

net = MLP(28*28, 20, 10)  # define a network

###### CALCULATE GRADIENTS WITH TORCH AUTOGRAD GRAD ######
def calculate_gradients(model, X):
    # Create a tensor to hold the gradients
    gradients = torch.zeros(X.shape[0], 10, sum(p.numel() for p in model.parameters()))

    # Calculate the gradients for each input and target dimension
    for i in range(X.shape[0]):
        for j in range(10):
            model.zero_grad()
            output = model(X[i])
            # Calculate the gradients
            grads = torch.autograd.grad(output[j], model.parameters())
            # Flatten the gradients and store them
            gradients[i, j, :] = torch.cat([g.view(-1) for g in grads])

    return gradients

grads_torch = calculate_gradients(net, X.view(X.shape[0], -1))

###### NOW CALCULATE THE SAME GRADIENTS WITH FUNCTORCH ######
# extract the parameters and buffers for a functional call
params = {k: v.detach() for k, v in net.named_parameters()}
buffers = {k: v.detach() for k, v in net.named_buffers()}

def one_sample(sample):
    # this will calculate the gradients for a single sample
    # we want the gradients for each output wrt to the parameters
    # this is the same as the jacobian of the network wrt the parameters

    # define a function that takes the as input returns the output of the network
    call = lambda x: torch.func.functional_call(net, (x, buffers), sample)

    # calculate the jacobian of the network wrt the parameters
    J = torch.func.jacrev(call)(params)

    # J is a dictionary with keys the names of the parameters and values the gradients
    # we want a tensor
    grads = torch.cat([v.flatten(1) for v in J.values()],-1) 
    return grads

# no we can use vmap to calculate the gradients for all samples at once
grads_func = torch.vmap(one_sample)(X.flatten(1))

print(torch.allclose(grads_torch, grads_func))  # returns True
print(torch.abs(grads_torch - grads_func).sum())  # returns tensor(1.4454e-05)