aqlaboratory / openfold

Trainable, memory-efficient, and GPU-friendly PyTorch reproduction of AlphaFold 2
Apache License 2.0
2.72k stars 509 forks source link

CUDA out of memory during training with A100 GPUs #473

Open abhinavb22 opened 1 month ago

abhinavb22 commented 1 month ago

I am trying to finetune AF-multimer and I start with a mock dataset of just 3 proteins. I get the following error:

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.95 GiB. GPU 0 has a total capacty of 39.39 GiB of which 1.90 GiB is free. Including non-PyTorch memory, this process has 37.48 GiB memory in use. Of the allocated memory 35.28 GiB is allocated by PyTorch, and 1.44 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Here is my full error::

[rank: 0] Seed set to 77843 WARNING:root:load from versionmodel_1_multimer_v3 GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set torch.set_float32_matmul_precision('medium' | 'high') which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision [rank: 0] Seed set to 77843 initializing deepspeed distributed: GLOBAL_RANK: 0, MEMBER: 1/1 /work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:652: Checkpoint directory /work/10110/abhinav22/ls6/openfold_test/training/output_dir/checkpoints exists and is not empty. Enabling DeepSpeed BF16. Model parameters and inputs will be cast to bfloat16. LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]

| Name | Type | Params | Mode

0 | model | AlphaFold | 93.2 M | train 1 | loss | AlphaFoldLoss | 0 | train

93.2 M Trainable params 0 Non-trainable params 93.2 M Total params 372.895 Total estimated model params size (MB) SLURM auto-requeueing enabled. Setting signal handlers. /work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/module.py:518: You called self.log('train/distogram', ..., logger=True) but have no logger configured. You can enable one by doing Trainer(logger=ALogger(...)) /work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/module.py:518: You called self.log('train/distogram_epoch', ..., logger=True) but have no logger configured. You can enable one by doing Trainer(logger=ALogger(...)) /work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/module.py:518: You called self.log('train/experimentally_resolved', ..., logger=True) but have no logger configured. You can enable one by doing Trainer(logger=ALogger(...)) /work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/module.py:518: You called self.log('train/experimentally_resolved_epoch', ..., logger=True) but have no logger configured. You can enable one by doing Trainer(logger=ALogger(...)) /work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/module.py:518: You called self.log('train/fape', ..., logger=True) but have no logger configured. You can enable one by doing Trainer(logger=ALogger(...)) /work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/module.py:518: You called self.log('train/fape_epoch', ..., logger=True) but have no logger configured. You can enable one by doing Trainer(logger=ALogger(...)) /work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/module.py:518: You called self.log('train/plddt_loss', ..., logger=True) but have no logger configured. You can enable one by doing Trainer(logger=ALogger(...)) /work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/module.py:518: You called self.log('train/plddt_loss_epoch', ..., logger=True) but have no logger configured. You can enable one by doing Trainer(logger=ALogger(...)) /work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/module.py:518: You called self.log('train/masked_msa', ..., logger=True) but have no logger configured. You can enable one by doing Trainer(logger=ALogger(...)) /work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/module.py:518: You called self.log('train/masked_msa_epoch', ..., logger=True) but have no logger configured. You can enable one by doing Trainer(logger=ALogger(...)) /work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/module.py:518: You called self.log('train/supervised_chi', ..., logger=True) but have no logger configured. You can enable one by doing Trainer(logger=ALogger(...)) /work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/module.py:518: You called self.log('train/supervised_chi_epoch', ..., logger=True) but have no logger configured. You can enable one by doing Trainer(logger=ALogger(...)) /work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/module.py:518: You called self.log('train/violation', ..., logger=True) but have no logger configured. You can enable one by doing Trainer(logger=ALogger(...)) /work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/module.py:518: You called self.log('train/violation_epoch', ..., logger=True) but have no logger configured. You can enable one by doing Trainer(logger=ALogger(...)) /work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/module.py:518: You called self.log('train/tm', ..., logger=True) but have no logger configured. You can enable one by doing Trainer(logger=ALogger(...)) /work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/module.py:518: You called self.log('train/tm_epoch', ..., logger=True) but have no logger configured. You can enable one by doing Trainer(logger=ALogger(...)) /work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/module.py:518: You called self.log('train/chain_center_of_mass', ..., logger=True) but have no logger configured. You can enable one by doing Trainer(logger=ALogger(...)) /work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/module.py:518: You called self.log('train/chain_center_of_mass_epoch', ..., logger=True) but have no logger configured. You can enable one by doing Trainer(logger=ALogger(...)) /work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/module.py:518: You called self.log('train/unscaled_loss', ..., logger=True) but have no logger configured. You can enable one by doing Trainer(logger=ALogger(...)) /work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/module.py:518: You called self.log('train/unscaled_loss_epoch', ..., logger=True) but have no logger configured. You can enable one by doing Trainer(logger=ALogger(...)) /work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/module.py:518: You called self.log('train/loss', ..., logger=True) but have no logger configured. You can enable one by doing Trainer(logger=ALogger(...)) /work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/module.py:518: You called self.log('train/loss_epoch', ..., logger=True) but have no logger configured. You can enable one by doing Trainer(logger=ALogger(...)) /work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/module.py:518: You called self.log('train/lddt_ca', ..., logger=True) but have no logger configured. You can enable one by doing Trainer(logger=ALogger(...)) /work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/module.py:518: You called self.log('train/drmsd_ca', ..., logger=True) but have no logger configured. You can enable one by doing Trainer(logger=ALogger(...)) Traceback (most recent call last): File "/work/10110/abhinav22/ls6/openfold/train_openfold.py", line 706, in main(args) File "/work/10110/abhinav22/ls6/openfold/train_openfold.py", line 455, in main trainer.fit( File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit call._call_and_handle_interrupt( File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, kwargs) File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 105, in launch return function(*args, *kwargs) File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 579, in _fit_impl self._run(model, ckpt_path=ckpt_path) File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 986, in _run results = self._run_stage() File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1030, in _run_stage self.fit_loop.run() File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 205, in run self.advance() File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 363, in advance self.epoch_loop.run(self._data_fetcher) File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 140, in run self.advance(data_fetcher) File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 250, in advance batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs) File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 190, in run self._optimizer_step(batch_idx, closure) File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 268, in _optimizer_step call._call_lightning_module_hook( File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 159, in _call_lightning_module_hook output = fn(args, kwargs) File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1308, in optimizer_step optimizer.step(closure=optimizer_closure) File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py", line 153, in step step_output = self._strategy.optimizer_step(self._optimizer, closure, kwargs) File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/strategies/ddp.py", line 270, in optimizer_step optimizer_output = super().optimizer_step(optimizer, closure, model, kwargs) File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 238, in optimizer_step return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, kwargs) File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/deepspeed.py", line 129, in optimizer_step closure_result = closure() File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 144, in call self._result = self.closure(*args, *kwargs) File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(args, kwargs) File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 138, in closure self._backward_fn(step_output.closure_loss) File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 239, in backward_fn call._call_strategy_hook(self.trainer, "backward", loss, optimizer) File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 311, in _call_strategy_hook output = fn(*args, kwargs) File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 212, in backward self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, *kwargs) File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/deepspeed.py", line 117, in backward deepspeed_engine.backward(tensor, args, kwargs) File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn ret_val = func(*args, *kwargs) File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1944, in backward self.optimizer.backward(loss, retain_graph=retain_graph) File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 2019, in backward self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward scaled_loss.backward(retain_graph=retain_graph) File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward torch.autograd.backward( File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/torch/autograd/init.py", line 251, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/torch/autograd/function.py", line 288, in apply return user_fn(self, args) File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 691, in backward torch.autograd.backward(output_tensors, grad_tensors) File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/torch/autograd/init.py", line 251, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.95 GiB. GPU 0 has a total capacty of 39.39 GiB of which 1.90 GiB is free. Including non-PyTorch memory, this process has 37.48 GiB memory in use. Of the allocated memory 35.28 GiB is allocated by PyTorch, and 1.44 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

The following are my input parameters:

export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:256 DATA_DIR=/work/10110/abhinav22/ls6/openfold_test/training TEMPLATE_MMCIF_DIR=/work/10110/abhinav22/ls6/database_2024_07_06/pdb_mmcif CHECKPOINT_PATH=openfold/resources/params/params_model_1_multimer_v3.npz CACHE_DIR=$DATA_DIR/cache_dir

python3 train_openfold.py $DATA_DIR/mmcif_dir/train $DATA_DIR/alignments_dir/train $TEMPLATE_MMCIF_DIR/mmcif_files $DATA_DIR/output_dir 2018-04-30 \ --config_preset model_1_multimer_v3 \ --template_release_dates_cache_path $CACHE_DIR/mmcif_cache.json \ --seed 77843 \ --obsolete_pdbs_file_path $TEMPLATE_MMCIF_DIR/obsolete.dat \ --num_nodes 1 \ --resume_from_jax_params $CHECKPOINT_PATH \ --resume_model_weights_only False \ --train_mmcif_data_cache_path $CACHE_DIR/train_mmcif_data_cache.json \ --val_mmcif_data_cache_path $CACHE_DIR/val_mmcif_data_cache.json \ --val_data_dir $DATA_DIR/mmcif_dir/val \ --val_alignment_dir $DATA_DIR/alignments_dir/val \ --gpus 3 \ --train_epoch_len 5 \ --max_epochs 1 \ --checkpoint_every_epoch \ --precision bf16-mixed \ --deepspeed_config_path ./deepspeed_config.json

However, this training runs on H100 GPUs, but only a problem with 40GB A100s. Is there a workaround for this?

MarjanHJ commented 1 month ago

hi

what is your config["evoformer_stack"]['blocks_per_ckpt'] ?, should be 1

abhinavb22 commented 1 month ago

Yes, it is one. I have not changed the config file. Here is what it says:

if train:
    c.globals.blocks_per_ckpt = 1

"evoformer_stack": { "c_m": c_m, "c_z": c_z, "c_hidden_msa_att": 32, "c_hidden_opm": 32, "c_hidden_mul": 128, "c_hidden_pair_att": 32, "c_s": c_s, "no_heads_msa": 8, "no_heads_pair": 4, "no_blocks": 48, "transition_n": 4, "msa_dropout": 0.15, "pair_dropout": 0.25, "no_column_attention": False, "opm_first": False, "fuse_projection_weights": False, "blocks_per_ckpt": blocks_per_ckpt, "clear_cache_between_blocks": False, "tune_chunk_size": tune_chunk_size, "inf": 1e9, "eps": eps, # 1e-10,