unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
7.96k stars 866 forks source link

Processes are terminated in multi-GPU setting when using multiple models and seeds #2519

Closed KunzstBGR closed 1 month ago

KunzstBGR commented 1 month ago

Hi, When comparing multiple models and multiple seeds using a nested loop, all processes are terminated when the loop switches from one model class to the next. Does anyone have an idea why? Maybe I'm doing this wrong. Or is this a pytorch-lightning issue?

Error message: Child process with PID 652 terminated with code 1. Forcefully terminating all other processes to avoid zombies

Relevant code snippet:

# ...
from pytorch_lightning import seed_everything
from pytorch_lightning.strategies import DDPStrategy

def create_params(input_chunk_length,
                  output_chunk_length, 
                  quantiles,
                  batch_size,
                  n_epochs,
                  dropout):

        # ...

        pl_trainer_kwargs = {
                         'strategy':DDPStrategy(process_group_backend='gloo', accelerator='gpu'),
                         'devices':4
                         #...
                         }
       # ...

def dl_model_training(df, 
                      seeds, 
                      input_chunk_length,
                      output_chunk_length, 
                      quantiles,
                      batch_size,
                      n_epochs,
                      dropout):

  # Some data processing ...

  for model_arch, model_class in [('NHiTS', NHiTSModel), ('TiDE', TiDEModel), ('TFT', TFTModel)]:       
           for i in seeds: 
              # Set the seed
              seed_everything(i, workers=True)

              # Define the model name with seed
              model_arch_seed = f'{model_arch}_gws_{i}'

              # Train the model
              model = model_class(
                      **create_params(
                          input_chunk_length,
                          output_chunk_length, 
                          quantiles,
                          batch_size,
                          n_epochs,
                          dropout
                      ), 
                      model_name=model_arch_seed,
                      work_dir=os.path.join(MODEL_PATH, model_arch)
                  )

              # Fit the model
              model.fit(
                        series=train_gws, 
                        past_covariates=train_cov,
                        future_covariates=train_cov if model_arch in ['TFT', 'TiDE'] else None,
                        val_series=val_gws, 
                        val_past_covariates=val_cov,
                        val_future_covariates=val_cov if model_arch in ['TFT', 'TiDE'] else None,
                        verbose=True
                      ) 

              # Clean up to prevent memory issues
              del model
              gc.collect()
              torch.cuda.empty_cache() 

if __name__ == '__main__':
     torch.multiprocessing.freeze_support()
     dl_model_training(df=gws_bb_subset, 
                       seeds=seeds,
                       input_chunk_length=52,
                       output_chunk_length=16, 
                       quantiles=None, 
                       batch_size=4096,
                       n_epochs=10,
                       dropout=0.2)
madtoinou commented 1 month ago

Hi @KunzstBGR,

This issue seems to be come from PytorchLightning and not Darts.

It might also arise from the fact that you use multi-gpu. Can you check if it persists when you use devices=[0]?

Have you tried to change num_nodes parameters of DDP? (based on pytorch doc)

Also, is it normal that you don't save checkpoints or generate any kind of forecasts in your code snippet?

KunzstBGR commented 1 month ago

Hi @madtoinou , thanks for your quick response!

madtoinou commented 1 month ago

Nice, I would not be able to tell why swapping the order of the loops fixed it but as long as it works, it's great!

All good if you save the checkpoints and perform evaluation in a separate loop, I was just curious since it was not visible in the code snippet. It's indeed better to do it separately.

If the issue is solved, can you please close it?