aeon-toolkit / aeon

A toolkit for machine learning from time series
https://aeon-toolkit.org/
BSD 3-Clause "New" or "Revised" License
882 stars 93 forks source link

[BUG] classification base class does not check for different length series in `predict` #1712

Open MatthewMiddlehurst opened 1 week ago

MatthewMiddlehurst commented 1 week ago

Describe the bug

From #1696

The classification base class does not check for a different series length in predict than the one used in fit. This should raise an exception unless the classifier can handle unequal length.

Possibly impacts other base classes, but have not checked.

The below example has two estimators raise an exception like excepted, but this in internal to the specific estimator from the following lines:

        n_cases, n_channels, n_timepoints = X.shape

        if n_channels != self.n_channels_:
            raise ValueError(
                "The number of channels in the train data does not match the number "
                "of channels in the test data"
            )
        if n_timepoints != self.n_timepoints_:
            raise ValueError(
                "The series length of the train data does not match the series length "
                "of the test data"
            )

Steps/Code to reproduce the bug

from aeon.classification.convolution_based import Arsenal
from aeon.classification.hybrid import HIVECOTEV2
from aeon.classification.distance_based import KNeighborsTimeSeriesClassifier
from aeon.classification.interval_based import DrCIFClassifier
import numpy as np

cls = [Arsenal, HIVECOTEV2, KNeighborsTimeSeriesClassifier, DrCIFClassifier]

for c in cls:
    print(c)
    X = np.random.random((10,20))
    X2 = np.random.random((10,20))
    X3 = np.random.random((1,20))
    X4 = np.random.random((1,200))
    X5 = np.random.random((10,200))
    y = np.array([0,1,0,1,0,1,0,1,0,1])

    afc = c()
    afc.fit(X,y)

    try:
        print(afc.predict(X2))
    except Exception as e:
        print("X2 failed")
        print(e)
    try:
        print(afc.predict(X3))
    except Exception as e:
        print("X3 failed")
        print(e)
    try:
        print(afc.predict(X4))
    except Exception as e:
        print("X4 failed")
        print(e)
    try:
        print(afc.predict(X5))
    except Exception as e:
        print("X5 failed")
        print(e)

Expected results

All classifiers except for KNN which can handle unequal length throw an exception in the base class.

Actual results

Only DrCIF and HC2 (which contains DrCIF) throw exceptions, neither in the base class.

<class 'aeon.classification.convolution_based._arsenal.Arsenal'>
[1 1 1 0 1 1 1 0 0 1]
[1]
[1]
[1 1 1 1 1 1 1 1 1 1]
<class 'aeon.classification.hybrid._hivecote_v2.HIVECOTEV2'>
[0 0 1 1 0 1 0 0 1 0]
[0]
X4 failed
The series length of the train data does not match the series length of the test data
X5 failed
The series length of the train data does not match the series length of the test data
<class 'aeon.classification.distance_based._time_series_neighbors.KNeighborsTimeSeriesClassifier'>
[1 1 1 1 0 1 1 1 1 0]
[1]
[0]
[1 0 0 0 0 1 1 0 1 0]
<class 'aeon.classification.interval_based._drcif.DrCIFClassifier'>
[1 0 1 1 1 1 0 1 1 1]
[1]
X4 failed
The series length of the train data does not match the series length of the test data
X5 failed
The series length of the train data does not match the series length of the test data

Versions

N/A

TonyBagnall commented 1 week ago

@MatthewMiddlehurst

I think this should always be raised

        if n_channels != self.n_channels_:
            raise ValueError(
                "The number of channels in the train data does not match the number "
                "of channels in the test data"
            )

since we have no mechanism for determining if a classifier can work with different number of channels and this

        if n_channels != self.n_channels_:
            raise ValueError(
                "The number of channels in the train data does not match the number "
                "of channels in the test data"
            )

only if capability:unequal is false.Have I got that right? PR would do the above and strip out the DrCIF checks