zalandoresearch / pytorch-ts

PyTorch based Probabilistic Time Series forecasting framework based on GluonTS backend
MIT License
1.21k stars 190 forks source link

Matrix multiplication error during training TFT #49

Closed vpozdnyakov closed 3 years ago

vpozdnyakov commented 3 years ago

Environment Details

Error Description

I get a matrix multiplication error during training TFT

RuntimeError: mat1 and mat2 shapes cannot be multiplied (4608x2 and 1x32)

Steps to reproduce

  1. Open Colab https://colab.research.google.com/
  2. Run the code
!pip install pytorchts -q
!curl https://forecasters.org/data/m3comp/M3C.xls --create-dirs -o /root/.mxnet/gluon-ts/datasets/M3C.xls 

from gluonts.dataset.repository.datasets import get_dataset
from pts.model.tft import TemporalFusionTransformerEstimator
from pts import Trainer

dataset = get_dataset("m3_monthly", regenerate=False)

estimator = TemporalFusionTransformerEstimator(
    freq=dataset.metadata.freq,
    prediction_length=dataset.metadata.prediction_length,
    context_length=dataset.metadata.prediction_length,
    dropout_rate=0.1,
    num_outputs=15,
    trainer=Trainer(device='cpu',
                    epochs=20,
                    learning_rate=1e-3,
                    num_batches_per_epoch=100,
                    batch_size=128))

predictor = estimator.train(dataset.train)
kashif commented 3 years ago

thanks @vpozdnyakov for the report! I'll try to reproduce and figure out what's going on shortly!

nsomabalint commented 3 years ago

Hey, I have the exact same issue. Did you manage to find the reason? Do you have any tips on what could be causing it? Thank you!

kashif commented 3 years ago

@nsoma97 @vpozdnyakov sorry i didnt get time to look into it... I will try to squeeze it in this week... the issue most probably is that I have some value hard-coded from the dataset I was testing with....

nsomabalint commented 3 years ago

@kashif aslinagy has started debugging it here, his changes seem to have resolved my issues: https://github.com/aslinagy/pytorch-ts/tree/tft_fixes

vpozdnyakov commented 3 years ago

@kashif there is still a matrix multiplication error. Could you please explain?

kashif commented 3 years ago

Ok let me check…

kashif commented 3 years ago

@vpozdnyakov I checked on the latest master and I do not get this issue any more. So I can make a new release if you want or you can test it yourself on the master brach?

vpozdnyakov commented 3 years ago

@kashif please make a new release, I do not know how to install master branch version, I only did !pip install pytorchts. thanks!

kashif commented 3 years ago

@vpozdnyakov sure will do!

kashif commented 3 years ago

@vpozdnyakov ok made a new version... can you test?

vpozdnyakov commented 3 years ago

hi @kashif ! Thanks for your fixes, but I cannot test them, since there is another problem during training, look at a screenshot.

image

kashif commented 3 years ago

@vpozdnyakov ok can you kindly update to 0.5.1 and try?

vpozdnyakov commented 3 years ago

it works, thank you