aeon-toolkit / aeon

A toolkit for machine learning from time series
https://aeon-toolkit.org/
BSD 3-Clause "New" or "Revised" License
1.01k stars 118 forks source link

[BUG] aeon.distances.pairwise_distance does fail for unequal-length time series datasets #1227

Closed CodeLionX closed 6 months ago

CodeLionX commented 8 months ago

Describe the bug

All pairwise distance functions raise an error when presented with a dataset containing unequal-length time series, such as "japanese_vowels". The implementations support unequal-length time series but the wrapping function assumes a single np.ndarray as input. However unequal-length datasets have List[np.ndarray]-type.

Steps/Code to reproduce the bug

import aeon
from aeon.datasets import load_japanese_vowels
from aeon.distances import pairwise_distance, get_distance_function_names

print("aeon version", aeon.__version__)
X, y = load_japanese_vowels(split="train")
print("Type of input:", type(X), "shapes=")
print([x.shape for x in X])

error = None
for distance in get_distance_function_names():
    try:
        distance_matrix = pairwise_distance(X, metric=distance)
    except Exception as e:
        error = e
        print(f"Distance '{distance}' failed to process varying-length time series!")

if error is not None:
    print("Last traceback:")
    print(error)

Expected results

All distances functions produce a distance matrix instead of raising an exception.

Actual results

aeon version 0.7.0
Type of input: <class 'list'> shapes=
[(12, 20), (12, 26), (12, 22), (12, 20), (12, 21), (12, 23), (12, 22), (12, 18), (12, 24), (12, 15), (12, 23), (12, 15), (12, 17), (12, 14), (12, 14), (12, 15), (12, 15), (12, 21), (12, 16), (12, 15), (12, 15), (12, 19), (12, 22), (12, 20), (12, 17), (12, 17), (12, 13), (12, 16), (12, 13), (12, 14), (12, 18), (12, 17), (12, 16), (12, 17), (12, 15), (12, 15), (12, 15), (12, 18), (12, 18), (12, 21), (12, 17), (12, 13), (12, 13), (12, 17), (12, 15), (12, 16), (12, 15), (12, 18), (12, 15), (12, 11), (12, 17), (12, 15), (12, 13), (12, 12), (12, 14), (12, 16), (12, 12), (12, 16), (12, 15), (12, 15), (12, 21), (12, 19), (12, 14), (12, 12), (12, 14), (12, 15), (12, 20), (12, 14), (12, 7), (12, 17), (12, 11), (12, 15), (12, 15), (12, 9), (12, 12), (12, 18), (12, 13), (12, 12), (12, 12), (12, 16), (12, 13), (12, 12), (12, 17), (12, 15), (12, 13), (12, 19), (12, 11), (12, 13), (12, 10), (12, 15), (12, 21), (12, 18), (12, 23), (12, 23), (12, 21), (12, 21), (12, 23), (12, 21), (12, 24), (12, 19), (12, 23), (12, 23), (12, 17), (12, 23), (12, 19), (12, 17), (12, 18), (12, 19), (12, 15), (12, 22), (12, 18), (12, 15), (12, 17), (12, 25), (12, 21), (12, 19), (12, 23), (12, 16), (12, 20), (12, 22), (12, 13), (12, 11), (12, 12), (12, 10), (12, 11), (12, 10), (12, 16), (12, 13), (12, 11), (12, 15), (12, 11), (12, 15), (12, 16), (12, 13), (12, 17), (12, 14), (12, 16), (12, 15), (12, 15), (12, 10), (12, 15), (12, 14), (12, 18), (12, 15), (12, 15), (12, 12), (12, 12), (12, 10), (12, 12), (12, 10), (12, 17), (12, 16), (12, 16), (12, 21), (12, 18), (12, 13), (12, 16), (12, 17), (12, 21), (12, 23), (12, 19), (12, 16), (12, 18), (12, 18), (12, 15), (12, 16), (12, 18), (12, 19), (12, 16), (12, 16), (12, 18), (12, 14), (12, 16), (12, 19), (12, 23), (12, 17), (12, 18), (12, 16), (12, 18), (12, 15), (12, 16), (12, 15), (12, 15), (12, 17), (12, 15), (12, 12), (12, 16), (12, 18), (12, 18), (12, 14), (12, 20), (12, 19), (12, 17), (12, 14), (12, 14), (12, 18), (12, 17), (12, 20), (12, 17), (12, 14), (12, 13), (12, 16), (12, 16), (12, 21), (12, 16), (12, 15), (12, 21), (12, 16), (12, 21), (12, 25), (12, 10), (12, 12), (12, 11), (12, 11), (12, 10), (12, 11), (12, 11), (12, 14), (12, 13), (12, 12), (12, 12), (12, 13), (12, 17), (12, 13), (12, 11), (12, 12), (12, 11), (12, 12), (12, 11), (12, 12), (12, 14), (12, 15), (12, 14), (12, 15), (12, 13), (12, 10), (12, 13), (12, 15), (12, 13), (12, 16), (12, 17), (12, 12), (12, 18), (12, 16), (12, 17), (12, 19), (12, 20), (12, 19), (12, 18), (12, 11), (12, 11), (12, 17), (12, 15), (12, 12), (12, 15), (12, 14), (12, 14), (12, 11), (12, 13), (12, 14), (12, 12), (12, 17), (12, 11), (12, 10), (12, 14), (12, 15), (12, 17), (12, 12), (12, 14), (12, 9)]
Distance 'adtw' failed to process varying-length time series!
Distance 'ddtw' failed to process varying-length time series!
Distance 'dtw' failed to process varying-length time series!
Distance 'edr' failed to process varying-length time series!
Distance 'erp' failed to process varying-length time series!
Distance 'euclidean' failed to process varying-length time series!
Distance 'lcss' failed to process varying-length time series!
Distance 'manhattan' failed to process varying-length time series!
Distance 'minkowski' failed to process varying-length time series!
Distance 'msm' failed to process varying-length time series!
Distance 'shape_dtw' failed to process varying-length time series!
Distance 'squared' failed to process varying-length time series!
Distance 'twe' failed to process varying-length time series!
Distance 'wddtw' failed to process varying-length time series!
Distance 'wdtw' failed to process varying-length time series!
Last traceback:
Failed in nopython mode pipeline (step: nopython frontend)
Unknown attribute 'ndim' of type reflected list(array(float64, 2d, C))<iv=None>

File "../../../.conda/envs/aeon/lib/python3.10/site-packages/aeon/distances/_wdtw.py", line 303:
def wdtw_pairwise_distance(
    <source elided>
        # To self
        if X.ndim == 3:
        ^

During: typing of get attribute at /home/sebastian/.conda/envs/aeon/lib/python3.10/site-packages/aeon/distances/_wdtw.py (303)

File "../../../.conda/envs/aeon/lib/python3.10/site-packages/aeon/distances/_wdtw.py", line 303:
def wdtw_pairwise_distance(
    <source elided>
        # To self
        if X.ndim == 3:
        ^

Versions

System: python: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] executable: /home/sebastian/.conda/envs/aeon/bin/python machine: Linux-5.15.0-94-generic-x86_64-with-glibc2.35 Python dependencies: pip: 23.3.1 setuptools: 68.2.2 scikit-learn: 1.4.1.post1 aeon: 0.7.0 statsmodels: None numpy: 1.23.5 scipy: 1.10.0 pandas: 2.0.3 matplotlib: None joblib: 1.3.2 numba: 0.56.4 pmdarima: None tsfresh: None
TonyBagnall commented 8 months ago

hi, thanks for this, it is due to problems with using lists and ndarray within numpy, its easy to do in code but causes all sorts of compilation issues. @chrisholder has an idea about how to fix this, we will update on progress here

chrisholder commented 6 months ago

This has now been resolved in PR #1287 and PR #1356