tslearn-team / tslearn

The machine learning toolkit for time series analysis in Python
https://tslearn.readthedocs.io
BSD 2-Clause "Simplified" License
2.83k stars 331 forks source link

Fix error in _soft_dtw, add test and add option compute_with_backend in soft_dtw and soft_dtw_alignment. #481

Closed YannCabanes closed 9 months ago

YannCabanes commented 9 months ago

In the file tslearn/metrics/soft_dtw_fast.py, the functions _soft_dtw, _soft_dtw_batch, _soft_dtw_grad and _soft_dtw_grad_batch do not return a value, they modify directly the input value as a mutable object. An error was introduced in the PR #479 casting the input values to the selected backend with:

    D = be.array(D, dtype=be.float64)
    R = be.array(R, dtype=be.float64)
    gamma = be.array(gamma, dtype=be.float64)

Indeed, be.array cast the input values to the desired backend, but it creates a copy input values. Therefore the input values will not be modified as mutable objects. These lines are removed in this PR.

We add a test to make sure that this error will not be reproduced.

We also add the option compute_with_backend in the functions soft_dtw, soft_dtw_alignment, cdist_soft_dtw and cdist_soft_dtw_normalized. Before, the input data was cast to NumPy arrays to use Numba with the decorator @njit and the results were converted to the backend of the input data. We can now use PyTorch automatic differentiation with these functions.

codecov-commenter commented 9 months ago

Codecov Report

All modified lines are covered by tests :white_check_mark:

Comparison is base (e7a177c) 92.70% compared to head (74ef6c4) 93.32%.

:exclamation: Current head 74ef6c4 differs from pull request most recent head 7e228ee. Consider uploading reports for the commit 7e228ee to get more accurate results

:exclamation: Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #481 +/- ## ========================================== + Coverage 92.70% 93.32% +0.61% ========================================== Files 67 67 Lines 5732 5725 -7 ========================================== + Hits 5314 5343 +29 + Misses 418 382 -36 ``` | [Files](https://app.codecov.io/gh/tslearn-team/tslearn/pull/481?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=tslearn-team) | Coverage Δ | | |---|---|---| | [tslearn/metrics/soft\_dtw\_fast.py](https://app.codecov.io/gh/tslearn-team/tslearn/pull/481?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=tslearn-team#diff-dHNsZWFybi9tZXRyaWNzL3NvZnRfZHR3X2Zhc3QucHk=) | `75.42% <ø> (+24.26%)` | :arrow_up: | | [tslearn/metrics/softdtw\_variants.py](https://app.codecov.io/gh/tslearn-team/tslearn/pull/481?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=tslearn-team#diff-dHNsZWFybi9tZXRyaWNzL3NvZnRkdHdfdmFyaWFudHMucHk=) | `94.58% <100.00%> (+0.49%)` | :arrow_up: | | [tslearn/tests/test\_metrics.py](https://app.codecov.io/gh/tslearn-team/tslearn/pull/481?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=tslearn-team#diff-dHNsZWFybi90ZXN0cy90ZXN0X21ldHJpY3MucHk=) | `99.22% <100.00%> (+<0.01%)` | :arrow_up: | ... and [1 file with indirect coverage changes](https://app.codecov.io/gh/tslearn-team/tslearn/pull/481/indirect-changes?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=tslearn-team)

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.