jonasrauber / eagerpy

PyTorch, TensorFlow, JAX and NumPy — all of them natively using the same code
https://eagerpy.jonasrauber.de
MIT License
695 stars 40 forks source link

Jax tensors are not correctly recognized if they are stored on GPUs #31

Closed zimmerrol closed 3 years ago

zimmerrol commented 3 years ago

I noticed that the conversion of ndarrays in JAX to JAXTensors in eagerpy does not work on machines with GPUs, since here JAX will use the jaxlib namespace to represent the array s(as DeviceArrays) instead of the normal jax namespace. This PR adds jaxlib as an alias to detect JAX tensors.

codecov[bot] commented 3 years ago

Codecov Report

Merging #31 (b2b48c9) into master (850a905) will not change coverage. The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff            @@
##            master       #31   +/-   ##
=========================================
  Coverage   100.00%   100.00%           
=========================================
  Files           16        16           
  Lines         1753      1753           
=========================================
  Hits          1753      1753           
Impacted Files Coverage Δ
eagerpy/astensor.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 850a905...b2b48c9. Read the comment docs.

jonasrauber commented 3 years ago

Apparently, this was changed in a recent JAX version. In 0.1.70 it is jax.interpreters.xla.DeviceArray, in 0.2.12 it is jaxlib.xla_extension.DeviceArray.

But let's just add it to the existing if to simplify testing: if (name == "jax" or name == "jaxlib") and

zimmerrol commented 3 years ago

@jonasrauber I think this doesn't work because of the isinstance check in the same line (https://github.com/jonasrauber/eagerpy/blob/b1d72892024a045d8716705ce2c462332abf1fa8/eagerpy/astensor.py#L55)

This check only works if name="jax" since there is no class jaxlib.numpy.ndarray. So should we change this line to?:

if (name == "jax" or name == "jaxlib") and isinstance(x, m["jax"].numpy.ndarray):
jonasrauber commented 3 years ago

So should we change this line to?:

Yes 👍