Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.12k stars 69 forks source link

enter/exit_autocast of torch.amp.autocast_mode #824

Open tfogal opened 1 month ago

tfogal commented 1 month ago

๐Ÿš€ Model / language coverage

I'm trying to get a fuller picture of what we need to support NeVA. As such I'm using:

def thunder_backend(gm, args):
  gm.real_recompile()
  from thunder.examine import examine
  try:   # Examine may raise an exception
      thunder.examine.examine(gm, *args)
  except Exception as e:
      print(f"Hit problem with examine:\n{e}")
  # Don't really use Thunder just return the original graph
  return gm

...
#model.model = thunder.jit(model.model)
model.model = torch.compile(backend=thunder_backend)(model.model)

(thanks Ivan for the great idea!)

And one of the issues that gets reported is e.g.

Found 2 distinct operations, of which 0 (0.0%) are supported
Please file an issue requesting the following operators here: https://github.com/Lightning-AI/lightning-thunder/issues/
new
_enter_autocast of torch.amp.autocast_mode
_exit_autocast of torch.amp.autocast_mode

Pitch

This looks like it is going to be important for #343.

Alternatives / Potential work-arounds

Minimal Repro

cc @apaz-cli @crcrpar @tfogal

tfogal commented 1 month ago

Assigning to me until I can fill out a better reproducer

IvanYashchuk commented 1 month ago

Thunder doesn't support PyTorch's context managers like autocast, no_grad, enable_grad, etc. inside the compiled function. With the recent addition of the autocast-specific dispatch at tracing time (https://github.com/Lightning-AI/lightning-thunder/pull/705, https://github.com/Lightning-AI/lightning-thunder/pull/810) supporting this might not take a lot of work, the challenge is not to reorder these enter and exit calls inappropriately. Another way to approach this problem is to ensure Thunder never sees these calls by adding more graph breaks.

kshitij12345 commented 1 month ago

I think this could be the minimal repro:

import torch
import thunder

class ThunderJitBackend:
    def __init__(self, **compile_options) -> None:        
        self.thunder_jit_fns = []
        self.dynamo_graphs = []
        self.cnt = 0
        self.compile_options = compile_options

    def compile(self, gm, sample_args):
        self.dynamo_graphs.append(gm)
        gm.real_recompile()
        thunder_jit_fn = thunder.jit(gm, **self.compile_options)
        self.thunder_jit_fns.append(thunder_jit_fn)
        self.cnt += 1
        return thunder_jit_fn

dev = "cuda"

def foo(x):
    with torch.autocast(dev, torch.bfloat16):
        y = x @ x
    return x + 2, y

with torch.device(dev):
    model = foo
    x = torch.randn(16, 16)
    args = (x,)
    kwargs = {}

jit_backend = ThunderJitBackend()
cmodel = torch.compile(model, backend=jit_backend.compile)

o = cmodel(*args, **kwargs)
print(f"GRAPHS {jit_backend.cnt}")
print(jit_backend.dynamo_graphs[0])

for tfn in jit_backend.thunder_jit_fns:
    print(thunder.last_traces(tfn)[-1])

torch.testing.assert_close(o, model(*args, **kwargs))

Dynamo Graph

def forward(self, L_x_ : torch.Tensor):
    l_x_ = L_x_
    _enter_autocast = torch.amp.autocast_mode._enter_autocast('cuda', torch.bfloat16, True, None)
    y = l_x_ @ l_x_
    _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast);  _enter_autocast = None
    add = l_x_ + 2;  l_x_ = None
    return (add, y)

With dev="cuda", we see the following error

  File "/home/kkalambarkar/git/pytorch/torch/_dynamo/eval_frame.py", line 410, in _fn
    return fn(*args, **kwargs)
  File "/home/kkalambarkar/git/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/kkalambarkar/git/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/kkalambarkar/lightning-thunder/thunder/core/module.py", line 61, in forward
    res = self._forward_fn(*args, **kwargs)
  File "/home/kkalambarkar/lightning-thunder/thunder/__init__.py", line 685, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
  File "/home/kkalambarkar/lightning-thunder/thunder/__init__.py", line 225, in cache_info_wrapper
    res = fn(*args, **kwargs)
  File "/home/kkalambarkar/lightning-thunder/thunder/__init__.py", line 506, in get_computation_and_inputs
    jit_results: TraceResults = interpreter(
  File "/home/kkalambarkar/lightning-thunder/thunder/__init__.py", line 213, in _general_frontend
    return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)
  File "/home/kkalambarkar/lightning-thunder/thunder/core/jit_ext.py", line 1768, in thunder_general_jit
    result = jfn(*args, **kwargs)
  File "/home/kkalambarkar/lightning-thunder/thunder/core/interpreter.py", line 6760, in fn_
    raise InterpreterError(msg) from e
thunder.core.interpreter.InterpreterError: Encountered exception TypeError: unhashable type: 'instancemethod' while tracing GraphModule()

With dev="cpu", the program compiles but silently ignores the autocast in the function (computes in single precision), failing at torch.testing.assert_close.

tfogal commented 1 month ago

Removing my assignment because Kshiteej is a hero w.r.t. finding minimal reproducers ๐Ÿ˜„. Thank you

tfogal commented 1 month ago

Another way to approach this problem is to ensure Thunder never sees these calls by adding more graph breaks.

Yes, this is a great idea for the interim. I'll see if these are in actionable parts of the code (NeMo, or maybe megatron).

tfogal commented 1 week ago

Some automation revealed an even simpler reproducer:

import torch
import thunder

class DynamoModule(torch.nn.Module):
    def forward(self):
        _enter_autocast = torch.amp.autocast_mode._enter_autocast('cuda', torch.bfloat16, True, None)
        _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast);  _enter_autocast = _exit_autocast = None
        return ()

inputs = [
]
fqn = thunder.jit(DynamoModule())
fqn(*inputs)