Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.07k stars 60 forks source link

ThunderModule - fix load_original_state_dict to work with modules with buffers #648

Closed kshitij12345 closed 3 days ago

kshitij12345 commented 3 days ago

Repro:

import torch
import thunder

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_parameter("param", torch.nn.Parameter(torch.randn(3)))
        self.register_buffer("buffer", torch.randn(3))

    def forward(self, x):
        return x

m = Model()

thunder_module = thunder.jit(Model())

# Fails with - NameError: name 'model' is not defined
thunder_module.load_original_state_dict(m.state_dict())