timeseriesAI / tsai

Time series Timeseries Deep Learning Machine Learning Python Pytorch fastai | State-of-the-art Deep Learning library for Time Series and Sequences in Pytorch / fastai
https://timeseriesai.github.io/tsai/
Apache License 2.0
5.1k stars 639 forks source link

Getting train, validation and test datasets with get_splits? #602

Closed ninoslavc closed 1 year ago

ninoslavc commented 1 year ago

Hi @oguiza , in case we use splits = get_splits(y, valid_size = 0.2, test_size = 0.2, stratify = True, random_state = seed) we get splits for all three datasets. Could you give an example of using the test portion of the dataset in that case (to the inference of the trained model)?

Is the correct way of getting and using the test portion: X_test = X[splits[2]] y_test = y[splits[2]]

test_ds = dsets.valid.add_test(X_test, y_test) test_dl = dls.valid.new(test_ds) #what happens with the previous valid test in this moment?

test_probas, test_targets, test_preds = learn.get_preds(dl=test_dl, with_decoded=True) test_probas, test_targets, test_preds ?

Also, why can't we get dsets.test, as we can get dsets.valid and dsets.train? thx

oguiza commented 1 year ago

Hi @ninoslavc ,

Here's an example. Let's first create some data with 3 splits:

from tsai.basics import *
X, y, _ = get_UCR_data('LSST', split_data=False)
splits = get_splits(y, valid_size = 0.2, test_size = 0.2, stratify = True, random_state = 1234)
splits

The way I'd recommend you is this:

Pass train and valid splits when building the dataloaders and use the test split for inference

# training
tfms = [None, TSClassification()]
batch_tfms = TSStandardize(by_sample=True)
dls = get_ts_dls(X, y, splits=splits[:2], tfms=tfms, batch_tfms=batch_tfms)
learn = ts_learner(dls,  metrics=accuracy)
learn.fit_one_cycle(10, 1e-2)

# inference
X_test = X[splits[2]]
y_test = y[splits[2]]
test_probas, test_targets, test_preds = learn.get_X_preds(X_test, y_test, with_decoded=True)

This behavior is inherited from fastai. fastai is built to create dsets.valid and dsets.valid, but not dsets.test. But it then allows you to build as many datasets and dataloaders as you need.