Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.18k stars 78 forks source link

Updating Module's buffer leads to crash in thunder.jit #646

Open kshitij12345 opened 4 months ago

kshitij12345 commented 4 months ago
import torch
import thunder

class TestModel(torch.nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.register_buffer("buffer", torch.tensor([1,]))

    def forward(self, x):
        self.buffer = self.buffer + 1
        return x + self.buffer

m = TestModel()

x = torch.randn(3, )
print(m(x))

jm = thunder.jit(TestModel())
print(jm(x))

Errors with:

Traceback (most recent call last):
  File "lightning-thunder/scratchpad/test.py", line 73, in <module>
    print(jm(x))
  File "git/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "git/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "lightning-thunder/thunder/core/module.py", line 61, in forward
    res = self._forward_fn(*args, **kwargs)
  File "lightning-thunder/thunder/__init__.py", line 675, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
  File "lightning-thunder/thunder/__init__.py", line 224, in cache_info_wrapper
    res = fn(*args, **kwargs)
  File "lightning-thunder/thunder/__init__.py", line 503, in get_computation_and_inputs
    jit_results: TraceResults = interpreter(
  File "lightning-thunder/thunder/__init__.py", line 212, in _general_frontend
    return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)
  File "lightning-thunder/thunder/core/jit_ext.py", line 1660, in thunder_general_jit
    process_recorded_modifications(ctx, epilogue_trace)
  File "lightning-thunder/thunder/core/jit_ext.py", line 1554, in process_recorded_modifications
    typ, name, root_module_provenance = get_parameter_or_buffer_or_submodule_name_and_root(
  File "lightning-thunder/thunder/core/jit_ext.py", line 1255, in get_parameter_or_buffer_or_submodule_name_and_root
    assert provenance.inputs[0].inst is PseudoInst.LOAD_ATTR
IndexError: list index out of range
tfogal commented 3 months ago

triage review:

riccardofelluga commented 1 month ago

This issue is still present in the torchbench HF models: hf_GPT2, hf_GPT2_large, hf_T5 ,hf_T5_base, hf_T5_large. In particular the error in that case can be reproduced by setting the buffer to None:

class TestModel(torch.nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.register_buffer("buffer", torch.ones((1, 2)))

    def forward(self, x):
        self.buffer = None
        return x

With stacktrace similar to:

  File "/opt/pytorch/lightning-thunder/thunder/core/jit_ext.py", line 1698, in thunder_general_jit
    process_recorded_modifications(ctx, epilogue_trace)
  File "/opt/pytorch/lightning-thunder/thunder/core/jit_ext.py", line 1578, in process_recorded_modifications
    assert isinstance(value.value, Proxy)
AssertionError

To reproduce the original error from HF, checkout to hf-benchmarks and run pytest thunder/benchmarks/targets.py -k "torchbench and hf and -thunder-hf" -v --benchmark-disable. No need to checkout to the branch if #1238 is merged at the time of reading.