unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
8.11k stars 884 forks source link

[BUG] dataloader_kwargs.num_workers exceptions during fit() #2525

Closed briandecamp closed 2 months ago

briandecamp commented 2 months ago

Describe the bug Getting multiple exceptions when setting num_workers

To Reproduce

import numpy as np
import pandas as pd

from darts.dataprocessing.transformers import Scaler
from darts.models import TFTModel
from darts.datasets import AirPassengersDataset

import torch
torch.set_float32_matmul_precision('medium')

# Read data:
series = AirPassengersDataset().load()
series = series.astype(np.float32)

# Create training and validation sets:
train, val = series.split_after(pd.Timestamp("19590101"))

# Normalize the time series (note: we avoid fitting the transformer on the validation set)
transformer = Scaler()
train_transformed = transformer.fit_transform(train)
val_transformed = transformer.transform(val)
series_transformed = transformer.transform(series)

my_model = TFTModel(
    model_name="TFT",
    input_chunk_length=24,
    output_chunk_length=12,
    hidden_size=100,
    dropout=0.1,
    batch_size=1,
    n_epochs=10,
    add_relative_index=False,
    add_encoders={
        'cyclic': {'future': ['month']}
    },
    pl_trainer_kwargs={
        "accelerator": "gpu",
        "devices": [0],
    }
)

my_model.fit(
     train_transformed,
     dataloader_kwargs={
         "num_workers": 8,
         "persistent_workers": True
     },
     verbose=True
)

Expected behavior Train the model

System (please complete the following information):

madtoinou commented 2 months ago

Hi @briandecamp,

I could not reproduce the issue with my setup. Can you please also share the exceptions you are getting?

briandecamp commented 2 months ago

Sorry to bother you @madtoinou I'm running on Windows and it works fine as long as I protect the code with if __name__ == "__main__"