Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.22k stars 82 forks source link

no_grad is lost in jitted functions #1486

Open beverlylytle opened 1 week ago

beverlylytle commented 1 week ago

🐛 Bug

For a function decorated with torch.no_grad, the compile data of the jitted version has is_grad_enabled set to True when I would expect it to be False.

Code sample

import torch
import thunder

@torch.no_grad
def f(x):
    print(torch.is_grad_enabled())
    return x * 2

jf = thunder.jit(f)

x = torch.ones((2,2))

f(x)                                              # prints False
jf(x)                                             # prints True
thunder.compile_data(jf).is_grad_enabled          # True
kshitij12345 commented 6 days ago

Here is what is happening,

# This is how the thunder interpreter sees the above function.
def f(x):
    torch._C._set_grad_enabled(False)
    print(torch.is_grad_enabled())
    result = x * 2  # CompileData.is_grad_enabled = False only for this part.
    torch._C._set_grad_enabled(True)
    return result

We can verify this from thunder.last_traces, adding the following lines to the above script

# Verifying from trace
traces = thunder.last_traces(jf)
print(traces[0])

Trace

def computation(x):
  # x: "cpu f32[2, 2]"

  # /home/kkalambarkar/git/pytorch/torch/autograd/grad_mode.py:187:             torch._C._set_grad_enabled(mode)
  ltorch._set_grad_enabled_with_warning(False)

  # /home/kkalambarkar/lightning-thunder/scratchpad/test.py:153:            return x * 2
  t0 = ltorch.mul(x, 2)  # t0: "cpu f32[2, 2]"
    # t0 = prims.mul(x, 2.0)  # t0: "cpu f32[2, 2]"

  # /home/kkalambarkar/git/pytorch/torch/autograd/grad_mode.py:187:             torch._C._set_grad_enabled(mode)
  ltorch._set_grad_enabled_with_warning(True)
  return {'output': t0, 'flat_args': [x]}

So, we can see that at the end of the function, we have this _set_grad_enabled(True) which sets the CompileData.is_grad_enabled to True.

CompileData.is_grad_enabled will only be False between the no_grad region. So, if ltorch.mul or other operator queried this state, it would have seen False.

NOTE - The Symbols which care about this should query this state during Tracing. Post that is_grad_enabled will reflect the last state it was updated to.

Also, the reason we see print(torch.is_grad_enabled()) printing True is because, our lookaside for _set_grad_enabled only updates the state in CompileData and doesn't actually call torch._C._set_grad_enabled which would have updated the state for PyTorch. So, PyTorch never knows this while tracing and hence is_grad_enabled returns True. (I am not sure if we do/want to support printing during tracing).