pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
2.38k stars 427 forks source link

Pallas on CPU #7599

Open johnsutor opened 4 days ago

johnsutor commented 4 days ago

🐛 Bug

I am attempting to implement custom Pallas kernels locally on a CPU for use with a TPU. I'm attempting to follow the official example here, with the minor modification being that I run the script on a CPU using interpret mode. After investigating, it appears that the main branch's latest code for a custom kernel should fix any issues with this error.

To Reproduce

Please use the colab here:

Steps to reproduce the behavior:

  1. Run the colab
  2. Observe errors in the last two cells

Expected behavior

It should execute the code without any errors


Additional context


JackCaoG commented 4 days ago

@alanwaketan can you take a look? I am actually not entirely sure if Pallas will works on XLA:CPU, but the failure seems to be happen in our python code.

alanwaketan commented 3 days ago

@johnsutor Can you try nightly?

johnsutor commented 3 days ago

@alanwaketan I installed nightly using

! pip install

but then I get the following error when I attempt to run the code block

import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import os 
from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
import torch
os.environ["PJRT_DEVICE"] = "CPU"
ImportError                               Traceback (most recent call last)
[<ipython-input-3-193790483d72>](https://localhost:8080/#) in <cell line: 5>()
      3 import jax.numpy as jnp
      4 import os
----> 5 from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
      6 import torch
      7 os.environ["PJRT_DEVICE"] = "CPU"

[/usr/local/lib/python3.10/dist-packages/torch_xla/](https://localhost:8080/#) in <module>
      7 import torch
----> 8 import _XLAC
      9 from ._internal import tpu
     10 from .version import __version__

ImportError: /usr/local/lib/python3.10/dist-packages/ undefined symbol: _ZNK3c105Error4whatEv

NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.

To view examples of installing some common dependencies, click the
"Open Examples" button below.
bhavya01 commented 3 days ago

@johnsutor This error happens usually when there's a mismatch between torch and torch_xla versions installed. Maybe also update the pytorch to nightly version.

pip3 install --pre torch torchvision torchaudio --index-url
johnsutor commented 3 days ago

@bhavya01 That unfortunately does not work, as I receive an issue from the PyTorch end when attempting to install in Colab.

Looking in indexes:
Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.3.0+cu121)
Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (0.18.0+cu121)
Requirement already satisfied: torchaudio in /usr/local/lib/python3.10/dist-packages (2.3.0+cu121)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.15.4)
Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12.1)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2023.6.0)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  ERROR: HTTP error 403 while getting (from
ERROR: Could not install requirement nvidia-cuda-nvrtc-cu12==12.1.105 from (from torch) because of HTTP error 403 Client Error: Forbidden for url: for URL (from

however, PyTorch nightly install works fine on my M2 Mac, but then I can't install torch_xla.

alanwaketan commented 2 days ago

@johnsutor Do you include torchvision torchaudio? If so, we can remove them from the command.

johnsutor commented 2 days ago

@alanwaketan On Colab, I had to uninstall torch to get it to work before installing torch nightly. However, I now get the following error when I attempt to run this code (you can see the results in the notebook here)

ef add_vectors_kernel(x_ref, y_ref, o_ref):
  x, y = x_ref[...], y_ref[...]
  o_ref[...] = x + y

def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
  return pl.pallas_call(add_vectors_kernel,
                        out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
                        )(x, y)

q = torch.randn(3, 2, 128, 4).to("xla")
k = torch.randn(3, 2, 128, 4).to("xla")
v = torch.randn(3, 2, 128, 4).to("xla")

# From the tutorial
pt_kernel = make_kernel_from_pallas(add_vectors, lambda x, y: [(x.shape, x.dtype)])
output = pt_kernel(q, k)
TypeError                                 Traceback (most recent call last)
[<ipython-input-4-d7393bd6c1ca>](https://localhost:8080/#) in <cell line: 3>()
      1 # From the tutorial
      2 pt_kernel = make_kernel_from_pallas(add_vectors, lambda x, y: [(x.shape, x.dtype)])
----> 3 output = pt_kernel(q, k)

[/usr/local/lib/python3.10/dist-packages/torch_xla/experimental/](https://localhost:8080/#) in wrapped_kernel(kernel, output_shape_dtype_fn, static_argnums, static_argnames, *args, **kwargs)
    143     output_shapes = [shape for shape, _ in output_shape_dtype]
    144     output_dtypes = [dtype for _, dtype in output_shape_dtype]
--> 145     outputs = torch_xla._XLAC._xla_tpu_custom_call(tensor_args, payload,
    146                                                    output_shapes, output_dtypes)

TypeError: _xla_tpu_custom_call(): incompatible function arguments. The following argument types are supported:
    1. (arg0: list[torch.Tensor], arg1: str, arg2: list[list[int]], arg3: list[object]) -> list[torch.Tensor]

Invoked with: [tensor([[[[ 0.9387, -0.0999,  1.2256,  1.0983],
          [ 0.1434,  0.6569, -1.0381, -1.5455],
          [ 0.6816,  1.0613, -1.0494, -0.1487],
          [-0.7309, -0.7079,  0.9284,  0.8607],
          [ 0.0899,  1.5027, -0.5658, -1.0447],
          [ 1.3098,  0.5118,  0.0765, -0.2424]],

         [[ 0.0088,  1.8775,  0.0542, -1.1126],
          [-0.6061, -0.3355, -0.7491, -1.8286],
          [ 0.7737,  0.0899, -0.2609,  0.4641],
          [ 0.7544,  1.3189,  0.9427,  0.6183],
          [ 0.9660, -0.5893, -0.3516,  1.0709],
          [-2.3094, -0.2950, -1.1045, -0.0845]]],

        [[[-0.4123,  0.7556, -0.4119,  0.6650],
          [-0.4744,  0.9115,  0.1186, -0.7852],
          [-0.2252, -0.6484,  1.5036,  0.7215],
          [-0.1553,  1.5585,  1.1157,  0.8698],
          [ 0.5997,  0.5789,  0.3054, -1.8421],
          [-0.5578, -0.8656,  0.1356,  1.1475]],

         [[ 2.0097,  0.7483, -0.5908,  0.0702],
          [-1.0810, -0.6120, -0.7814,  0.5367],
          [-0.9203, -0.9630, -1.7621,  1.3503],
          [ 0.2230, -0.2255,  1.2624, -0.7935],
          [-1.0775,  0.5843,  0.5457, -0.1265],
          [-2.4482, -1.0382, -0.9038, -0.9088]]],

        [[[ 0.6542,  2.3457,  0.0888, -0.2082],
          [-0.2973, -0.4685,  0.8633,  1.2241],
          [ 0.1258,  0.1412,  0.9298, -1.0842],
          [-0.6876, -1.5594, -1.0357,  0.3485],
          [ 1.1975, -0.1514, -1.2257,  0.9857],
          [ 1.7342, -1.5681,  0.4157,  0.9439]],

         [[ 1.1967,  0.2086, -0.5509, -1.1779],
          [ 0.4936, -0.8626, -0.6094, -0.7941],
          [ 0.0440, -0.5978,  1.2477,  1.2164],
          [-1.8150,  1.0365,  0.5270, -0.4706],
          [ 0.5347, -1.1803, -0.2394, -0.1587],
          [-0.9638, -1.0259, -1.2330, -0.2761]]]], device='xla:0'), tensor([[[[ 6.7709e-01,  5.0651e-01, -3.9015e-01,  1.1769e+00],
          [-5.7469e-01,  6.5236e-01, -6.9628e-01,  3.3803e-02],
          [-2.8044e-01,  7.1211e-01,  2.3748e-01,  3.7293e-01],
          [ 2.8838e+00,  8.6530e-01, -1.5567e-01,  2.1392e-01],
          [-5.7115e-01, -2.6569e+00,  1.2452e+00,  1.0137e-01],
          [ 8.7078e-01, -8.3965e-01, -9.3462e-01,  5.8777e-01]],

         [[-3.5002e-01, -1.0575e+00, -1.4964e+00,  9.9756e-01],
          [ 7.8972e-01, -4.1112e-02, -1.2023e+00, -5.3902e-01],
          [-5.9894e-01, -8.5050e-01, -3.6425e-01, -9.7505e-01],
          [ 1.2945e-02, -3.0388e-01, -1.3666e+00, -8.1373e-01],
          [-1.2614e+00,  1.3913e-01, -5.6531e-01, -4.5330e-01],
          [-1.1217e+00, -9.0676e-01, -1.0731e+00, -1.9240e-01]]],

        [[[ 4.0985e-01,  1.1629e+00, -5.1721e-01,  1.6515e-01],
          [-3.1879e-01,  7.2867e-01, -1.5622e+00, -6.3426e-01],
          [-6.6151e-01, -3.2032e-01,  2.1753e+00, -8.9741e-01],
          [-3.5038e-01,  2.1497e-01, -2.1903e-01,  3.8987e-01],
          [ 5.4283e-01, -5.4239e-01, -8.3459e-01,  4.8928e-01],
          [ 1.2570e+00, -1.4615e+00,  4.1475e-01,  1.5395e+00]],

         [[-8.7543e-01, -3.6893e-01, -6.6030e-01, -4.0877e-01],
          [-2.3046e-02, -6.1282e-01,  1.8114e-01, -5.9609e-01],
          [-1.8128e-01,  1.1691e+00, -5.3699e-01, -1.2312e-01],
          [ 2.0114e-01, -7.7060e-01,  1.1129e+00, -2.0385e-01],
          [ 1.0480e+00,  4.0939e-01, -5.2975e-01, -2.1745e-01],
          [-9.5069e-01, -9.6135e-01, -1.1307e+00,  1.1766e+00]]],

        [[[ 3.3030e+00,  6.5805e-01, -1.7184e+00,  3.5029e-01],
          [-1.2511e-01,  4.9624e-01, -1.3216e+00, -6.8949e-01],
          [ 1.5174e+00, -6.8108e-01,  1.5536e-01, -7.8465e-01],
          [-9.0042e-01,  6.8886e-01,  5.7894e-01,  1.2891e-01],
          [ 1.1717e-01, -2.0201e-01, -3.0778e-01, -1.8447e+00],
          [ 3.1883e-01,  6.5378e-01,  1.3329e+00,  2.5843e-01]],

         [[-2.1050e-01, -5.7218e-01, -5.8443e-01, -1.4757e+00],
          [-1.6935e+00, -3.2765e-02,  9.5702e-01,  8.3929e-01],
          [ 1.6788e-01, -1.0459e+00, -2.0357e-01,  3.7145e-02],
          [-9.4363e-01,  9.8749e-01, -4.5407e-01, -9.5364e-01],
          [-1.3861e-01, -2.1635e-01,  9.0047e-01, -2.7273e-02],
          [ 2.2375e+00,  2.0899e-03,  1.3707e+00, -6.9060e-01]]]],
       device='xla:0')], None, [torch.Size([3, 2, 128, 4])], [torch.float32]
alanwaketan commented 1 day ago

Okay, the payload is None which means no IR being traced from the kernel. I think we just don't support the interpret mode.

alanwaketan commented 1 day ago

Can you tell me more on why you are trying to use interpret mode? Maybe we can add that support in the future.