Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.13k stars 69 forks source link

Represent slices natively in traces #160

Open carmocca opened 5 months ago

carmocca commented 5 months ago

🚀 Feature

Motivation

Tensor slices are represented in traces as:

  t107 = torch_slice_prim_impl(t53, [0, 0, 0, 0], [4, 32, 2048, 0], [1, 1, 1, 1])  # t107: "cuda:0 bf16[4, 32, 2048, 0]"

But there's no torch_slice_prim_impl import. And we can use Python to represent it.

This reference comes from:

https://github.com/Lightning-AI/lightning-thunder/blob/ea1d1302f4a630e3832c07dc3adfe559111ba099/thunder/executors/torchex.py#L533-L534

https://github.com/Lightning-AI/lightning-thunder/blob/ea1d1302f4a630e3832c07dc3adfe559111ba099/thunder/executors/torchex.py#L507-L517

Pitch

Instead represent it with __getitem__ and slice():

t123 = t321.__getitem__([slice(0, 3), slice(0, 5)])  # t123: "cuda:..."

Alternatives

Add the torch_slice_prim_impl import from torchex to the trace so that it's a valid program

cc @apaz-cli @nikitaved

mruberry commented 4 months ago

We can definitely represent it using __getitem__ when it's executed by PyTorch (we can even use a subscript instead of the dunder).

In general we should be sure we're importing the necessary functions so that we always generate a Python function that can be run independently.