tslearn-team / tslearn

The machine learning toolkit for time series analysis in Python
https://tslearn.readthedocs.io
BSD 2-Clause "Simplified" License
2.83k stars 331 forks source link

Add the JAX backend #504

Open YannCabanes opened 5 months ago

YannCabanes commented 5 months ago

Add the JAX backend: https://jax.readthedocs.io/en/latest/ The JAX backend can be used for automatic differentiation: https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html

YannCabanes commented 5 months ago

The following error is obtained:

tslearn/tests/test_metrics.py:15: in <module>
    backends = [Backend("numpy"), None]
tslearn/backend/backend.py:99: in __init__
    self.backend = select_backend(data)
tslearn/backend/backend.py:75: in select_backend
    backends_instances = [NumPyBackend(), JAXBackend(), PyTorchBackend()]
tslearn/backend/backend.py:13: in __init__
    raise ValueError("Could not use JAX backend since JAX is not installed.")
E   ValueError: Could not use JAX backend since JAX is not installed.
YannCabanes commented 5 months ago

It seems from the previous error message that JAX is not installed during the continuous integration tests, therefore the tests in test_metrics.py are not running with the JAX backend.

On my local computer, these tests are failing with the following error message:

TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
YannCabanes commented 5 months ago

Difficulty: "JAX arrays are immutable" https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

YannCabanes commented 5 months ago

A solution would be to create a class named JAXNumPyInterface which would define the operators of mutable objects. The Python operators are listed here: https://docs.python.org/3/library/operator.html