awslabs / gluonts

Probabilistic time series modeling in Python
Apache License 2.0
4.41k stars 741 forks source link

PyTorch Lightning logs are not synchronised when using distributed training #3157

Open admivsn opened 2 months ago

admivsn commented 2 months ago


As described in PyTorch Lightning documentation, the logs need to be synchronised using sync_dist=True.

For example in DeepAR, I think there should be an extra parameter when running distributed training.

I notice that when training on multi-GPU SageMaker instances I don't see the a performance uplift compared to a single-GPU instance. I also get a warning output from PyTorch Lightning.

It is recommended to use `self.log('train_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.

To Reproduce

It's difficult to reproduce as I'm running a SageMaker Training job.


You can see by the set up there are 4 GPUs, which are detected by PyTorch Lightning as the logs look like this:

2024-04-17 20:37:48 Starting - Starting the training job...
2024-04-17 20:38:05 Starting - Preparing the instances for training......
2024-04-17 20:39:10 Downloading - Downloading input data...
2024-04-17 20:39:29 Downloading - Downloading the training image............
2024-04-17 20:41:50 Training - Training image download completed. Training in progress........2024-04-17 20:42:45,440 sagemaker-training-toolkit INFO     No Neurons detected (normal if no neurons installed)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/ Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
/usr/local/lib/python3.10/site-packages/lightning/pytorch/trainer/ You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/4
Missing logger folder: /opt/ml/code/lightning_logs
Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/4
Missing logger folder: /opt/ml/code/lightning_logs
Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/4
All distributed processes registered. Starting with 4 processes
Missing logger folder: /opt/ml/code/lightning_logs
Missing logger folder: /opt/ml/code/lightning_logs
  | Name  | Type        | Params | In sizes                                                        | Out sizes  
0 | model | DeepARModel | 25.9 K | [[1, 1], [1, 1], [1, 1102, 4], [1, 1102], [1, 1102], [1, 1, 4]] | [1, 100, 1]
25.9 K    Trainable params
0         Non-trainable params
25.9 K    Total params
0.104     Total estimated model params size (MB)
/usr/local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/ It is recommended to use `self.log('train_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
Epoch 0, global step 50: 'train_loss' reached 0.75536 (best 0.75536), saving model to '/opt/ml/code/lightning_logs/version_0/checkpoints/epoch=0-step=50.ckpt' as top 1
Epoch 1, global step 100: 'train_loss' reached 0.72144 (best 0.72144), saving model to '/opt/ml/code/lightning_logs/version_0/checkpoints/epoch=1-step=100.ckpt' as top 1

Error message or code output

The particular warning of interest is:

/usr/local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/ It is recommended to use `self.log('train_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


(Add as much information about your environment as possible, e.g. dependencies versions.)

admivsn commented 2 months ago

Updated this with some more info. Originally I thought it was just when using validation data however upon investigations it seems like its a wider issue