sktime / sktime

A unified framework for machine learning with time series
https://www.sktime.net
BSD 3-Clause "New" or "Revised" License
7.57k stars 1.29k forks source link

[ENH] proliferation of Rockets #3854

Open fkiraly opened 1 year ago

fkiraly commented 1 year ago

It would appear that the contents of the folder transformations.panel.dictionary_based._rocket are highly duplicative, with now six variants of the rocket algorithm in it?

__all__ = [
    "Rocket",
    "MiniRocket",
    "MiniRocketMultivariate",
    "MiniRocketMultivariateVariable",
    "MultiRocket",
    "MultiRocketMultivariate",
]

Not saying this is in itself a problem, because we made the decision to allow contribution of algorithms as long as they are well described.

What I wonder, there seems to be a lot of copy-paste repetition, so:

fkiraly commented 1 year ago

FYI authors @angus924, @michaelfeil

Related question: could MiniRocketMultivariateVariable not simply be a pipeline of an unequal length transformer and MiniRocketMultivariate, or is there some numba magic going on?

Especially the content of https://github.com/angus924/minirocket seems to have been copy-pasted all over the place...

MatthewMiddlehurst commented 1 year ago

Relevant discussion in #3786

fkiraly commented 1 year ago

sounds mysterious - perhaps worthwhile to add what @angus924 has said in some of the docstrings, e.g., which rocket to use. E.g., which have currently unresolved problems, etc. Perhaps even an issue to track the unresolved issues, we could leverage the wisdom of the crowd if they are just precisely formulated and in an open issue.

I have also seen users on various social channels being confused about there being too many rockets, so making the docstrings clearer in that respect would also help people to use the one that @angus924 currently considers the "right" one.

michaelfeil commented 1 year ago

I agree, I added the interface. This could be easier for users, e.g. only aMiniRocket class. This would require according check functions for univariate / multivariate / unequal length checks in the fit(X) and transform(X).

Pseudocode:

class MiniRocketNew
...
def _fit(X):
    if is_unequal_length(X):
        return MiniRocketMultivariateVariable.fit(X)
    elif is_multivariate(X):
        return MiniRocketMultivariate.fit(X)
    else:
        return MiniRocket.fit(X)

There is however the issue, that based upon the _tags, the data gets converted to the X_inner_mtype. E.g. For the unequal length, the "X_inner_mtype": "numpy3D", should fail.

https://github.com/sktime/sktime/blob/12917f4027c325dcd77c3b1c0c93644d81fbca12/sktime/transformations/panel/rocket/_minirocket_multivariate_variable.py#L85-L95

https://github.com/sktime/sktime/blob/12917f4027c325dcd77c3b1c0c93644d81fbca12/sktime/transformations/panel/rocket/_minirocket_multivariate.py#L63-L73

https://github.com/sktime/sktime/blob/12917f4027c325dcd77c3b1c0c93644d81fbca12/sktime/transformations/panel/rocket/_multirocket.py#L80-L89

Also the numba signatures are different for each function, e.g. minirocket taking the 2D array, minirocket_multivariate a numpy3D input, minirocket_multivariate_variable taking a flattened multivariate 2D [n_dimensions, sum(length_series(i) for i in n_instances)] array with information of the series length in an extra [n_instances] array.

While we could keep just the three versions of numba in one .py file, I am not sure how to adjust the _tags to all three cases.

fkiraly commented 1 year ago

@michaelfeil, I was referring more to the backend, i.e., the copy-pasting of the numba code.

Agree that the number of classes could also be reduced, but it's primarily the multiplication of the internal code that I think we need to address - DRY is more severely violated there.

TonyBagnall commented 1 year ago

hi, I have talked to @angus924 (code owner) about the multiple rocket transforms, and the original reason was to do with a problem he had with the multivariate version being much less efficient than the univariate one. It is something we will look at with him when time allows.