unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
7.91k stars 857 forks source link

[BUG] TypeError: __init__() got an unexpected keyword argument 'tpus' #808

Closed gsamaras closed 2 years ago

gsamaras commented 2 years ago

Describe the bug Cannot fit N-Beats model to my data with a TPU in Google Colab.

To Reproduce

!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl
!pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchtext==0.10.0 -f https://download.pytorch.org/whl/cu111/torch_stable.html
!pip install 'u8darts[torch]==0.17.1'
!pip install pyyaml==5.4.1

and then after restarting the runtime:

import numpy as np

from darts import TimeSeries
from darts.models import NBEATSModel

series = TimeSeries.from_values(np.random.random_sample((500,)))

trainset_size = 0.6
train, val = series.split_after(trainset_size)

model_nbeats = NBEATSModel(
  input_chunk_length=8,
  output_chunk_length=2,
  n_epochs=20,
  pl_trainer_kwargs={
    "accelerator": "tpu",
    "tpus": [0]
  },
)
model_nbeats.fit(series=train, val_series=val, verbose=True) # Error here

I get the following error:

TypeError                                 Traceback (most recent call last)
[<ipython-input-4-1c57a1bae7f2>](https://localhost:8080/#) in <module>
     25   },
     26 )
---> 27 model_nbeats.fit(series=train, val_series=val, verbose=True)

6 frames
[/usr/local/lib/python3.7/dist-packages/darts/utils/torch.py](https://localhost:8080/#) in decorator(self, *args, **kwargs)
     68         with fork_rng():
     69             manual_seed(self._random_instance.randint(0, high=MAX_TORCH_SEED_VALUE))
---> 70             return decorated(self, *args, **kwargs)
     71 
     72     return decorator

[/usr/local/lib/python3.7/dist-packages/darts/models/forecasting/torch_forecasting_model.py](https://localhost:8080/#) in fit(self, series, past_covariates, future_covariates, val_series, val_past_covariates, val_future_covariates, trainer, verbose, epochs, max_samples_per_ts, num_loader_workers)
    769 
    770         return self.fit_from_dataset(
--> 771             train_dataset, val_dataset, trainer, verbose, epochs, num_loader_workers
    772         )
    773 

[/usr/local/lib/python3.7/dist-packages/darts/utils/torch.py](https://localhost:8080/#) in decorator(self, *args, **kwargs)
     68         with fork_rng():
     69             manual_seed(self._random_instance.randint(0, high=MAX_TORCH_SEED_VALUE))
---> 70             return decorated(self, *args, **kwargs)
     71 
     72     return decorator

[/usr/local/lib/python3.7/dist-packages/darts/models/forecasting/torch_forecasting_model.py](https://localhost:8080/#) in fit_from_dataset(self, train_dataset, val_dataset, trainer, verbose, epochs, num_loader_workers)
    913 
    914         # setup trainer
--> 915         self._setup_trainer(trainer, verbose, train_num_epochs)
    916 
    917         # TODO: multiple training without loading from checkpoint is not trivial (I believe PyTorch-Lightning is still

[/usr/local/lib/python3.7/dist-packages/darts/models/forecasting/torch_forecasting_model.py](https://localhost:8080/#) in _setup_trainer(self, trainer, verbose, epochs)
    522         self.trainer = (
    523             self._init_trainer(trainer_params=self.trainer_params, max_epochs=epochs)
--> 524             if trainer is None
    525             else trainer
    526         )

[/usr/local/lib/python3.7/dist-packages/darts/models/forecasting/torch_forecasting_model.py](https://localhost:8080/#) in _init_trainer(trainer_params, max_epochs)
    535             trainer_params_copy["max_epochs"] = max_epochs
    536 
--> 537         return pl.Trainer(**trainer_params_copy)
    538 
    539     @abstractmethod

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/connectors/env_vars_connector.py](https://localhost:8080/#) in insert_env_defaults(self, *args, **kwargs)
     36 
     37         # all args were already moved to kwargs
---> 38         return fn(self, **kwargs)
     39 
     40     return insert_env_defaults

TypeError: __init__() got an unexpected keyword argument 'tpus'

Expected behavior I thought I could use TPU based on this comment.

System:

Additional context

dennisbader commented 2 years ago

Hi @gsamaras, maybe to clarify: what you pass to pl_trainer_kwargs is directly used to instantiate a PyTorch Lightning (PL) Trainer. These arguments/parameters are neither changed nor controlled/maintained by Darts. So you have to find out how what Trainer parameters to use for TPU by looking into PyTorch Lightning.

See the PyTorch Lighnting Trainer parameters here

That being said and from the error: PL's Trainer doesn't have a tpus kwarg. I saw a tpu_cores kwarg, maybe it works with that? Also this might help: https://pytorch-lightning.readthedocs.io/en/stable/advanced/tpu.html#tpu-core-training

We haven't tested if the models run on TPU, so we don't know yet for sure.

Please let us know if/how you get it running :)

gsamaras commented 2 years ago

Hi @dennisbader, thanks!

So I was able to get it working after reading the PL docs, but I eventually got this:

MisconfigurationException: `Trainer(accelerator='tpu', precision=64)` is not implemented. Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues` requesting this feature.

As a result I used float-32 precision to actually make it run, that was fun!

If you think it makes sense, then I could probably write a laconic and compact tutorial about this (+the GPU usage), or maybe extend https://unit8co.github.io/darts/examples/07-NBEATS-examples.html? That way I can contribute in the project and give-back!

dennisbader commented 2 years ago

Hey @gsamaras, glad to hear that it worked out! Yes, sure, thanks! We are always happy about contribution.

I think this would be a great new user guide (a new .md file for /docs/userguide/). We could also add how to run models on GPU and that by default models run on CPU :)

You can take the following user guide as reference: https://unit8co.github.io/darts/userguide/torch_forecasting_models.html).

Let us know if you need help