aeon-toolkit / aeon

A toolkit for machine learning from time series
https://aeon-toolkit.org/
BSD 3-Clause "New" or "Revised" License
961 stars 112 forks source link

[ENH] Improve `BaseCollectionEstimator` `_check_X` collection datatype validation #1889

Open MatthewMiddlehurst opened 1 month ago

MatthewMiddlehurst commented 1 month ago

Describe the feature or idea you want to propose

_check_X had checks for multivariate, unequal length, and missing values, but this could be improved with checks for specific datatypes i.e. numpy3d, np-list etc.

We should check if the input is one of these and if it is correctly formatted. Some of this probably done elsewhere, but likely will not raise an infromative error.

Describe your proposed solution

Create a function i.e. check_collection and return a informative errors if the datatype is not valid.

Describe alternatives you've considered, if relevant

No response

Additional context

No response

Cyril-Meyer commented 1 month ago

Check can be done using get_type. _check_X is calling it through _get_X_metadata -> is_univariate -> get_type.

from aeon.base import BaseCollectionEstimator

dummy = BaseCollectionEstimator()
all_tags = {
    "capability:multivariate": True,
    "capability:unequal_length": True,
    "capability:missing_values": True,
}
dummy.set_tags(**all_tags)

X = 42
dummy._check_X(X)
Traceback (most recent call last):
  File "C:\Users\cyril\Documents\Development\aeon\cyril.py", line 12, in <module>
    dummy._check_X(X)
  File "C:\Users\cyril\Documents\Development\aeon\aeon\base\_base_collection.py", line 139, in _check_X
    metadata = self._get_X_metadata(X)
  File "C:\Users\cyril\Documents\Development\aeon\aeon\base\_base_collection.py", line 232, in _get_X_metadata
    metadata["multivariate"] = not is_univariate(X)
  File "C:\Users\cyril\Documents\Development\aeon\aeon\utils\validation\collection.py", line 365, in is_univariate
    type = get_type(X)
  File "C:\Users\cyril\Documents\Development\aeon\aeon\utils\validation\collection.py", line 276, in get_type
    raise TypeError(
TypeError: ERROR passed input of type <class 'int'>, must be of type np.ndarray, pd.DataFrame or list of np.ndarray/pd.DataFrame
MatthewMiddlehurst commented 1 month ago

I feel this could be improved still, i.e. moved to the top level and link to resources for accepted data types.

There are also some aspects that are still not checked I think, i.e. is the dtype of the input valid (int or float)?