pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.47k stars 471 forks source link

[bug] SparseXLA backend not implement #6853

Open Mon-ius opened 6 months ago

Mon-ius commented 6 months ago
import torch
import torch_xla.core.xla_model as xm

def voxelize(point_cloud, voxel_size):
    normalized_point_cloud = (point_cloud / voxel_size).long()
    ones = torch.ones(len(normalized_point_cloud))

    size = tuple((torch.max(normalized_point_cloud, dim=0).values + 1).tolist())
    voxel_grid = torch.sparse_coo_tensor(indices=normalized_point_cloud.t(), values=ones, size=size)

    voxel_grid = torch.stack([voxel_grid, voxel_grid], dim=1)
    return voxel_grid

xla = xm.xla_device()
point_cloud = torch.rand((100, 3), device=xla)
voxel_size = 0.1
voxel_grid = voxelize(point_cloud, voxel_size)

Bugs:

Traceback (most recent call last):
  File "/home/m0niusplus/xla/t1.py", line 17, in <module>
    voxel_grid = voxelize(point_cloud, voxel_size)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/m0niusplus/xla/t1.py", line 9, in voxelize
    voxel_grid = torch.sparse_coo_tensor(indices=normalized_point_cloud.t(), values=ones, size=size)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
NotImplementedError: Could not run 'aten::_sparse_coo_tensor_with_dims_and_tensors' with arguments from the 'SparseXLA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_sparse_coo_tensor_with_dims_and_tensors' is only available for these backends: [XLA, Meta, SparseCPU, SparseMeta, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastXLA, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

XLA: registered at torch_xla/csrc/aten_cpu_fallback.cpp:51 [backend fallback]
Meta: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/build/aten/src/ATen/RegisterMeta.cpp:26984 [kernel]
SparseCPU: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/build/aten/src/ATen/RegisterSparseCPU.cpp:1387 [kernel]
SparseMeta: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/build/aten/src/ATen/RegisterSparseMeta.cpp:249 [kernel]
BackendSelect: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/build/aten/src/ATen/RegisterBackendSelect.cpp:807 [kernel]
Python: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/aten/src/ATen/core/PythonFallbackKernel.cpp:154 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/aten/src/ATen/functorch/DynamicLayer.cpp:498 [backend fallback]
Functionalize: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/aten/src/ATen/FunctionalizeFallbackKernel.cpp:324 [backend fallback]
Named: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/aten/src/ATen/native/NegateFallback.cpp:19 [backend fallback]
ZeroTensor: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at /opt/conda/conda-bld/pytorch_1711403238793/work/aten/src/ATen/core/VariableFallbackKernel.cpp:86 [backend fallback]
AutogradOther: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradCPU: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradCUDA: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradHIP: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradXLA: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradMPS: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradIPU: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradXPU: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradHPU: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradVE: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradLazy: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradMTIA: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradPrivateUse1: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradPrivateUse2: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradPrivateUse3: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradMeta: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradNestedTensor: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
Tracer: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/torch/csrc/autograd/generated/TraceType_2.cpp:17346 [kernel]
AutocastCPU: fallthrough registered at /opt/conda/conda-bld/pytorch_1711403238793/work/aten/src/ATen/autocast_mode.cpp:378 [backend fallback]
AutocastXLA: fallthrough registered at torch_xla/csrc/autocast_mode.cpp:25 [backend fallback]
AutocastCUDA: fallthrough registered at /opt/conda/conda-bld/pytorch_1711403238793/work/aten/src/ATen/autocast_mode.cpp:244 [backend fallback]
FuncTorchBatched: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:720 [backend fallback]
BatchedNestedTensor: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:746 [backend fallback]
FuncTorchVmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1711403238793/work/aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
Batched: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/aten/src/ATen/LegacyBatchingRegistrations.cpp:1075 [backend fallback]
VmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1711403238793/work/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/aten/src/ATen/functorch/TensorWrapper.cpp:203 [backend fallback]
PythonTLSSnapshot: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/aten/src/ATen/core/PythonFallbackKernel.cpp:162 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/aten/src/ATen/functorch/DynamicLayer.cpp:494 [backend fallback]
PreDispatch: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/aten/src/ATen/core/PythonFallbackKernel.cpp:166 [backend fallback]
PythonDispatcher: registered at /opt/conda/conda-bld/pytorch_1711403238793/work/aten/src/ATen/core/PythonFallbackKernel.cpp:158 [backend fallback]
Mon-ius commented 6 months ago

Environment setup,

conda create -n xla python=3.11 transformers diffusers datasets accelerate evaluate torchvision torchaudio  bitsandbytes safetensors sentencepiece imageio scipy numpy pyglet gradio open3d fire rich -c conda-forge -c pytorch -y

conda activate xla
conda env config vars set LD_LIBRARY_PATH="$CONDA_PREFIX/lib"
conda env config vars set HF_HOME="/dev/shm"
conda env config vars set PJRT_DEVICE=TPU
# conda env config vars set XLA_USE_BF16=1
# conda env config vars set XLA_USE_SPMD=1
conda deactivate && conda activate xla

pip install 'torch~=2.2.0' --index-url https://download.pytorch.org/whl/cpu
pip install 'torch_xla[tpu]~=2.2.0' -f https://storage.googleapis.com/libtpu-releases/index.html
pip uninstall -y accelerate
pip install git+https://github.com/huggingface/accelerate
JackCaoG commented 6 months ago

Seems like it got codegen in https://github.com/pytorch/pytorch/blob/3243be7c3a7e871acfc9923eea817493f996da9a/torchgen/model.py#L166 but we didn't implement the corresponding sparse kernels. I can make this a feature request but it is unlikely we have resouce to work on sparse related projects anytime soon.

Mon-ius commented 6 months ago

do we have alternative solution for torch.sparse_coo_tensor implement on XLA device?

JackCaoG commented 6 months ago

not that I am aware of, we haven't think too much about sparsity yet.