Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.07k
stars
60
forks
source link
Updating Module's buffer leads to crash in thunder.jit #646
import torch
import thunder
class TestModel(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.register_buffer("buffer", torch.tensor([1,]))
def forward(self, x):
self.buffer = self.buffer + 1
return x + self.buffer
m = TestModel()
x = torch.randn(3, )
print(m(x))
jm = thunder.jit(TestModel())
print(jm(x))
Errors with:
Traceback (most recent call last):
File "lightning-thunder/scratchpad/test.py", line 73, in <module>
print(jm(x))
File "git/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "git/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "lightning-thunder/thunder/core/module.py", line 61, in forward
res = self._forward_fn(*args, **kwargs)
File "lightning-thunder/thunder/__init__.py", line 675, in fn_
cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
File "lightning-thunder/thunder/__init__.py", line 224, in cache_info_wrapper
res = fn(*args, **kwargs)
File "lightning-thunder/thunder/__init__.py", line 503, in get_computation_and_inputs
jit_results: TraceResults = interpreter(
File "lightning-thunder/thunder/__init__.py", line 212, in _general_frontend
return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)
File "lightning-thunder/thunder/core/jit_ext.py", line 1660, in thunder_general_jit
process_recorded_modifications(ctx, epilogue_trace)
File "lightning-thunder/thunder/core/jit_ext.py", line 1554, in process_recorded_modifications
typ, name, root_module_provenance = get_parameter_or_buffer_or_submodule_name_and_root(
File "lightning-thunder/thunder/core/jit_ext.py", line 1255, in get_parameter_or_buffer_or_submodule_name_and_root
assert provenance.inputs[0].inst is PseudoInst.LOAD_ATTR
IndexError: list index out of range
Errors with: