Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.07k
stars
60
forks
source link
ThunderModule - `state_dict` followed by `load_original_state_dict` returns `state_dict` of the original module. #647
import torch
import thunder
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_parameter("param", torch.nn.Parameter(torch.randn(3)))
def forward(self, x):
return x
m = Model()
thunder_module = thunder.jit(Model())
# This only loads into `ThunderModule._overrides_parameters`
# doesn't update the original module.
thunder_module.load_original_state_dict(m.state_dict())
print("THUNDER OVERRIDES:", thunder_module._overrides_parameters)
print("THUNDER STATE_DICT:", thunder_module.state_dict()) # This pulls in the state_dict of the original module (which we never updated).
print("ORIGINAL STATE_DICT:", m.state_dict())
Output
I think there are two options to make this consistent -
load_original_state_dict
should also update the parameters of the original module which ThunderModule wraps.ThunderModule.state_dict
should prefer to pull tensors from_overrides_parameters
and_overrides_buffers
.Given that the original module parameters and buffers may be on device
meta
, I think option 2 makes more sense.cc: @t-vi
Tagging with label
transforms
as we care aboutload/save_state_dict
with transforms - https://github.com/Lightning-AI/lightning-thunder/issues/483