sktime / pytorch-forecasting

Time series forecasting with PyTorch
https://pytorch-forecasting.readthedocs.io/
MIT License
4.02k stars 639 forks source link

[ENH] Add `device` parameter to classes. #1647

Open fnhirwa opened 2 months ago

fnhirwa commented 2 months ago

Expected behavior

The current implementation of different classes seems to be using devices in a fashion that I can say is complex and users have no control exactly on how they can easily switch between devices.

https://github.com/jdb78/pytorch-forecasting/blob/81aee6650ed3de0c3071c9ce1fce19eec7fc24a7/pytorch_forecasting/metrics/distributions.py#L470

Like here the device being used is the one of the input which is somehow complex for the interface being used.

I suggest the Introduction of the device parameter to the classes. This would give users control over the device control and switching, rather than relying entirely on input data devices.

XinyuWuu commented 2 months ago

The distribution is implemented by lightning. I am not sure if we can set it manually.

import lightning.pytorch as pl
trainer = pl.Trainer(
    accelerator="cpu",
)
fnhirwa commented 2 months ago

I think we should find a better way to handle the accelerator case as I don't think it would be sensible to set these for users: https://github.com/jdb78/pytorch-forecasting/blob/bb6c8a2243c35ca35c2c0e14093d352430fee6d0/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py#L177

We can pass it as some optional parameter.

fkiraly commented 2 months ago

This is the tuning code, I think it's extraneous to the models - as long as the models themselves have it not hard-coded, I consider this less of a problem. I feel this one tuning routine does a lot of hard-coding and things extraneous to an otherwise very consistent architectural design.