timeseriesAI / tsai

Time series Timeseries Deep Learning Machine Learning Python Pytorch fastai | State-of-the-art Deep Learning library for Time Series and Sequences in Pytorch / fastai
https://timeseriesai.github.io/tsai/
Apache License 2.0
5.07k stars 633 forks source link

Can't train PatchTST on different (than X) dimensions for target #713

Open strakehyr opened 1 year ago

strakehyr commented 1 year ago

Hi @oguiza and thanks again for implementing a SOTA model.

Seem to have run into a limitation for PatchTST in which I can't seem to train it for a different dimension X compared to y. As opposed to, for instance, TSTPlus where we could have several X covariates and then different number of Y series (with different dimensions). I might just be using the wrong approach for this, as your PatchTST example uses a sliding window on the same variables for X and Y.

The TSTPlus case:

X.shape, y.shape
arch_config = dict(
    n_layers=3,  
    ks = 4,
    n_heads=4,  
    d_model=16,  
    d_ff=128,  
    dropout=0.3,
)
learn = TSForecaster(X, y, splits=splits, batch_size=16, path="models", pipelines=[exp_pipe],
                     arch="PatchTST", arch_config=arch_config, metrics=[mse, mae], cbs=ShowGraph())

n_epochs = 100
lr_max = 0.0025
lr_max = learn.lr_find().valley
((7416, 168, 4), (7416, 24, 3))
Epoch 1/1 : |█████---------------| 28.57% [92/322 00:01<00:03 1.1506]

This doesn't seem to be the case for the PatchTST:

arch_config = dict(
    n_layers=3,
    n_heads=4,
    d_model=16,
    d_ff=128,
    attn_dropout=0.0,
    dropout=0.3,
    patch_len=1,
)
learn = TSForecaster(X, y, splits=splits, batch_size=16, path="models", pipelines=[exp_pipe],
                     arch="PatchTST", arch_config=arch_config, metrics=[mse, mae], cbs=ShowGraph())
n_epochs = 100
lr_max = 0.0025
lr_max = learn.lr_find().valley
RuntimeError: The size of tensor a (8064) must match the size of tensor b (1152) at non-singleton dimension 0            
oguiza commented 1 year ago

Thanks for raising this @strakehyr. You are absolutely right. This is a limitation of the current PatchTST model. I'd like to develop a PatchTSTPlus model that allows to use it in scenarios like the one you described above. This is the normal process that I've followed in the library. The standard version replicates as close as possible the model published in the paper/ code. And then a Plus version add additional functionality. Forecasting any # variables is one of those scenarios. Another one is Classification or Regression. I'd like to work on this soon, but need to find the time or resources. Would you be interested in creating a PR? If your are interested I could you give you some direction on what needs to be done.

vrodriguezf commented 1 year ago

I am interested in this too. I might have some time to put my hands on it in about two weeks

strakehyr commented 1 year ago

I would be happy to help, but I don't currently have the time.