Open t-vi opened 2 months ago
I see this as the materialization transform:
@requiresCUDA
def test_materialization():
from thunder.transforms import MaterializationTransform
config = litgpt_model.Config.from_name("llama2-like")
with torch.device("cuda"):
ref_m = litgpt_model.GPT(config).to(torch.bfloat16)
with torch.device("meta"):
m = litgpt_model.GPT(config).to(torch.bfloat16)
for p in m.parameters():
p.__thunder_device = torch.device("cuda")
for b in m.buffers():
p.__thunder_device = torch.device("cuda")
init_from_sd = MaterializationTransform.from_original_state_dict(ref_m.state_dict())
jm = thunder.jit(
m,
transforms=MaterializationTransform("cuda", init=init_from_sd)
)
x = torch.randint(1, 255, (1, 10), device="cuda")
input_pos = torch.arange(10, device="cuda")
expected = ref_m(x, input_pos)
actual = jm(x, input_pos)
assert_close(actual, expected)
wdyt?
Looks super neat, to simplify we could also do
jm = thunder.jit(
m,
transforms=MaterializationTransform("cuda", init=ref_m.state_dict())
)
I thought that would be neat, but this is what kept me from offering: we could have
And there we have two state dicts I would not know how to differentiate.
aah makes sense
With #867 and #868 we have initial support, but three of the four modes are yet to be fleshed out:
What's the benefit of allowing meta device tensors interact with other devices and not error out? What could be the alternatives? Have you considered using PyTorch's FakeTensor for initialization instead of plain meta tensors?
We want to facilitate running models that can only fit into memory after transforms. The current main thing in PyTorch to instantiate a model without memory is through
meta
devices.So the plan is to enable:
__thunder_device
to the meta tensors (and/or maybe as a default option for compilation) and then use that device for proxying meta tensors.cc @apaz-cli