Closed zimmerrol closed 3 years ago
Merging #31 (b2b48c9) into master (850a905) will not change coverage. The diff coverage is
100.00%
.
@@ Coverage Diff @@
## master #31 +/- ##
=========================================
Coverage 100.00% 100.00%
=========================================
Files 16 16
Lines 1753 1753
=========================================
Hits 1753 1753
Impacted Files | Coverage Δ | |
---|---|---|
eagerpy/astensor.py | 100.00% <100.00%> (ø) |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact)
,ø = not affected
,? = missing data
Powered by Codecov. Last update 850a905...b2b48c9. Read the comment docs.
Apparently, this was changed in a recent JAX version. In 0.1.70
it is jax.interpreters.xla.DeviceArray
, in 0.2.12
it is jaxlib.xla_extension.DeviceArray
.
But let's just add it to the existing if
to simplify testing: if (name == "jax" or name == "jaxlib") and
@jonasrauber I think this doesn't work because of the isinstance
check in the same line (https://github.com/jonasrauber/eagerpy/blob/b1d72892024a045d8716705ce2c462332abf1fa8/eagerpy/astensor.py#L55)
This check only works if name="jax"
since there is no class jaxlib.numpy.ndarray
. So should we change this line to?:
if (name == "jax" or name == "jaxlib") and isinstance(x, m["jax"].numpy.ndarray):
So should we change this line to?:
Yes 👍
I noticed that the conversion of
ndarray
s inJAX
toJAXTensors
ineagerpy
does not work on machines with GPUs, since hereJAX
will use thejaxlib
namespace to represent the array s(asDeviceArray
s) instead of the normaljax
namespace. This PR addsjaxlib
as an alias to detectJAX
tensors.