jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30k stars 2.75k forks source link

enforcing eigendecomposition on CPU in TPU process #13959

Open yonghakim opened 1 year ago

yonghakim commented 1 year ago

Description

Hi,

I'm trying to do eigendecomposition of non-hermitian matrix on TPU in Colab. As we know, it's available only on CPU so I enforced eig run on CPU with jit option.

I tried 2 options.

  1. w/o host_callback
  2. w/ host_callback
from functools import partial

import jax
import jax.numpy as jnp
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

from jax.experimental import host_callback

@partial(jax.jit, backend='cpu')
def eig(mat):

    print('eig')
    return jnp.linalg.eig(mat)

@partial(jax.jit, static_argnums=(1, ))
def eig2(matrix, type_complex=jnp.complex128):  #TODO: use type_complex arg
    """Wraps jnp.linalg.eig so that it can be jit-ed on a machine with GPUs."""
    eigenvalues_shape = jax.ShapeDtypeStruct(matrix.shape[:-1], type_complex)
    eigenvectors_shape = jax.ShapeDtypeStruct(matrix.shape, type_complex)
    return host_callback.call(
        # We force this computation to be performed on the cpu by jit-ing and
        # explicitly specifying the device.
        jax.jit(jnp.linalg.eig, device=jax.devices('cpu')[0]),
        matrix.astype(type_complex),
        result_shape=(eigenvalues_shape, eigenvectors_shape),
    )

aa = jnp.arange(9).reshape((3,3))

try:
    eig(aa)
except Exception as e:
    print(1, e)

try:
    eig(aa)
except Exception as e:
    print(2, e)

try:
    eig2(aa)
except Exception as e:
    print(3, e)

Result

eig
2 Unable to cast Python instance to C++ type (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)
3 'NoneType' object has no attribute 'add_outfeed'
<ipython-input-4-353dc3404a3b>:25: UserWarning: Explicitly requested dtype <class 'jax.numpy.complex128'> requested in astype is not available, and will be truncated to dtype complex64. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  matrix.astype(type_complex),

In the first option, initial run was successful but the second outputs Error about casting. Second option returned 'add_outfeed' error.

are they bug or misused?

What jax/jaxlib version are you using?

jax 0.3.25, jaxlib 0.3.25+cuda11.cudnn805

Which accelerator(s) are you using?

TPU

Additional system info

google colab

NVIDIA GPU info

No response

hawkinsp commented 1 year ago

I suspect this an artifact of TPU colab, which is less well supported than TPU VMs. @skye ?