pymc-devs / pytensor

PyTensor allows you to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays.
https://pytensor.readthedocs.io
Other
340 stars 101 forks source link

minor bug: jax tests reference jax.interpreters.xla.DeviceArray #660

Open leventov opened 7 months ago

leventov commented 7 months ago

Describe the issue:

https://github.com/pymc-devs/pytensor/pull/133 In this PR, the second branch of this conditional:

https://github.com/pymc-devs/pytensor/blob/d175203b4e00f48db9c61b68a5f70263a1fbb645/tests/link/jax/test_basic.py#L73-L77

was evidently not updated only because it is never executed in the tests.

cc @ricardoV94

Reproducable code example:

n/a

Error message:

No response

PyTensor version information:

n/a

Context for the issue:

No response

ricardoV94 commented 7 months ago

@leventov want to update it?

leventov commented 7 months ago

Yes, I can do it later.

Also: probably this function as well as compare_numba_and_py() should use itertools.zip_longest rather than zip.