tslearn-team / tslearn

The machine learning toolkit for time series analysis in Python
https://tslearn.readthedocs.io
BSD 2-Clause "Simplified" License
2.91k stars 342 forks source link

Add the JAX backend #504

Open YannCabanes opened 10 months ago

YannCabanes commented 10 months ago

Add the JAX backend: https://jax.readthedocs.io/en/latest/ The JAX backend can be used for automatic differentiation: https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html

YannCabanes commented 10 months ago

The following error is obtained:

tslearn/tests/test_metrics.py:15: in <module>
    backends = [Backend("numpy"), None]
tslearn/backend/backend.py:99: in __init__
    self.backend = select_backend(data)
tslearn/backend/backend.py:75: in select_backend
    backends_instances = [NumPyBackend(), JAXBackend(), PyTorchBackend()]
tslearn/backend/backend.py:13: in __init__
    raise ValueError("Could not use JAX backend since JAX is not installed.")
E   ValueError: Could not use JAX backend since JAX is not installed.
YannCabanes commented 10 months ago

It seems from the previous error message that JAX is not installed during the continuous integration tests, therefore the tests in test_metrics.py are not running with the JAX backend.

On my local computer, these tests are failing with the following error message:

TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
YannCabanes commented 10 months ago

Difficulty: "JAX arrays are immutable" https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

YannCabanes commented 10 months ago

A solution would be to create a class named JAXNumPyInterface which would define the operators of mutable objects. The Python operators are listed here: https://docs.python.org/3/library/operator.html