pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.49k stars 483 forks source link

Add aten::lift_fresh to the exception list that converts input to torch tensor #8360

Closed barney-s closed 2 weeks ago

barney-s commented 2 weeks ago

This was causing fastNLP_BERT model to fail.

https://github.com/fastnlp/fastNLP/blob/4e95989e973f59b2ecb7f718647257e8b6fea0c7/fastNLP/embeddings/bert_embedding.py#L427

467                word_pieces[i, 1:word_pieces_lengths[i] + 1] = torch.LongTensor(word_pieces_i)

torch.LongTensor was returning a torch.Tensor instead of XLA Array.

Debugging the call trace

DISPATCH: aten::lift_fresh
 FUNCTION: aten::lift_fresh

fixes: https://github.com/pytorch/xla/issues/8126