Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.14k stars 73 forks source link

Support for torchvision models, e.g., a simple ViT #93

Closed rasbt closed 3 months ago

rasbt commented 5 months ago

🐛 Bug

I was trying to run a simple torchvision ViT and am getting the following error:

File "/teamspace/studios/this_studio/minimal-vit/01_pytorch-vit.py", line 136, in <module>
    train(
  File "/teamspace/studios/this_studio/minimal-vit/01_pytorch-vit.py", line 31, in train
    logits = model(features)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/__init__.py", line 194, in forward
    res = self._forward_fn(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/__init__.py", line 611, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/__init__.py", line 262, in cache_info_wrapper
    res = fn(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/__init__.py", line 498, in get_computation_and_inputs
    prologue_trc, computation_trc, *maybe_epilogue = interpreter(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/__init__.py", line 175, in _general_frontend
    return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/jit_ext.py", line 1386, in thunder_general_jit
    result = jfn(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 6580, in fn_
    raise e
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 6543, in fn_2
    return fn(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torchvision/models/vision_transformer.py", line 298, in forward
    x = self.encoder(x)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torchvision/models/vision_transformer.py", line 157, in forward
    return self.ln(self.layers(self.dropout(input)))
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torchvision/models/vision_transformer.py", line 113, in forward
    x, _ = self.self_attention(x, x, x, need_weights=False)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/activation.py", line 1236, in forward
    any_nested = query.is_nested or key.is_nested or value.is_nested
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 1253, in wrapping_wrapper
    res = ufn(*uargs, **ukwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/proxies.py", line 1234, in __getattr__
    method: None | Callable = resolve_method(attr, self)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/langctxs.py", line 68, in resolve_method
    method: Callable = ctx.get_method(id, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/torch/langctx.py", line 40, in get_method
    raise AttributeError(f"The {self.name} language context has no method {id}")
AttributeError: The torch language context has no method is_nested

Not sure how to go about debugging this. I thought that sharing this may help improving thunder in terms of supporting more models and edge cases

To Reproduce

Steps to reproduce the behavior:

I attached self-contained code in the zip.

# Runs PyTorch eager, works ok!

python 01_pytorch-vit.py

# Runs torch.compile, works ok!
python 01_pytorch-vit.py --compilation_option "torch.compile"

# Runs thunder.jit(), fails! (See error above)
python 01_pytorch-vit.py --compilation_option "thunder_default"

Code sample

See zip attached

Expected behavior

Either a clearer error message or ideally it should work :)

Environment

Same as Zero to Thunder studio.

Archive.zip

cc @apaz-cli

carmocca commented 5 months ago

Hey Seb! @nikitaved Just merged a PR to improve the messaging here: #78

The TLDR is that you want to run examine on the model to get a report of what's not working:

from thunder.examine import examine

x = ...
model = ...
examine(model, x)

It would be useful if you can include here what it reports for those models.

rasbt commented 5 months ago

This is nice, thanks! The report is

Files already downloaded and verified
Found 18 distinct operations, of which 15 (83.3%) are supported
Please file an issue requesting the following operators here: https://github.com/Lightning-AI/lightning-thunder/issues/new
TensorBase.is_nested
multi_head_attention_forward of torch.nn.functional
_assert of torch

So the culprit seems to be https://github.com/pytorch/pytorch/blob/1e8d4b389b5f03cea191ed558051f036fe04f92d/torch/nn/functional.py#L5163

mruberry commented 5 months ago

triage review:

We think there are three issues here:

First issue:

Second issue:

Third issue:

Can we break this issue up into those three, @rasbt?

rasbt commented 5 months ago

This sounds totally reasonable, please feel free to break it up into these three.

Re first issue: Not sure if that's feasible, but perhaps even automatically calling examine upon failure could not be a bad thing for users.

t-vi commented 3 months ago

We do seem to be able to run ResNet as of today and vit_b_16 (at least), thanks to #584 and #633 . :tada: