Open beverlylytle opened 1 week ago
Here is what is happening,
# This is how the thunder interpreter sees the above function.
def f(x):
torch._C._set_grad_enabled(False)
print(torch.is_grad_enabled())
result = x * 2 # CompileData.is_grad_enabled = False only for this part.
torch._C._set_grad_enabled(True)
return result
We can verify this from thunder.last_traces
, adding the following lines to the above script
# Verifying from trace
traces = thunder.last_traces(jf)
print(traces[0])
Trace
def computation(x):
# x: "cpu f32[2, 2]"
# /home/kkalambarkar/git/pytorch/torch/autograd/grad_mode.py:187: torch._C._set_grad_enabled(mode)
ltorch._set_grad_enabled_with_warning(False)
# /home/kkalambarkar/lightning-thunder/scratchpad/test.py:153: return x * 2
t0 = ltorch.mul(x, 2) # t0: "cpu f32[2, 2]"
# t0 = prims.mul(x, 2.0) # t0: "cpu f32[2, 2]"
# /home/kkalambarkar/git/pytorch/torch/autograd/grad_mode.py:187: torch._C._set_grad_enabled(mode)
ltorch._set_grad_enabled_with_warning(True)
return {'output': t0, 'flat_args': [x]}
So, we can see that at the end of the function, we have this _set_grad_enabled(True)
which sets the CompileData.is_grad_enabled to True
.
CompileData.is_grad_enabled will only be False
between the no_grad
region. So, if ltorch.mul
or other operator queried this state, it would have seen False
.
NOTE - The Symbols which care about this should query this state during Tracing. Post that is_grad_enabled
will reflect the last state it was updated to.
Also, the reason we see print(torch.is_grad_enabled())
printing True
is because, our lookaside for _set_grad_enabled
only updates the state in CompileData
and doesn't actually call torch._C._set_grad_enabled
which would have updated the state for PyTorch. So, PyTorch never knows this while tracing and hence is_grad_enabled
returns True
. (I am not sure if we do/want to support printing during tracing).
🐛 Bug
For a function decorated with
torch.no_grad
, the compile data of the jitted version hasis_grad_enabled
set to True when I would expect it to be False.Code sample