QData / spacetimeformer

Multivariate Time Series Forecasting with efficient Transformers. Code for the paper "Long-Range Transformers for Dynamic Spatiotemporal Forecasting."
https://arxiv.org/abs/2109.12218
MIT License
781 stars 184 forks source link

Memory requirements to replicate on Pems-Bay #88

Open steve3nto opened 9 months ago

steve3nto commented 9 months ago

Thanks to the authors for the great work and for sharing the code!

I am interested in replicating the results on Pems-Bay before trying the model on a custom dataset of similar size. I am using the suggested command:

I set accelerator="gpu" in the Trainer and run the command from the README:

python train.py spacetimeformer pems-bay --batch_size 32 --warmup_steps 1000 --d_model 200 --d_ff 700 --enc_layers 5 --dec_layers 6 --dropout_emb .1 --dropout_ff .3 --run_name pems-bay-spatiotemporal --base_lr 1e-3 --l2_coeff 1e-3 --loss mae --data_path /path/to/pems_bay/ --d_qk 30 --d_v 30 --n_heads 10 --patience 10 --decay_factor .8

It gives an OOM error on my GPU (which has more than 20GB of space) How many GBs of GPU memory are required? Did you use multiple GPUs to train spacetimeformer on Pems-Bay?

I also tried keeping accelerator="dp" , but it gives this error:

Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=gloo
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

Traceback (most recent call last):
  File "train.py", line 873, in <module>
    main(args)
  File "train.py", line 851, in main
    trainer.fit(forecaster, datamodule=data_module)
  File "/home/simoscopi/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 771, in fit
    self._call_and_handle_interrupt(
  File "/home/simoscopi/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 722, in _call_and_handle_interrupt
    return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
  File "/home/simoscopi/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 93, in launch
    return function(*args, **kwargs)
  File "/home/simoscopi/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 812, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/simoscopi/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1218, in _run
    self.strategy.setup(self)
  File "/home/simoscopi/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp.py", line 172, in setup
    self.configure_ddp()
  File "/home/simoscopi/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp.py", line 294, in configure_ddp
    self.model = self._setup_model(LightningDistributedModule(self.model))
  File "/home/simoscopi/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp.py", line 178, in _setup_model
    return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)
  File "/home/simoscopi/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 650, in __init__
    self._ddp_init_helper(parameters, expect_sparse_gradient, param_to_name_mapping)
  File "/home/simoscopi/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 738, in _ddp_init_helper
    self._passing_sync_batchnorm_handle(self.module)
  File "/home/simoscopi/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1637, in _passing_sync_batchnorm_handle
    self._log_and_throw(
  File "/home/simoscopi/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 674, in _log_and_throw
    raise err_type(err_msg)
ValueError: SyncBatchNorm layers only work with GPU modules

I hope someone can help me run spacetimeformer in multi-gpu mode.

In case it is not possible, do you have some suggestion to reduce memory requirements while mantaining good performance? For now I could fit the model on a single GPU only by setting batch_size=4, but I fear this would lead to a bad model fit.

Thank you!

steve3nto commented 9 months ago

I managed to replicate the result on PEMS-Bay, using a batch_size of 4 it trains ok. I got a test/mae of around 1.59 logged on wandb.

Reading the paper, I saw that there are these suggestions for low memory scenarios:

I quote: "Most results in this paper were collected using fast attention alone with less than 40GBs of GPU memory. Strided convolutions are only necessary on the largest datasets." From this I infer that on PEMS-BAY strided convolution was used.

This part is not clear to me:

Shifted window attention saves meaningful amounts of memory when using quadratic attention, so we primarily use it when we need to mask padded sequences. When did you need to mask padded sequences? And why is it needed only for full quadratic attention?

jakegrigsby commented 9 months ago

I'm surprised batch size 4 worked that well that's interesting.

The pems-bay results in v3 of the paper (the current one) were run on multiple A100s. I don't think this was necessary, I just had compute at the end of the project and scaled way up to see what would happen. All the other results (including arxiv v1 and v2 pems-bay) were on more accessible GPUs. I meant to circle back and publish training commands for low-gpu settings but this project dragged on in peer review so long I switched institutions and research topics by the time it was over and couldn't return to it :)

Shifted window attention saves meaningful amounts of memory when using quadratic attention, so we primarily use it when we need to mask padded sequences. When did you need to mask padded sequences? And why is it needed only for full quadratic attention?

The shifted window attention moves memory requirements from the length dimension (where it's normally quadratic) to the batch dimension (where it's linear). So if you are using a linear attention approximation like performer, you really aren't saving much memory here overall.

We mask padded sequences for mixed-length datasets which are included in the codebase (m4, wikipedia) but are not heavily discussed in the paper. Basically if your context sequences have mixed lengths, the flattened spatiotemporal sequences require an unusual attention mask that does not fit most efficient attention implementations. In this situation it makes sense to revert to vanilla quadratic attention where we can easily use any mask we'd like, and reduce computation with shifted windows.

If you are actively working on this now, I think it'd be an interesting update to do this with Flash Attention instead of Vanilla Attention. You still save compute with shifted windows and the performance gains are significant enough where I think it is hard to justify approximations like Performer today - at least at these sequence lengths.