tenstorrent / tt-forge-fe

The TT-Forge FE is a graph compiler designed to optimize and transform computational graphs for deep learning models, enhancing their performance and efficiency.
https://docs.tenstorrent.com/tt-forge-fe/
Apache License 2.0
15 stars 1 forks source link

Ability to move inputs/parameters between memory locations (e.g. host, device) #175

Open nvukobratTT opened 3 months ago

nvukobratTT commented 3 months ago

Summary

For more flexible inference/training we need more granular control over:

The proposal is to use the .to("cpu"/"tt") functionality (similar to torch 2.0).

Details

In order to:

It's adequate to have API to support this kind of control. Therefore, the proposal is to use a style inspired by Torch 2.0.

For example:

    # Put inputs on device
    input_tensor = torch.rand(1, 784).to("tt")

    # Define model and instruct it to compile and run on TT device
    framework_model = MNISTLinear()
    framework_model.to("tt")

    loss = torch.nn.L1Loss().to("cpu")

    # Compiles the model
    tt_model = pybuda.compile(framework_model)
    cpu_model = pybuda.compile(loss)

We should investigate can we utilize metadata from torch elements. E.g. is it specified to be run on CPU (to("cpu")) or some other backed. If it's possible, we can follow above example and use pybuda.compile for CPU devices as well. This can nicely integrate in a flow where we need to be aware of specific inputs/ouputs for CPU models as well (e.g. loss or optimizer run on CPU; we need to tell runtime from where to gather appropriate tensors).

Passing criteria

nsmithtt commented 3 months ago
framework_model = MNISTLinear()
framework_model.to("tt")

FYI I think you can only run .to on a module after it's been pybuda.compile'd.

I think the above could be written:

    # Define model and instruct it to compile and run on TT device
    framework_model = MNISTLinear()
    tt_model = pybuda.compile(framework_model)
    tt_model.to("tt")

    # Create a torch loss and leave on CPU
    loss = torch.nn.L1Loss()

    # Put inputs on device
    input_tensor = torch.rand(1, 784).to("tt")
    # Run device module
    model_out = tt_model(input_tensor)
    # Pull output back to CPU
    model_out_cpu = model_out.to("cpu")
    # Run loss on CPU
    loss_out_cpu = loss(model_out_cpu)

The tricky thing will be wiring up backwards though, if we want to support this level of flexibility, if the user does:

loss_out_cpu.backward()

We'll have to hook this backward method somehow and propagate that back into the device module.

nvukobratTT commented 3 months ago

Thanks for the comments!

FYI I think you can only run .to on a module after it's been pybuda.compile'd. I'm not 100% sure if that is correct. I'll double-check it latter today.

While writing these issues, I tried to reference both vanilla torch 2.0 and our current integration from PyBuda end. In sum, checking out this example:

    inputs = [torch.rand(1, 784)]

    framework_model = MNISTLinear()
    fw_out = framework_model(*inputs)

    framework_model.to("tt")
    compiled_model = torch.compile(framework_model, backend="tt")
    co_out = compiled_model(*[i.to("tt") for i in inputs])

    co_out = [co.to("cpu") for co in co_out]
    assert [torch.allclose(fo, co) for fo, co in zip(fw_out, co_out)]

you'll see that there is a similar pattern:

    framework_model.to("tt")
    compiled_model = torch.compile(framework_model, backend="tt")

However, if I remove .to("tt") before compile call, I would get something like this:

E       torch._dynamo.exc.TorchRuntimeError: Failed running call_module L__self___l1(*(FakeTensor(..., device='tt:0', size=(1, 784)),), **{}):
E       Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices tt:0, cpu

Not sure if this is more torch-related, or how we implemented it on our end. In any case, I would like to strive for what is the default for torch 2.0. @AleksKnezevic can you provide more context on this workflow? I'm mostly interested around .to("tt") details.

Edit: This doc from a related issue provided more clarity on my end. I agree that we should stick to what doc outlines, and that is using .to(*) to just move parameters and buffers across different mem locations. I'll make issue edits during the day.

nvukobratTT commented 3 months ago

The tricky thing will be wiring up backwards though, if we want to support this level of flexibility, if the user does: Agreed.

However, I see value in exploring it for the initial run. If this proves to be a bit more difficult, we can always expose a new runtime API to handle this case. However, I'm much more in favor of using .backard() as is.

In sum, once we clarify these details, I'll update the root description to encapsulate all of it.

nsmithtt commented 3 months ago

However, if I remove .to("tt") before compile call, I would get something like this:

I think that might be some quirk to do with how we had to do it in the old runtime. But new runtime we should be able to support .to per torch spec and have it mean prepare and move weights to device.

nsmithtt commented 3 months ago

Re loss.backward(), just making a note of these PyTorch docs: https://pytorch.org/docs/stable/generated/torch.Tensor.register_hook.html

We can register a backwards hook on the graph outputs from tt modules. This way the gradient can come from the CPU and it will hit our hooked backwards function:

def forward_to_tt_module(tt_module, grad):
    device_grad = grad.to("tt")
    return tt_module.backwards(device_grad)

# Runtime before returning the output tensor, registers a hook
tt_output.register_hook(lambda grad: forward_to_device_module(tt_module, grad))