jasonkyuyim / se3_diffusion

Implementation for SE(3) diffusion model with application to protein backbone generation
https://arxiv.org/abs/2302.02277
MIT License
305 stars 50 forks source link

add DDP support,main change:1.add get_ddp_info() in Experiment class.… #17

Closed Wangchentong closed 1 year ago

Wangchentong commented 1 year ago

Hi,i would like to provide DDP mode support for FrameDiff, and i has test it to stably run. Here is some main change:

1.modified Experiment class :

2.careat distributedSample to cooperate train sampler with torch.utils.data.distributed.DistributedSampler

Another important but small change:     process_csv_row(self, processed_file_path) owns a annotattion @fn.lru_cache(maxsize=50000), which cost a huge memory, it is ok when use DP mode since this is only one dataloader(which is also quite expensive because of each dataloader worker does not share this memory cache), for DP and 10 num_worker, it can cost more than 200G memory. In DDP mode, since each GPU owns a dataset, this memory cost is multiplied by gpu_num which cant be affored. So i change it to a samller number @fn.lru_cache(maxsize=100), there is not any speed down.

Wangchentong commented 1 year ago

Now training code can run use torchrun on both cpu and gpu,: torchrun --nproc_per_node 4 experiments/train_se3_diffusion.py experiment.use_ddp=True experiment.num_gpus=4

Wangchentong commented 1 year ago

Now i assume all logic is quite clear. I have test it both in DP and DDP mode. Here is test command.

DP mode:      torchrun --nproc_per_node 1 experiments/train_se3_diffusion.py experiment.num_gpus=2 experiment.use_ddp=False experiment.use_wandb=False experiment.ckpt_dir=null experiment.eval_dir=null

Output:

[2023-05-09 13:08:47,557][data.so3_diffuser][INFO] - Using cached IGSO3 in .cache/eps_1000_omega_1000_min_sigma_0_1_max_sigma_1_5_schedule_logarithmic
[2023-05-09 13:08:48,292][__main__][INFO] - Number of model parameters 17446190
[2023-05-09 13:08:48,293][__main__][INFO] - Checkpoint not being saved.
[2023-05-09 13:08:48,293][__main__][INFO] - Evaluation will not be saved.
[2023-05-09 13:08:48,326][__main__][INFO] - Multi-GPU training on GPUs in DP mode: ['cuda:0', 'cuda:1']
[2023-05-09 13:08:50,362][data.pdb_data_loader][INFO] - Training: 21737 examples
[2023-05-09 13:08:50,689][data.pdb_data_loader][INFO] - Validation: 40 examples with lengths [ 60 108 157 206 255 304 353 402 451 512]
[2023-05-09 13:08:55,054][__main__][INFO] - [1]:  total_loss=12.5422 rot_loss=2.9649 trans_loss=3.5666 bb_atom_loss=3.3127 dist_mat_loss=2.6980 examples_per_step=20.0000 res_length=122.0000, steps/sec=329.15038steps/sec=229.48265

DDP mode torchrun --nproc_per_node 2 experiments/train_se3_diffusion.py experiment.num_gpus=2 experiment.use_wandb=False experiment.ckpt_dir=null experiment.use_ddp=True experiment.eval_dir=null

Output

[2023-05-09 13:11:54,462][torch.distributed.distributed_c10d][INFO] - Added key: store_based_barrier_key:1 to store for rank: 0
[2023-05-09 13:11:54,462][torch.distributed.distributed_c10d][INFO] - Added key: store_based_barrier_key:1 to store for rank: 1
[2023-05-09 13:11:54,463][torch.distributed.distributed_c10d][INFO] - Rank 1: Completed store-based barrier for key:store_based_barrier_key:1 with 2 nodes.
[2023-05-09 13:11:54,463][torch.distributed.distributed_c10d][INFO] - Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 2 nodes.
[2023-05-09 13:11:54,464][data.so3_diffuser][INFO] - Using cached IGSO3 in .cache/eps_1000_omega_1000_min_sigma_0_1_max_sigma_1_5_schedule_logarithmic
[2023-05-09 13:11:54,464][data.so3_diffuser][INFO] - Using cached IGSO3 in .cache/eps_1000_omega_1000_min_sigma_0_1_max_sigma_1_5_schedule_logarithmic
[2023-05-09 13:11:55,174][__main__][INFO] - Number of model parameters 17446190
[2023-05-09 13:11:55,175][__main__][INFO] - Checkpoint not being saved.
[2023-05-09 13:11:55,175][__main__][INFO] - Evaluation will not be saved.
[2023-05-09 13:11:57,275][__main__][INFO] - Multi-GPU training on GPUs in DDP mode, node_id : 0, devices: ['cuda:0', 'cuda:1']
[2023-05-09 13:11:57,685][data.pdb_data_loader][INFO] - Training: 21737 examples
[2023-05-09 13:11:57,685][data.pdb_data_loader][INFO] - Training: 21737 examples
[2023-05-09 13:11:58,041][data.pdb_data_loader][INFO] - Validation: 40 examples with lengths [ 60 108 157 206 255 304 353 402 451 512]
[2023-05-09 13:11:58,043][data.pdb_data_loader][INFO] - Validation: 40 examples with lengths [ 60 108 157 206 255 304 353 402 451 512]
[2023-05-09 13:12:01,086][__main__][INFO] - [1]: total_loss=12.7105 rot_loss=3.0103 trans_loss=2.9752 bb_atom_loss=3.6877 dist_mat_loss=3.0373 examples_per_step=20.0000 res_length=122.0000, steps/sec=138.46955

Ram memory use is also quite stable with smaller lrucache size: image

This is memory use even for DP mode before update(nearly out of memory with 10 worker: image