Open johnsutor opened 4 days ago
@alanwaketan can you take a look? I am actually not entirely sure if Pallas will works on XLA:CPU, but the failure seems to be happen in our python code.
@johnsutor Can you try nightly?
@alanwaketan I installed nightly using
! pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl
but then I get the following error when I attempt to run the code block
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import os
from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
import torch
os.environ["PJRT_DEVICE"] = "CPU"
---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
[<ipython-input-3-193790483d72>](https://localhost:8080/#) in <cell line: 5>()
3 import jax.numpy as jnp
4 import os
----> 5 from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
6 import torch
7 os.environ["PJRT_DEVICE"] = "CPU"
[/usr/local/lib/python3.10/dist-packages/torch_xla/__init__.py](https://localhost:8080/#) in <module>
6
7 import torch
----> 8 import _XLAC
9 from ._internal import tpu
10 from .version import __version__
ImportError: /usr/local/lib/python3.10/dist-packages/_XLAC.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZNK3c105Error4whatEv
---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.
To view examples of installing some common dependencies, click the
"Open Examples" button below.
---------------------------------------------------------------------------
@johnsutor This error happens usually when there's a mismatch between torch and torch_xla versions installed. Maybe also update the pytorch to nightly version.
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
@bhavya01 That unfortunately does not work, as I receive an issue from the PyTorch end when attempting to install in Colab.
Looking in indexes: https://download.pytorch.org/whl/nightly/cpu
Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.3.0+cu121)
Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (0.18.0+cu121)
Requirement already satisfied: torchaudio in /usr/local/lib/python3.10/dist-packages (2.3.0+cu121)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.15.4)
Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12.1)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2023.6.0)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
ERROR: HTTP error 403 while getting https://download.pytorch.org/whl/nightly/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (from https://download.pytorch.org/whl/nightly/cpu/nvidia-cuda-nvrtc-cu12/)
ERROR: Could not install requirement nvidia-cuda-nvrtc-cu12==12.1.105 from https://download.pytorch.org/whl/nightly/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (from torch) because of HTTP error 403 Client Error: Forbidden for url: https://download.pytorch.org/whl/nightly/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl for URL https://download.pytorch.org/whl/nightly/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (from https://download.pytorch.org/whl/nightly/cpu/nvidia-cuda-nvrtc-cu12/)
however, PyTorch nightly install works fine on my M2 Mac, but then I can't install torch_xla.
@johnsutor Do you include torchvision torchaudio? If so, we can remove them from the command.
@alanwaketan On Colab, I had to uninstall torch to get it to work before installing torch nightly. However, I now get the following error when I attempt to run this code (you can see the results in the notebook here)
ef add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
o_ref[...] = x + y
@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
return pl.pallas_call(add_vectors_kernel,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
interpret=True
)(x, y)
q = torch.randn(3, 2, 128, 4).to("xla")
k = torch.randn(3, 2, 128, 4).to("xla")
v = torch.randn(3, 2, 128, 4).to("xla")
# From the tutorial
pt_kernel = make_kernel_from_pallas(add_vectors, lambda x, y: [(x.shape, x.dtype)])
output = pt_kernel(q, k)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[<ipython-input-4-d7393bd6c1ca>](https://localhost:8080/#) in <cell line: 3>()
1 # From the tutorial
2 pt_kernel = make_kernel_from_pallas(add_vectors, lambda x, y: [(x.shape, x.dtype)])
----> 3 output = pt_kernel(q, k)
[/usr/local/lib/python3.10/dist-packages/torch_xla/experimental/custom_kernel.py](https://localhost:8080/#) in wrapped_kernel(kernel, output_shape_dtype_fn, static_argnums, static_argnames, *args, **kwargs)
143 output_shapes = [shape for shape, _ in output_shape_dtype]
144 output_dtypes = [dtype for _, dtype in output_shape_dtype]
--> 145 outputs = torch_xla._XLAC._xla_tpu_custom_call(tensor_args, payload,
146 output_shapes, output_dtypes)
147
TypeError: _xla_tpu_custom_call(): incompatible function arguments. The following argument types are supported:
1. (arg0: list[torch.Tensor], arg1: str, arg2: list[list[int]], arg3: list[object]) -> list[torch.Tensor]
Invoked with: [tensor([[[[ 0.9387, -0.0999, 1.2256, 1.0983],
[ 0.1434, 0.6569, -1.0381, -1.5455],
[ 0.6816, 1.0613, -1.0494, -0.1487],
...,
[-0.7309, -0.7079, 0.9284, 0.8607],
[ 0.0899, 1.5027, -0.5658, -1.0447],
[ 1.3098, 0.5118, 0.0765, -0.2424]],
[[ 0.0088, 1.8775, 0.0542, -1.1126],
[-0.6061, -0.3355, -0.7491, -1.8286],
[ 0.7737, 0.0899, -0.2609, 0.4641],
...,
[ 0.7544, 1.3189, 0.9427, 0.6183],
[ 0.9660, -0.5893, -0.3516, 1.0709],
[-2.3094, -0.2950, -1.1045, -0.0845]]],
[[[-0.4123, 0.7556, -0.4119, 0.6650],
[-0.4744, 0.9115, 0.1186, -0.7852],
[-0.2252, -0.6484, 1.5036, 0.7215],
...,
[-0.1553, 1.5585, 1.1157, 0.8698],
[ 0.5997, 0.5789, 0.3054, -1.8421],
[-0.5578, -0.8656, 0.1356, 1.1475]],
[[ 2.0097, 0.7483, -0.5908, 0.0702],
[-1.0810, -0.6120, -0.7814, 0.5367],
[-0.9203, -0.9630, -1.7621, 1.3503],
...,
[ 0.2230, -0.2255, 1.2624, -0.7935],
[-1.0775, 0.5843, 0.5457, -0.1265],
[-2.4482, -1.0382, -0.9038, -0.9088]]],
[[[ 0.6542, 2.3457, 0.0888, -0.2082],
[-0.2973, -0.4685, 0.8633, 1.2241],
[ 0.1258, 0.1412, 0.9298, -1.0842],
...,
[-0.6876, -1.5594, -1.0357, 0.3485],
[ 1.1975, -0.1514, -1.2257, 0.9857],
[ 1.7342, -1.5681, 0.4157, 0.9439]],
[[ 1.1967, 0.2086, -0.5509, -1.1779],
[ 0.4936, -0.8626, -0.6094, -0.7941],
[ 0.0440, -0.5978, 1.2477, 1.2164],
...,
[-1.8150, 1.0365, 0.5270, -0.4706],
[ 0.5347, -1.1803, -0.2394, -0.1587],
[-0.9638, -1.0259, -1.2330, -0.2761]]]], device='xla:0'), tensor([[[[ 6.7709e-01, 5.0651e-01, -3.9015e-01, 1.1769e+00],
[-5.7469e-01, 6.5236e-01, -6.9628e-01, 3.3803e-02],
[-2.8044e-01, 7.1211e-01, 2.3748e-01, 3.7293e-01],
...,
[ 2.8838e+00, 8.6530e-01, -1.5567e-01, 2.1392e-01],
[-5.7115e-01, -2.6569e+00, 1.2452e+00, 1.0137e-01],
[ 8.7078e-01, -8.3965e-01, -9.3462e-01, 5.8777e-01]],
[[-3.5002e-01, -1.0575e+00, -1.4964e+00, 9.9756e-01],
[ 7.8972e-01, -4.1112e-02, -1.2023e+00, -5.3902e-01],
[-5.9894e-01, -8.5050e-01, -3.6425e-01, -9.7505e-01],
...,
[ 1.2945e-02, -3.0388e-01, -1.3666e+00, -8.1373e-01],
[-1.2614e+00, 1.3913e-01, -5.6531e-01, -4.5330e-01],
[-1.1217e+00, -9.0676e-01, -1.0731e+00, -1.9240e-01]]],
[[[ 4.0985e-01, 1.1629e+00, -5.1721e-01, 1.6515e-01],
[-3.1879e-01, 7.2867e-01, -1.5622e+00, -6.3426e-01],
[-6.6151e-01, -3.2032e-01, 2.1753e+00, -8.9741e-01],
...,
[-3.5038e-01, 2.1497e-01, -2.1903e-01, 3.8987e-01],
[ 5.4283e-01, -5.4239e-01, -8.3459e-01, 4.8928e-01],
[ 1.2570e+00, -1.4615e+00, 4.1475e-01, 1.5395e+00]],
[[-8.7543e-01, -3.6893e-01, -6.6030e-01, -4.0877e-01],
[-2.3046e-02, -6.1282e-01, 1.8114e-01, -5.9609e-01],
[-1.8128e-01, 1.1691e+00, -5.3699e-01, -1.2312e-01],
...,
[ 2.0114e-01, -7.7060e-01, 1.1129e+00, -2.0385e-01],
[ 1.0480e+00, 4.0939e-01, -5.2975e-01, -2.1745e-01],
[-9.5069e-01, -9.6135e-01, -1.1307e+00, 1.1766e+00]]],
[[[ 3.3030e+00, 6.5805e-01, -1.7184e+00, 3.5029e-01],
[-1.2511e-01, 4.9624e-01, -1.3216e+00, -6.8949e-01],
[ 1.5174e+00, -6.8108e-01, 1.5536e-01, -7.8465e-01],
...,
[-9.0042e-01, 6.8886e-01, 5.7894e-01, 1.2891e-01],
[ 1.1717e-01, -2.0201e-01, -3.0778e-01, -1.8447e+00],
[ 3.1883e-01, 6.5378e-01, 1.3329e+00, 2.5843e-01]],
[[-2.1050e-01, -5.7218e-01, -5.8443e-01, -1.4757e+00],
[-1.6935e+00, -3.2765e-02, 9.5702e-01, 8.3929e-01],
[ 1.6788e-01, -1.0459e+00, -2.0357e-01, 3.7145e-02],
...,
[-9.4363e-01, 9.8749e-01, -4.5407e-01, -9.5364e-01],
[-1.3861e-01, -2.1635e-01, 9.0047e-01, -2.7273e-02],
[ 2.2375e+00, 2.0899e-03, 1.3707e+00, -6.9060e-01]]]],
device='xla:0')], None, [torch.Size([3, 2, 128, 4])], [torch.float32]
Okay, the payload is None which means no IR being traced from the kernel. I think we just don't support the interpret mode.
Can you tell me more on why you are trying to use interpret mode? Maybe we can add that support in the future.
🐛 Bug
I am attempting to implement custom Pallas kernels locally on a CPU for use with a TPU. I'm attempting to follow the official example here, with the minor modification being that I run the script on a CPU using interpret mode. After investigating, it appears that the main branch's latest code for a custom kernel should fix any issues with this error.
To Reproduce
Please use the colab here:
Steps to reproduce the behavior:
Expected behavior
It should execute the code without any errors
Environment
Additional context
N/A