I'm trying to do eigendecomposition of non-hermitian matrix on TPU in Colab.
As we know, it's available only on CPU so I enforced eig run on CPU with jit option.
I tried 2 options.
w/o host_callback
w/ host_callback
from functools import partial
import jax
import jax.numpy as jnp
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
from jax.experimental import host_callback
@partial(jax.jit, backend='cpu')
def eig(mat):
print('eig')
return jnp.linalg.eig(mat)
@partial(jax.jit, static_argnums=(1, ))
def eig2(matrix, type_complex=jnp.complex128): #TODO: use type_complex arg
"""Wraps jnp.linalg.eig so that it can be jit-ed on a machine with GPUs."""
eigenvalues_shape = jax.ShapeDtypeStruct(matrix.shape[:-1], type_complex)
eigenvectors_shape = jax.ShapeDtypeStruct(matrix.shape, type_complex)
return host_callback.call(
# We force this computation to be performed on the cpu by jit-ing and
# explicitly specifying the device.
jax.jit(jnp.linalg.eig, device=jax.devices('cpu')[0]),
matrix.astype(type_complex),
result_shape=(eigenvalues_shape, eigenvectors_shape),
)
aa = jnp.arange(9).reshape((3,3))
try:
eig(aa)
except Exception as e:
print(1, e)
try:
eig(aa)
except Exception as e:
print(2, e)
try:
eig2(aa)
except Exception as e:
print(3, e)
Result
eig
2 Unable to cast Python instance to C++ type (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)
3 'NoneType' object has no attribute 'add_outfeed'
<ipython-input-4-353dc3404a3b>:25: UserWarning: Explicitly requested dtype <class 'jax.numpy.complex128'> requested in astype is not available, and will be truncated to dtype complex64. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
matrix.astype(type_complex),
In the first option, initial run was successful but the second outputs Error about casting.
Second option returned 'add_outfeed' error.
Description
Hi,
I'm trying to do eigendecomposition of non-hermitian matrix on TPU in Colab. As we know, it's available only on CPU so I enforced eig run on CPU with jit option.
I tried 2 options.
Result
In the first option, initial run was successful but the second outputs Error about casting. Second option returned 'add_outfeed' error.
are they bug or misused?
What jax/jaxlib version are you using?
jax 0.3.25, jaxlib 0.3.25+cuda11.cudnn805
Which accelerator(s) are you using?
TPU
Additional system info
google colab
NVIDIA GPU info
No response