Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.07k stars 60 forks source link

ThunderModule - `state_dict` followed by `load_original_state_dict` returns `state_dict` of the original module. #647

Open kshitij12345 opened 3 days ago

kshitij12345 commented 3 days ago
import torch
import thunder

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_parameter("param", torch.nn.Parameter(torch.randn(3)))

    def forward(self, x):
        return x

m = Model()

thunder_module = thunder.jit(Model())
# This only loads into `ThunderModule._overrides_parameters`
# doesn't update the original module.
thunder_module.load_original_state_dict(m.state_dict())

print("THUNDER OVERRIDES:", thunder_module._overrides_parameters)
print("THUNDER STATE_DICT:", thunder_module.state_dict())  # This pulls in the state_dict of the original module (which we never updated).
print("ORIGINAL STATE_DICT:", m.state_dict())

Output

THUNDER OVERRIDES: {'param': tensor([ 1.4207,  1.2219, -0.4130])}
THUNDER STATE_DICT: OrderedDict([('param', tensor([ 0.2293,  0.0467, -0.8747]))])
ORIGINAL STATE_DICT: OrderedDict([('param', tensor([ 1.4207,  1.2219, -0.4130]))])

I think there are two options to make this consistent -

  1. load_original_state_dict should also update the parameters of the original module which ThunderModule wraps.
  2. ThunderModule.state_dict should prefer to pull tensors from _overrides_parameters and _overrides_buffers.

Given that the original module parameters and buffers may be on device meta, I think option 2 makes more sense.

cc: @t-vi

Tagging with label transforms as we care about load/save_state_dict with transforms - https://github.com/Lightning-AI/lightning-thunder/issues/483

t-vi commented 3 days ago

Yes, 2 makes a lot more sense. thank you.