unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
8.09k stars 881 forks source link

[QUESTION][REQUEST] Way to move model to a specific device #998

Closed ywein closed 1 year ago

ywein commented 2 years ago

In torch you can move model to run on a specific device via my_model.to(device) Is there a way currently to do that with Darts?

It can be especially useful for m1 macs, as it does not use CUDA and you need to specify an mps backend. In pure torch it's very easy to do, but I couldn't find a way to do that with Darts.

zwelitunyiswa commented 2 years ago

I tried many different ways and failed because I figured out that the underlying PyTorch lightning library does not have the apple gpu (device = mps) in a non dev build. Ergo, the setting is not available to darts.

hrzn commented 2 years ago

Here is the user guide section for using hardware accelerators. It relies on Pytorch lightning, and we haven't tested on Apple GPUs so not sure it's supported (we'd be interested to know if it works!)

skeenan commented 2 years ago

On Apple M1 GPU I get

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

zwelitunyiswa commented 2 years ago

Can you share some code and what version of darts, pytoch and lightning you are using? I cannot even get Darts to see my MPS. Pytorch sees it.

On Fri, Aug 26, 2022 at 4:42 PM skeenan @.***> wrote:

On Apple M1 I get

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

— Reply to this email directly, view it on GitHub https://github.com/unit8co/darts/issues/998#issuecomment-1228930805, or unsubscribe https://github.com/notifications/unsubscribe-auth/AH3QQV673UU5YEBEVEUMBT3V3ET2NANCNFSM5YHH7SJQ . You are receiving this because you commented.Message ID: @.***>

skeenan commented 2 years ago

Got this working.

followed the example in https://unit8.com/resources/time-series-forecasting-using-past-and-future-external-data-with-darts.

My changes:

converted timeseries objects to to np.float32

flow = flow.astype(np.float32) rainfalls = rainfalls.astype(np.float32)

and configured the model thus

rnn_rain = RNNModel(input_chunk_length=30, 
                    training_length=40, 
                    pl_trainer_kwargs={
                        "accelerator": "gpu",
                        "devices": [0],
                        "precision": 32
                    },
                    n_rnn_layers=2)
skeenan commented 2 years ago

Oddly. When running on CPU the training takes 21 seconds. When GPU is enabled it takes 3 minutes 17 seconds.

Unexpected.

hrzn commented 2 years ago

Oddly. When running on CPU the training takes 21 seconds. When GPU is enabled it takes 3 minutes 17 seconds.

Unexpected.

Could you try changing the value for num_loader_workers and see if it helps? There are also a few other performance considerations written here

skeenan commented 2 years ago

Thanks for the pointer. Followed the example https://unit8co.github.io/darts/userguide/gpu_and_tpu_usage.html#use-a-gpu

Unfortunately with num_loader_workers > 1 the model.fit() does not ever return (tried even with epoch=1) just hangs indefinitely. Tried a range of values, all the same, just hangs.

my_model.fit(train_transformed, 
val_series=val_transformed, 
verbose=False, 
 num_loader_workers=1)
hrzn commented 2 years ago

Thanks for the pointer. Followed the example https://unit8co.github.io/darts/userguide/gpu_and_tpu_usage.html#use-a-gpu

Unfortunately with num_loader_workers > 1 the model.fit() does not ever return (tried even with epoch=1) just hangs indefinitely. Tried a range of values, all the same, just hangs.

my_model.fit(train_transformed, 
val_series=val_transformed, 
verbose=False, 
 num_loader_workers=1)

Oh then that's an issue between PyTorch and MPS. Maybe make sure you have the latest version of PyTorch if that's not already the case.

skeenan commented 2 years ago

Using the latest nightly build for PyTorch. Have tried diving deeper, but no luck. If I ever resolve I'll post.

Wondering if this is related https://github.com/Lightning-AI/lightning/issues/4289

skeenan commented 2 years ago

Got this working. Solution to get num_loader_workers working was to wrap the code execution logic into an
if name == 'main' guard.

However, even with num_workers > 0 still very slow on GPU. 2seconds on CPU - I killed the still-running process after 30mins on GPU. I can only conclude there is some sort of issue with current state of M1 PyTorch implementation. Not worth bothering with this in the current state of support IMO. Hope this helps someone in future.

Complete code below.

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from darts.dataprocessing.transformers import Scaler
from darts.models import RNNModel
from darts.metrics import mape
from darts.datasets import AirPassengersDataset

def main():
    # Read data:
    series = AirPassengersDataset().load()
    series = series.astype(np.float32)

    # Create training and validation sets:
    train, val = series.split_after(pd.Timestamp("19590101"))

    # Normalize the time series (note: we avoid fitting the transformer on the validation set)
    transformer = Scaler()
    train_transformed = transformer.fit_transform(train)
    val_transformed = transformer.transform(val)
    series_transformed = transformer.transform(series)

    my_model = RNNModel(
        model="RNN",
        hidden_dim=20,
        dropout=0,
        batch_size=40,
        n_epochs=300,
        optimizer_kwargs={"lr": 1e-3},
        model_name="Air_RNN",
        log_tensorboard=True,
        random_state=42,
        training_length=20,
        input_chunk_length=14,
        force_reset=True,
         pl_trainer_kwargs={
          "accelerator": "gpu",
          "devices": [0],
        },

    )

    my_model.fit(train_transformed, val_series=val_transformed, verbose=False, num_loader_workers=4)

if __name__ == '__main__':
    main()