pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.46k stars 467 forks source link

Fallback `_embedding_bag_backward` and force `sparse=false`. #7584

Closed ysiraichi closed 2 months ago

ysiraichi commented 3 months ago

Fix: #6008

This PR, in addition to https://github.com/pytorch/pytorch/pull/129691, add support for embedding bag calls when its sparse parameter is true.

Problem: when sparse=true, _embedding_bag_backward called at::_sparse_coo_tensor_unsafe_symint function which returns a sparse tensor. Since PyTorch/XLA does not support sparse tensors, this resulted in a dispatch error (see the original issue).

Solution: although, ideally we should support sparse tensors, in the short-term we decided (in an offline discussion) to fallback to the dense backwards function.

cc @miladm @JackCaoG