unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
7.95k stars 866 forks source link

[BUG]Model training stucked, why? #1643

Closed langslike closed 1 year ago

langslike commented 1 year ago

Describe the bug I am training a TFT model, the progress bar stucked, I wait for 30 mins, there is still no progress and my cuda usage go down to 0, GPU memory usage stop changing.Can't figure why. here is the available output:

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                              | Type                             | Params
----------------------------------------------------------------------------------------
0  | train_metrics                     | MetricCollection                 | 0     
1  | val_metrics                       | MetricCollection                 | 0     
2  | input_embeddings                  | _MultiEmbedding                  | 0     
3  | static_covariates_vsn             | _VariableSelectionNetwork        | 0     
4  | encoder_vsn                       | _VariableSelectionNetwork        | 15.3 K
5  | decoder_vsn                       | _VariableSelectionNetwork        | 896   
6  | static_context_grn                | _GatedResidualNetwork            | 4.3 K 
7  | static_context_hidden_encoder_grn | _GatedResidualNetwork            | 4.3 K 
8  | static_context_cell_encoder_grn   | _GatedResidualNetwork            | 4.3 K 
9  | static_context_enrichment         | _GatedResidualNetwork            | 4.3 K 
10 | lstm_encoder                      | LSTM                             | 25.3 K
11 | lstm_decoder                      | LSTM                             | 25.3 K
12 | post_lstm_gan                     | _GateAddNorm                     | 2.2 K 
13 | static_enrichment_grn             | _GatedResidualNetwork            | 5.3 K 
14 | multihead_attn                    | _InterpretableMultiHeadAttention | 2.6 K 
15 | post_attn_gan                     | _GateAddNorm                     | 2.2 K 
16 | feed_forward_block                | _GatedResidualNetwork            | 4.3 K 
17 | pre_output_gan                    | _GateAddNorm                     | 2.2 K 
18 | output_layer                      | Linear                           | 561   
----------------------------------------------------------------------------------------
103 K     Trainable params
0         Non-trainable params
103 K     Total params
0.827     Total estimated model params size (MB)
Epoch 82: 94%
61/65 [00:03<00:00, 15.26it/s, loss=0.179, train_loss=0.124, val_loss=0.211]
Validation DataLoader 0: 81%
13/16 [00:00<00:00, 45.88it/s]

To Reproduce This is my code

from darts.models.forecasting.tft_model import TFTModel

encoders = {"datetime_attribute": {"past": ["month", 'dayofweek']}, "transformer": Scaler()}

model = TFTModel(input_chunk_length=10, output_chunk_length=1, 
                 hidden_size=32, lstm_layers=3,
                 random_state=42,
                add_encoders = encoders, add_relative_index =True,
                pl_trainer_kwargs={"accelerator": "gpu", "devices": [0]},)

model.fit(series = targets_train_scaled, 
          past_covariates = covariates_train_scaled, 
                  val_series = targets_val_scaled, 
          val_past_covariates = covariates_val_scaled, 
          epochs=100, verbose=True)

Expected behavior I just want to know what's wrong under the hood, how to avoid the problem? can darts provide some kind of debug mode to output more detailed information?

madtoinou commented 1 year ago

Hi,

The cause of your issue is probably not caused by Darts directly but rather what you're using to run the training.

Are you using a python script or a jupyter notebook to run the training? In my own experience, python kernels for jupyter notebook sometimes crash when the computation is too intense (or resources might have been taken by another process?). In your case, it seems like the size of the model and the number of epochs are reasonable but the training timeseries might be very large?

According to your logs, the training was almost finished, try reducing the number of training epochs to see if it finishes?

You could try to use tensorboard to track the progress of your model or PyTorch Lightning functionalities to inspect what is happening under the hood (documentation)

langslike commented 1 year ago

The dataset is quite small, only 2k samples, each with about 20 variates. I tried to use smaller learning rate, and it worked. I will try the debug method you mentioned, Thanks