timeseriesAI / tsai

Time series Timeseries Deep Learning Machine Learning Python Pytorch fastai | State-of-the-art Deep Learning library for Time Series and Sequences in Pytorch / fastai
https://timeseriesai.github.io/tsai/
Apache License 2.0
5.07k stars 633 forks source link

MiniRocketPlus, InceptionRocketPlus, XResNet1dPlus, TSSequencerPlus don't work with ndim classification targets #719

Closed oguiza closed 1 year ago

oguiza commented 1 year ago

Failed test:

from tsai.basics import *
from tsai.learner import all_arch_names

num_classes = 5
X = torch.rand(8, 2, 50)
y = torch.randint(0, num_classes, (len(X), 1, 50))
splits = TimeSplitter(show_plot=False)(y)
vocab = np.arange(num_classes)

fail_test = []
for arch in all_arch_names:
    if not "plus" in arch.lower(): continue
    try:
        fcst = TSClassifier(X, y, splits=splits, arch=arch, metrics=accuracy, vocab=vocab)
        with ContextManagers([fcst.no_bar(), fcst.no_logging()]):
            fcst.fit_one_cycle(1, 1e-3)
    except Exception as e: 
        fail_test.append(arch)
        print(arch, e)

test_eq(fail_test, [])
oguiza commented 1 year ago

Closed as all all_arch_names ending in Plus pass the test.