Open t-vi opened 4 days ago
Repro:
def fn(input_ids): return input_ids[:, [-1, 0]] a = torch.randn(5, 5) jfn = thunder.jit(fn) jfn(a)
Needed for transformers BERT:
import thunder, torch, transformers m = transformers.BertForSequenceClassification(transformers.BertConfig()) x = torch.randint(1, 10, (1, 32)) jm = thunder.jit(m) jm(x)
Might ideally be fixed along with #460 and #187
cc @apaz-cli
@k223kim wants to work on this. Thank you!
Hey Team! I will be working on this issue. Thanks!
Repro:
Needed for transformers BERT:
Might ideally be fixed along with #460 and #187
cc @apaz-cli