zalandoresearch / pytorch-ts

PyTorch based Probabilistic Time Series forecasting framework based on GluonTS backend
MIT License
1.21k stars 190 forks source link

unable to reproduce results from notebook #39

Open turmeric-blend opened 3 years ago

turmeric-blend commented 3 years ago

I am unable to reproduce results from TimeGrad Notebook. I am getting diverging loss into NaN loss.

predictor = estimator.train(dataset_train, num_workers=8)

99it [00:22, 4.39it/s, avg_epoch_loss=0.945, epoch=0] 99it [00:22, 4.40it/s, avg_epoch_loss=0.495, epoch=1] 99it [00:22, 4.39it/s, avg_epoch_loss=0.466, epoch=2] 99it [00:22, 4.35it/s, avg_epoch_loss=0.795, epoch=3] 99it [00:22, 4.33it/s, avg_epoch_loss=0.852, epoch=4] 99it [00:22, 4.32it/s, avg_epoch_loss=nan, epoch=5]
99it [00:22, 4.33it/s, avg_epoch_loss=nan, epoch=6] 99it [00:22, 4.30it/s, avg_epoch_loss=nan, epoch=7] 99it [00:23, 4.30it/s, avg_epoch_loss=nan, epoch=8] 99it [00:22, 4.34it/s, avg_epoch_loss=nan, epoch=9] 99it [00:23, 4.29it/s, avg_epoch_loss=nan, epoch=10] 99it [00:23, 4.28it/s, avg_epoch_loss=nan, epoch=11] 99it [00:22, 4.33it/s, avg_epoch_loss=nan, epoch=12] 99it [00:23, 4.21it/s, avg_epoch_loss=nan, epoch=13] 99it [00:23, 4.30it/s, avg_epoch_loss=nan, epoch=14] 99it [00:23, 4.30it/s, avg_epoch_loss=nan, epoch=15] 99it [00:22, 4.34it/s, avg_epoch_loss=nan, epoch=16] 99it [00:22, 4.34it/s, avg_epoch_loss=nan, epoch=17] 99it [00:22, 4.34it/s, avg_epoch_loss=nan, epoch=18] 99it [00:23, 4.20it/s, avg_epoch_loss=nan, epoch=19]

kashif commented 3 years ago

thanks for letting me know.. perhaps i screwed something up while refactoring, I'll check and get back to you

kashif commented 3 years ago

@turmeric-blend which parameters did you give the model? I remember that when I set the beta_end to be high I got nans...

turmeric-blend commented 3 years ago

my beta_end=0.07, I ran everything the same as the notebook example.

kashif commented 3 years ago

thanks! I just re-ran it again on my machine and it all worked out... very strange... 🤔

turmeric-blend commented 3 years ago

I am running from the downloaded zip folder (didn't install via pip install pytorchts). Not sure if this affects anything.

kashif commented 3 years ago

hmm not sure... perhaps in the downloaded zip folder do: pip install . and then try?

turmeric-blend commented 3 years ago

I feel like its related to a random seed since we are using different machines....

kashif commented 3 years ago

also which version of pytorch to you use? I am using pytorch 1.7.1 here

turmeric-blend commented 3 years ago

pytorch 1.7.1+cu110

turmeric-blend commented 3 years ago

I tried with pip install ., nan still occurs, maybe you could update the results with fixed seed for pytorch,mxnet,numpy,random ... etc, and I will see if I can reproduce it?

kashif commented 3 years ago

I can try sure.. I just re-ran it again with the parameters from the paper and checked in the notebook, I also fixed the "cuda" device name...

turmeric-blend commented 3 years ago

my cuda settings is like:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

but this is just because I don't have a GPU on 1.

kashif commented 3 years ago

ok so setting a seed e.g. via:

np.random.seed(123456)
torch.manual_seed(123456)

also works for me and I get no nan in training... can you try to perhaps train with less num_workers e.g. 4 or 2 or even 0?

turmeric-blend commented 3 years ago

nothing works :/ I reinstalled everything in a clean env and still doesn't work.

I installed gluonts via pip install git+https://github.com/awslabs/gluon-ts.git@master#egg=gluonts.

If I try to install PyTorchTS via pip install pytorchts, I get this error:

ERROR: Packages installed from PyPI cannot depend on packages which are not also hosted on PyPI. pytorchts depends on gluonts@ git+https://github.com/awslabs/gluon-ts.git@master#egg=gluonts

kashif commented 3 years ago

sorry to hear that... i will try to reproduce on a clean env as well!

wsdqy1234 commented 1 year ago

I have the same results NaN, and I found that the outputs are NaN. It is very strange.

kashif commented 1 year ago

can you try to use the 0.7.0 branch?