atong01 / conditional-flow-matching

TorchCFM: a Conditional Flow Matching library
https://arxiv.org/abs/2302.00482
MIT License
1.27k stars 103 forks source link

add support for distributed data parallel training #116

Closed ImahnShekhzadeh closed 3 months ago

ImahnShekhzadeh commented 6 months ago

This PR adds support for distributed data parallel (DDP) and replaces DataParallel with DistributedDataParallel in train_cifar.py, which can be used via the flag parallel. To achieve this, the code is refactored, and the flags master_addr and master_port are added.

I tested the changes, on a single GPU, I get an FID of 3.74 (with the OT-CFM method), on two GPUs with DDP, I get an FID of 3.81.

Before submitting

kilianFatras commented 6 months ago

Hi, thank you for your contribution!

I had an internal implementation with fabric form litghtning but I like to rely only on PyTorch for this example. I need some time to review it (a few days/weeks). I will come back to it soon.

kilianFatras commented 3 months ago

I like the new changes. @atong01 do you mind having a look? I also think it would be great to keep the original train_cifar10.py.

While I like this code, it is slightly more complicated than the previous one. So I would keep both. The idea of this package is that any master student can easily understand it in 1hour. @ImahnShekhzadeh can you rename this file train_cifar10_ddp.py please? and re-add the previous file? Thanks

ImahnShekhzadeh commented 3 months ago

@ImahnShekhzadeh can you rename this file train_cifar10_ddp.py please? and re-add the previous file? Thanks

Done

atong01 commented 3 months ago

LGTM. Thanks for the contribution @ImahnShekhzadeh