graphcore / poptorch

PyTorch interface for the IPU
https://docs.graphcore.ai/projects/poptorch-user-guide/en/latest/
MIT License
177 stars 15 forks source link

Calculation of derivatives on the IPU (autograd.grad) #18

Open Bedledl opened 1 year ago

Bedledl commented 1 year ago

Hello,

In the poptorch documentation i read about the poptorch.identity_loss() function, should be an equivalent to the backward() funciton if pytorch. Is there a way to access the gradient on the input-tensor after the identity_loss() call?

I made a small minimal example, that shows that i can use either backward and access the input.grad variable or use autograd.grad to retreive the derivative for an input.

My question is: how can i retreive the gradient on the input tensor after an identity_loss call and return it as an additional return-value? With wich argument would i have to call the identity_loss method? Can i access this gradient even with the inferencedModel-wrapper?

Thank you for your help:)

Minimal example:

import torch
import poptorch
from torch import nn

input_dim = 2
hidden_dim = 4
output_dim = 1

model_0 = nn.Sequential(nn.Linear(input_dim, hidden_dim), nn.Tanh(), nn.Linear(hidden_dim, output_dim)).double()

class IPUModel(nn.Module):
    def __init__(self):
        super(IPUModel, self).__init__()
        self.nn = model_0

    def forward(self, in_):
        out = self.nn(in_)
        poptorch.identity_loss(out, reduction="none")
        return out, in_.grad

class CPUModel(nn.Module):
    def __init__(self):
        super(CPUModel, self).__init__()
        self.nn = model_0

    def forward(self, in_):
        out = self.nn(in_)
        return out

class CPUModelBackward(nn.Module):
    def __init__(self):
        super(CPUModelBackward, self).__init__()
        self.nn = model_0

    def forward(self, in_):
        out = self.nn(in_)
        out.backward(torch.ones_like(out))
        return out, in_.grad

model_IPU = poptorch.inferenceModel(IPUModel())
model_CPU = CPUModel()
model_CPU_back = CPUModelBackward()

x = torch.tensor(
    [[1, 2], [3, 42]], dtype=torch.float64, requires_grad=True
)

y_ipu, grad_ipu = model_IPU(x)
y_cpu = model_CPU(x)
autograd_grad = torch.autograd.grad(y_cpu, x, retain_graph=True, grad_outputs=torch.ones_like(y_cpu))[0]
y, grad_input = model_CPU_back(x)

print(f"The input.grad from the IPU is unfortunatley none: {grad_ipu}")

if torch.all(autograd_grad.eq(grad_input)):
    print(f"Success. This was expected. The gradient calculated by autograd.grad and by backward is {autograd_grad}")
else:
    print("Error!")

Background to this question: I am currently trying to run molecular dynamics simulation with a SchNet NN on the IPU. I use the implementation of torchmd-net(https://github.com/torchmd/torchmd-net/blob/main/torchmdnet/models/torchmd_gn.py). At the end of the model(https://github.com/torchmd/torchmd-net/blob/main/torchmdnet/models/model.py line 289) the derivatives of the outputs w.r.t the inputs are calculated with the autograd.grad function. This leads on the IPU to "Unsupported ops found in compiled model: [aten::_index_put_impl, aten::index_add]" Errors from the grad() call.