Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.07k stars 60 forks source link

Updating Module's buffer leads to crash in thunder.jit #646

Open kshitij12345 opened 3 days ago

kshitij12345 commented 3 days ago
import torch
import thunder

class TestModel(torch.nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.register_buffer("buffer", torch.tensor([1,]))

    def forward(self, x):
        self.buffer = self.buffer + 1
        return x + self.buffer

m = TestModel()

x = torch.randn(3, )
print(m(x))

jm = thunder.jit(TestModel())
print(jm(x))

Errors with:

Traceback (most recent call last):
  File "lightning-thunder/scratchpad/test.py", line 73, in <module>
    print(jm(x))
  File "git/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "git/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "lightning-thunder/thunder/core/module.py", line 61, in forward
    res = self._forward_fn(*args, **kwargs)
  File "lightning-thunder/thunder/__init__.py", line 675, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
  File "lightning-thunder/thunder/__init__.py", line 224, in cache_info_wrapper
    res = fn(*args, **kwargs)
  File "lightning-thunder/thunder/__init__.py", line 503, in get_computation_and_inputs
    jit_results: TraceResults = interpreter(
  File "lightning-thunder/thunder/__init__.py", line 212, in _general_frontend
    return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)
  File "lightning-thunder/thunder/core/jit_ext.py", line 1660, in thunder_general_jit
    process_recorded_modifications(ctx, epilogue_trace)
  File "lightning-thunder/thunder/core/jit_ext.py", line 1554, in process_recorded_modifications
    typ, name, root_module_provenance = get_parameter_or_buffer_or_submodule_name_and_root(
  File "lightning-thunder/thunder/core/jit_ext.py", line 1255, in get_parameter_or_buffer_or_submodule_name_and_root
    assert provenance.inputs[0].inst is PseudoInst.LOAD_ATTR
IndexError: list index out of range