Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.08k stars 62 forks source link

Support `torch.Tensor.register_hook` #307

Open carmocca opened 3 months ago

carmocca commented 3 months ago

🚀 Feature

Motivation

In Lightning Fabric, we use this once for error checking that the user properly called backward. https://github.com/Lightning-AI/pytorch-lightning/blob/096b063d6eeb41567409f4a6b9bac6f5af28ed93/src/lightning/fabric/wrappers.py#L232-L233. cc @awaelchli

I don't expect that we run this hook properly on backward, but it would be useful to simply ignore it and not fail on it, maybe showing a warning.

Pitch

import thunder
import torch

def hook(_):
    print("Hello")

def fn(x):
    y = x * 2
    y.register_hook(hook)
    return y

t = torch.tensor([1.0], requires_grad=True)
fn = thunder.jit(fn)
out = fn(t)
out.backward()
print(out)
    y.register_hook(hook)
  File "/home/carmocca/git/lightning-thunder/thunder/core/interpreter.py", line 5862, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/carmocca/git/lightning-thunder/thunder/core/interpreter.py", line 1243, in jit_wrapped
    res = ufn(*uargs, **ukwargs)
  File "/home/carmocca/git/lightning-thunder/thunder/core/proxies.py", line 1210, in __getattr__
    method: None | Callable = resolve_method(attr, self)
  File "/home/carmocca/git/lightning-thunder/thunder/core/langctxs.py", line 68, in resolve_method
    method: Callable = ctx.get_method(id, *args, **kwargs)
  File "/home/carmocca/git/lightning-thunder/thunder/torch/langctx.py", line 40, in get_method
    raise AttributeError(f"The {self.name} language context has no method {id}")
AttributeError: The torch language context has no method register_hook

Additional context

t-vi commented 3 months ago

You mean something along the lines of Lightning-AI/lit-thunder-LEGACY#1779 ?

carmocca commented 3 months ago

Oh yes perfect. I was happy with just not erroring out because otherwise we would need to comment this out in Fabric if we want to compile forward and the loss together

t-vi commented 3 months ago

Thinking about this more, it's not so clear that it is a reasonable implementation though, because the JITed things will need backward hook calls generated for them because we don't use the autograd engine for it..