Open nvukobratTT opened 3 months ago
framework_model = MNISTLinear()
framework_model.to("tt")
FYI I think you can only run .to
on a module after it's been pybuda.compile
'd.
I think the above could be written:
# Define model and instruct it to compile and run on TT device
framework_model = MNISTLinear()
tt_model = pybuda.compile(framework_model)
tt_model.to("tt")
# Create a torch loss and leave on CPU
loss = torch.nn.L1Loss()
# Put inputs on device
input_tensor = torch.rand(1, 784).to("tt")
# Run device module
model_out = tt_model(input_tensor)
# Pull output back to CPU
model_out_cpu = model_out.to("cpu")
# Run loss on CPU
loss_out_cpu = loss(model_out_cpu)
The tricky thing will be wiring up backwards though, if we want to support this level of flexibility, if the user does:
loss_out_cpu.backward()
We'll have to hook this backward method somehow and propagate that back into the device module.
Thanks for the comments!
FYI I think you can only run .to on a module after it's been pybuda.compile'd. I'm not 100% sure if that is correct. I'll double-check it latter today.
While writing these issues, I tried to reference both vanilla torch 2.0 and our current integration from PyBuda end. In sum, checking out this example:
inputs = [torch.rand(1, 784)]
framework_model = MNISTLinear()
fw_out = framework_model(*inputs)
framework_model.to("tt")
compiled_model = torch.compile(framework_model, backend="tt")
co_out = compiled_model(*[i.to("tt") for i in inputs])
co_out = [co.to("cpu") for co in co_out]
assert [torch.allclose(fo, co) for fo, co in zip(fw_out, co_out)]
you'll see that there is a similar pattern:
framework_model.to("tt")
compiled_model = torch.compile(framework_model, backend="tt")
However, if I remove .to("tt")
before compile call, I would get something like this:
E torch._dynamo.exc.TorchRuntimeError: Failed running call_module L__self___l1(*(FakeTensor(..., device='tt:0', size=(1, 784)),), **{}):
E Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices tt:0, cpu
Not sure if this is more torch-related, or how we implemented it on our end. In any case, I would like to strive for what is the default for torch 2.0. @AleksKnezevic can you provide more context on this workflow? I'm mostly interested around .to("tt")
details.
Edit:
This doc from a related issue provided more clarity on my end. I agree that we should stick to what doc outlines, and that is using .to(*)
to just move parameters and buffers across different mem locations. I'll make issue edits during the day.
The tricky thing will be wiring up backwards though, if we want to support this level of flexibility, if the user does: Agreed.
However, I see value in exploring it for the initial run. If this proves to be a bit more difficult, we can always expose a new runtime API to handle this case. However, I'm much more in favor of using .backard()
as is.
In sum, once we clarify these details, I'll update the root description to encapsulate all of it.
However, if I remove .to("tt") before compile call, I would get something like this:
I think that might be some quirk to do with how we had to do it in the old runtime. But new runtime we should be able to support .to
per torch spec and have it mean prepare and move weights to device.
Re loss.backward()
, just making a note of these PyTorch docs: https://pytorch.org/docs/stable/generated/torch.Tensor.register_hook.html
We can register a backwards hook on the graph outputs from tt modules. This way the gradient can come from the CPU and it will hit our hooked backwards function:
def forward_to_tt_module(tt_module, grad):
device_grad = grad.to("tt")
return tt_module.backwards(device_grad)
# Runtime before returning the output tensor, registers a hook
tt_output.register_hook(lambda grad: forward_to_device_module(tt_module, grad))
Summary
For more flexible inference/training we need more granular control over:
The proposal is to use the
.to("cpu"/"tt")
functionality (similar to torch 2.0).Details
In order to:
It's adequate to have API to support this kind of control. Therefore, the proposal is to use a style inspired by Torch 2.0.
For example:
We should investigate can we utilize metadata from torch elements. E.g. is it specified to be run on CPU (
to("cpu")
) or some other backed. If it's possible, we can follow above example and usepybuda.compile
for CPU devices as well. This can nicely integrate in a flow where we need to be aware of specific inputs/ouputs for CPU models as well (e.g. loss or optimizer run on CPU; we need to tell runtime from where to gather appropriate tensors).Passing criteria
pybuda.compile
should return to the runtime appropriate tensor layout details (e.g. is tensor stored on host, dram, l1, etc.)