locuslab / torchdeq

Modern Fixed Point Systems using Pytorch
MIT License
82 stars 10 forks source link

Custom autograd fails with torchdeq in eval mode #4

Open BurgerAndreas opened 7 months ago

BurgerAndreas commented 7 months ago

It's a very nieche problem, but tripped me over big time :')

Issue

For model.eval() , z_pred will not have tracked gradients (z_pred.requires_gradient==False). For custom torch.autograd this will lead to an error: RuntimeError: One of the differentiated Tensors does not require grad.

Minimal example


import torch

import torchdeq
from torchdeq import get_deq
from torchdeq.norm import apply_norm, reset_norm

class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.layer = torch.nn.Linear(10, 10)

        # deq
        self.deq = get_deq()
        apply_norm(self.layer, 'weight_norm')

    def implicit_layer(self, x):
        return self.layer(x)

    def forward(self, x, pos):

        z = torch.zeros_like(x)

        reset_norm(self.layer)

        f = lambda z: self.f(z, pos)

        z_pred, info = self.deq(self.implicit_layer, z)

        # if model.eval() -> z_pred[-1].requires_grad is False!
        energy = z_pred[-1]
        forces = -1 * (
            torch.autograd.grad(
                energy,
                # diff with respect to pos
                # if you get 'One of the differentiated Tensors appears to not have been used in the graph'
                # then because pos is not 'used' to calculate the energy
                pos, 
                grad_outputs=torch.ones_like(energy),
                create_graph=True,
                # allow_unused=True, 
            )[0]
        )

        return energy, forces

def run(model, eval=False):

    if eval:
        model.eval()
    else:
        model.train()

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    for step in range(10):
        x = torch.randn(10, 10)
        pos = torch.randn(10, 3)
        energy, forces = model(x, pos)

        # loss
        optimizer.zero_grad()
        energy_target = torch.randn(10, 1)
        energy_loss = torch.nn.functional.mse_loss(energy, energy_target)
        force_target = torch.randn(10, 3)
        force_loss = torch.nn.functional.mse_loss(forces, force_target)
        loss = energy_loss + force_loss

        if not eval:
            loss.backward()
            optimizer.step()

    return True

if __name__ == '__main__':
    model = MyModel()
    success = run(model, eval=False)
    print(f'train success: {success}')
    success = run(model, eval=True)
    print(f'eval success: {success}')

While model.train() it will work perfectly well. For model.eval() we get the error: RuntimeError: One of the differentiated Tensors does not require grad.

Desired behaviour

A flag to set such that z_pred[-1].requires_grad is always True, even when model.eval(). self.deq = get_deq(grad_in_eval=True)

Gsunshine commented 6 months ago

Hi Andreas @BurgerAndreas ,

Thanks a lot for your interest! I think a quick fix is to enable self.deq to be in the train mode while other components of the model are in eval mode.

I appreciate the suggestion! I think we can implement such a feature into the lib. Feel free to submit a PR. I'll be back to close this issue soon.

Thanks, Zhengyang