timeseriesAI / tsai

Time series Timeseries Deep Learning Machine Learning Python Pytorch fastai | State-of-the-art Deep Learning library for Time Series and Sequences in Pytorch / fastai
https://timeseriesai.github.io/tsai/
Apache License 2.0
5.21k stars 651 forks source link

(Proposing fixes) Adding sampler arg to the data loader causes bugs in get_idxs() function #626

Closed ognjenantonijevic closed 1 year ago

ognjenantonijevic commented 1 year ago

Hi, first of all, thanks for the awesome package, and really helpful tutorials and docs!

The problem is the following: I am performing time series classification with imbalanced classes, so I wanted to test Stratified Batch Sampling. I've used two StratifiedSamplers with ys from training and validation sets, which I've passed as sampler=[sampler_train, sampler_valid] to the get_ts_dls function.

Problem 1) But the output from https://github.com/timeseriesAI/tsai/blob/main/tsai/data/core.py#L565 returns a 2D array with shape (1,len(y)), which causes errors when trying to train the learner. Fix would be to add [0] at the end of this line (drop the extra first dimension of returned numpy array).

Problem 2) However, this only partly fixes the problem, since the second part of the problem occurs when using the trained model to infer on the new data using learn.get_X_preds(), which uses the new data to create a new data loader > https://github.com/timeseriesAI/tsai/blob/main/tsai/inference.py#L18

But this new data loader is then forwarded to the get_preds() function of fastai's Learner which uses the get_idxs() method of the data loader > https://github.com/fastai/fastai/blob/master/fastai/learner.py#L294

This results in the inference having constant number of results (the same number as the len(y) used when constructing the StratifiedSampler > https://github.com/timeseriesAI/tsai/blob/main/tsai/data/core.py#L820

Fix add **kwargs to the function def in: https://github.com/timeseriesAI/tsai/blob/main/tsai/data/core.py#L533 and function call in https://github.com/timeseriesAI/tsai/blob/main/tsai/data/core.py#L538

and add sampler=None to function call in: https://github.com/timeseriesAI/tsai/blob/main/tsai/inference.py#L17

ognjenantonijevic commented 1 year ago

Or I'll just create a PR in the next few days when I find the time

oguiza commented 1 year ago

Hi @ognjenantonijevic , There are 3 approaches already available in tsai to handle target imbalance:

Based on all this, I'm not sure that the sampler is needed to handle target imbalance.

oguiza commented 1 year ago

Hi @ognjenantonijevic, Could you please let me know if you are still planning to submit a PR? Or should we close this issue?

oguiza commented 1 year ago

Closing this issue due to lack of activity and progress. If necessary please, create a new one.