QData / spacetimeformer

Multivariate Time Series Forecasting with efficient Transformers. Code for the paper "Long-Range Transformers for Dynamic Spatiotemporal Forecasting."
https://arxiv.org/abs/2109.12218
MIT License
808 stars 191 forks source link

GPU memory leak #23

Closed Suppersine closed 2 years ago

Suppersine commented 2 years ago

After 80 epochs (8 hours), I got this error

Epoch 80:  87%|████████████████▌  | 750/858 [04:56<00:42,  2.53it/s, loss=0.258]Traceback (most recent call last):
  File "train.py", line 442, in <module>
    main(args)
  File "train.py", line 424, in main
    trainer.fit(forecaster, datamodule=data_module)
  File "/home/u7701783/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 460, in fit
    self._run(model)
  File "/home/u7701783/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 758, in _run
    self.dispatch()
  File "/home/u7701783/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 799, in dispatch
    self.accelerator.start_training(self)
  File "/home/u7701783/.local/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 96, in start_training
    self.training_type_plugin.start_training(trainer)
  File "/home/u7701783/.local/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 144, in start_training
    self._results = trainer.run_stage()
  File "/home/u7701783/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 809, in run_stage
    return self.run_train()
  File "/home/u7701783/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 871, in run_train
    self.train_loop.run_training_epoch()
  File "/home/u7701783/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 569, in run_training_epoch
    self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output)
  File "/home/u7701783/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py", line 325, in log_train_epoch_end_metrics
    self.log_metrics(epoch_log_metrics, {})
  File "/home/u7701783/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py", line 208, in log_metrics
    mem_map = memory.get_memory_profile(self.log_gpu_memory)
  File "/home/u7701783/.local/lib/python3.8/site-packages/pytorch_lightning/core/memory.py", line 365, in get_memory_profile
    memory_map = get_gpu_memory_map()
  File "/home/u7701783/.local/lib/python3.8/site-packages/pytorch_lightning/core/memory.py", line 384, in get_gpu_memory_map
    result = subprocess.run(
  File "/opt/conda/lib/python3.8/subprocess.py", line 516, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['/usr/bin/nvidia-smi', '--query-gpu=memory.used', '--format=csv,nounits,noheader']' returned non-zero exit status 255.
Epoch 80:  87%|████████▋ | 750/858 [04:57<00:42,  2.52it/s, loss=0.258]         
jakegrigsby commented 2 years ago

I think the only dataset I've ever run for that many epochs is the toy dset, but I will look into this. Is this a very small dataset, or are you using large regularization settings? If you aren't using --early_stopping you may want to try that to prevent long training times.

I'm not 100% sure if that is an issue with GPU memory itself or maybe in the logging of gpu memory stats to wandb...

jhillhouse92 commented 2 years ago

Python is refcounted (see https://discuss.pytorch.org/t/cuda-out-of-memory-on-the-8th-epoch/67288). This is not an issue with GPU memory as the python process would entirely crash as OOM and often written to /var/log/syslog. In other words, you wouldn't get the returned non-zero exit status 255. error message.

This is triggered from the flag log_gpu_memory in the train.py Pytorch Lightning trainer. This is deprecated and if you still want to log the metrics, try updating to using DeviceStatsMonitor. See https://pytorch-lightning.readthedocs.io/en/stable/extensions/generated/pytorch_lightning.callbacks.DeviceStatsMonitor.html?highlight=DeviceStatsMonitor#pytorch_lightning.callbacks.DeviceStatsMonitor.

Other options include just taking out that flag and use another bash window to periodically execute nvidia-smi query yourself. I couldn't find specific examples of why nvidia-smi would trigger that error but it's likely not associated with GPU OOM.

jakegrigsby commented 2 years ago

@jhillhouse92 Thank you, that's about what I figured. I'll just remove that option in the public version, it's only in there to help me pick hparams for each dataset. The best model is usually the biggest one that fits in memory. As long as the model doesn't crash with a clear OOM error on the first backward pass it's fine for most situations... not a super important thing to log.