Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.07k
stars
60
forks
source link
ThunderModule - fix load_original_state_dict to work with modules with buffers #648
import torch
import thunder
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_parameter("param", torch.nn.Parameter(torch.randn(3)))
self.register_buffer("buffer", torch.randn(3))
def forward(self, x):
return x
m = Model()
thunder_module = thunder.jit(Model())
# Fails with - NameError: name 'model' is not defined
thunder_module.load_original_state_dict(m.state_dict())
Repro: