jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.49k stars 2.8k forks source link

jit cache key should include device (really, device type) #12669

Open mattjj opened 2 years ago

mattjj commented 2 years ago
from jax import device_put, devices, jit
import jax.numpy as np
import numpy as onp

def f(x):
    return np.sum(x)

f0 = jit(f, device=devices()[0])
f1 = jit(f, device=devices()[1])

data = onp.random.rand(10000)

x0 = device_put(data, device=devices()[0])
x1 = device_put(data, device=devices()[1])

print("f on gpu 0:", f0(x0))
print("f on gpu 1:", f1(x1))

As I run this code on my machine, there are errors like this

f on gpu 0: 4949.328
2020-04-15 13:00:27.288577: E external/org_tensorflow/tensorflow/compiler/xla/python/local_client.cc:758] Execution of replica 0 failed: Invalid argument: executable is built for device CUDA:0 of type "GeForce RTX 2080 Ti"; cannot run it on device CUDA:1 of type "Tesla P100-PCIE-16GB"
Traceback (most recent call last):
  File "test3.py", line 21, in <module>
    print("f on gpu 1:", f1(x1))
  File "/usr/local/lib/python3.6/dist-packages/jax/api.py", line 150, in f_jitted
    out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
  File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 592, in call_bind
    outs = primitive.impl(f, *args, **params)
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 402, in _xla_call_impl
    return compiled_fun(*args)
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 486, in _execute_compiled
    out_bufs = compiled.Execute(input_bufs).destructure()
RuntimeError: Invalid argument: executable is built for device CUDA:0 of type "GeForce RTX 2080 Ti"; cannot run it on device CUDA:1 of type "Tesla P100-PCIE-16GB"

Could you give me some advice?

Originally posted by @caihao in https://github.com/google/jax/issues/1899#issuecomment-614026930

rajasekharporeddy commented 6 months ago

Hi @mattjj

Looks like this issue has been resolved in later versions of JAX. I tried to reproduce this issue with the latest JAX version 0.4.26 on cloud VM with 4 T4 GPUs. But it works without any error now. Please find the below screenshot for reference.

image

Thank you.