Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.91k stars 3.34k forks source link

Enable gradients in validation_step #15765

Closed jharris1679 closed 1 year ago

jharris1679 commented 1 year ago

🚀 Feature

Enable gradient calculations during evaluation loops.

Motivation

Some loss functions require the gradients of the outputs with respect to the inputs. For example, a physics informed neural network uses these gradients in a differential equation as its loss function.

Pitch

Add a set_grad_enabled flag to validation step to make the following possible for learning the cosine function, for example:

def validation_step(self, 
    batch: torch.Tensor, 
    batch_idx: int, 
    set_grad_enabled=True
) -> dict[str, torch.Tensor]:

    x, t = batch

    u_hat = self.forward(x, t)

    dydx = torch.autograd.grad(
              u_hat, 
              x,
              grad_outputs=torch.ones_like(u_hat),
              create_graph=True
          )

    physics_loss = torch.sin(x) - dydx

    return {"loss": physics_loss}

Alternatives

I tried simply adding with torch.set_grad_enabled(True) to my code but of course that didn't work.

Additional context

Physics Informed Neural Networks are a recently introduced architecture for learning differential equations by embedding them in the loss function of a neural network. The key innovation involves repurposing auto-differentiation machinery to obtain derivatives of network outputs wrt its inputs, then plugging those into the residual form of a differential equation. The network learns accurate derivatives by minimizing this residual.

PINNs paper: https://faculty.sites.iastate.edu/hliu/files/inline-files/PINN_RPK_2019_1.pdf


If you enjoy Lightning, check out our other projects! âš¡

carmocca commented 1 year ago

Hi!

You can enable grad during validation. You need to pass Trainer(inference_mode=False) and then you are free to use the context manager.

Duplicate of https://github.com/Lightning-AI/lightning/issues/13948

jharris1679 commented 1 year ago

Oh fantastic, I'll check this out, thank you!

jharris1679 commented 1 year ago

I've tried this but still seeing RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn during validation steps, but I can get gradients successfully during a training step. Is there something else I'm missing?

carmocca commented 1 year ago

You will need to set tensor.requires_grad_() for the tensors you want to compute grads for.

jharris1679 commented 1 year ago

Thanks for such a quick response! In my training step, I don't need to do that for the net output, and the loss function is working well. However in my validation_step, if I specify requires grad like so

output = self.net(x)
output.requires_grad_(requires_grad=True)

and then take the grads like so

  grads = torch.autograd.grad(
      output, 
      [x],
      grad_outputs=torch.ones_like(output),
      create_graph=True,
      retain_graph=True
  )

it gives

RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behaviour.

which is the same error I would get if I didnt set inference_mode=False. If I set allow_unused=True in torch.autograd.grad then grads are None. It seems as though my Trainer(inference_mode=False) is maybe not doing the trick? Should I perhaps open another issue with a steps to reproduce?

Thanks again for your eyes on this!

jharris1679 commented 1 year ago

Ahh in addition to .requires_grad_() I had to put everything in my validation step under with torch.enable_grad():, and now it works! Thanks again.

GeoffNN commented 1 year ago

Have you tried using functorch to compute the partials ? It's much easier to use.

I have an example script here. https://github.com/GeoffNN/deeponet-fno/blob/main/src/burgers/pytorch_deeponet.py