Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.07k stars 60 forks source link

Indexing with slice and list #643

Open t-vi opened 4 days ago

t-vi commented 4 days ago

Repro:

def fn(input_ids):
   return input_ids[:, [-1, 0]]

a = torch.randn(5, 5)
jfn = thunder.jit(fn)
jfn(a)

Needed for transformers BERT:

import thunder, torch, transformers

m = transformers.BertForSequenceClassification(transformers.BertConfig())
x = torch.randint(1, 10, (1, 32))
jm = thunder.jit(m)
jm(x)

Might ideally be fixed along with #460 and #187

cc @apaz-cli

t-vi commented 4 days ago

@k223kim wants to work on this. Thank you!

k223kim commented 1 day ago

Hey Team! I will be working on this issue. Thanks!