Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.15k stars 77 forks source link

allow tracing models on meta device #862

Open t-vi opened 2 months ago

t-vi commented 2 months ago

We want to facilitate running models that can only fit into memory after transforms. The current main thing in PyTorch to instantiate a model without memory is through meta devices.

So the plan is to enable:

cc @apaz-cli

t-vi commented 2 months ago

I see this as the materialization transform:

@requiresCUDA
def test_materialization():
    from thunder.transforms import MaterializationTransform

    config = litgpt_model.Config.from_name("llama2-like")
    with torch.device("cuda"):
        ref_m = litgpt_model.GPT(config).to(torch.bfloat16)
    with torch.device("meta"):
        m = litgpt_model.GPT(config).to(torch.bfloat16)
    for p in m.parameters():
        p.__thunder_device = torch.device("cuda")
    for b in m.buffers():
        p.__thunder_device = torch.device("cuda")

    init_from_sd = MaterializationTransform.from_original_state_dict(ref_m.state_dict())
    jm = thunder.jit(
        m,
        transforms=MaterializationTransform("cuda", init=init_from_sd)
    )
    x = torch.randint(1, 255, (1, 10), device="cuda")
    input_pos = torch.arange(10, device="cuda")

    expected = ref_m(x, input_pos)
    actual = jm(x, input_pos)

    assert_close(actual, expected)

wdyt?

lantiga commented 2 months ago

Looks super neat, to simplify we could also do

    jm = thunder.jit(
        m,
        transforms=MaterializationTransform("cuda", init=ref_m.state_dict())
    )
t-vi commented 2 months ago

I thought that would be neat, but this is what kept me from offering: we could have

And there we have two state dicts I would not know how to differentiate.

lantiga commented 2 months ago

aah makes sense

t-vi commented 2 months ago

With #867 and #868 we have initial support, but three of the four modes are yet to be fleshed out:

IvanYashchuk commented 1 month ago

What's the benefit of allowing meta device tensors interact with other devices and not error out? What could be the alternatives? Have you considered using PyTorch's FakeTensor for initialization instead of plain meta tensors?