timeseriesAI / tsai

Time series Timeseries Deep Learning Machine Learning Python Pytorch fastai | State-of-the-art Deep Learning library for Time Series and Sequences in Pytorch / fastai
https://timeseriesai.github.io/tsai/
Apache License 2.0
5.1k stars 639 forks source link

(Proposing fixes) Adding sampler arg to the data loader causes bugs in get_idxs() function #625

Closed GILAB-RS closed 1 year ago

GILAB-RS commented 1 year ago

Hi, first of all, thanks for the awesome package, and really helpful tutorials and docs!

The problem is the following: I am performing time series classification with imbalanced classes, so I wanted to test Stratified Batch Sampling. I've used two StratifiedSamplers with ys from training and validation sets, which I've passed as sampler=[sampler_train, sampler_valid] to the get_ts_dls function.

Problem 1) But the output from https://github.com/timeseriesAI/tsai/blob/main/tsai/data/core.py#L565 returns a 2D array with shape (1,len(y)), which causes errors when trying to train the learner. Fix would be to add [0] at the end of this line (drop the extra first dimension of returned numpy array).

Problem 2) However, this only partly fixes the problem, since the second part of the problem occurs when using the trained model to infer on the new data using learn.get_X_preds(), which uses the new data to create a new data loader > https://github.com/timeseriesAI/tsai/blob/main/tsai/inference.py#L18

But this new data loader is then forwarded to the get_preds() function of fastai's Learner which uses the get_idxs() method of the data loader > https://github.com/fastai/fastai/blob/master/fastai/learner.py#L294

This results in the inference having constant number of results (the same number as the len(y) used when constructing the StratifiedSampler > https://github.com/timeseriesAI/tsai/blob/main/tsai/data/core.py#L820

Fix add **kwargs to the function def in: https://github.com/timeseriesAI/tsai/blob/main/tsai/data/core.py#L533 and function call in https://github.com/timeseriesAI/tsai/blob/main/tsai/data/core.py#L538

and add sampler=None to function call in: https://github.com/timeseriesAI/tsai/blob/main/tsai/inference.py#L17

ognjenantonijevic commented 1 year ago

Sorry, I accidentally posted an issue through my company account. Reposting with my personal acc.