unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
7.91k stars 857 forks source link

[BUG] nbeats.py works better if you add .cuda() to each torch.zero #251

Closed catskillsresearch closed 3 years ago

catskillsresearch commented 3 years ago

Describe the bug It complains tensors are on GPU and CPU. nbeats.py works better if you add .cuda() to each torch.zero

To Reproduce Run NBEATS-examples notebook.

Expected behavior Shouldn't do that.

System (please complete the following information):

hrzn commented 3 years ago

Thanks for raising. We were aware of issues with N-Beats when using GPU. It should be solved in an upcoming release (being addressed here: https://github.com/unit8co/darts/pull/231)

hrzn commented 3 years ago

This one should be solved since 0.6.1, closing.