Closed ImahnShekhzadeh closed 3 months ago
Hi, thank you for your contribution!
I had an internal implementation with fabric form litghtning but I like to rely only on PyTorch for this example. I need some time to review it (a few days/weeks). I will come back to it soon.
I like the new changes. @atong01 do you mind having a look? I also think it would be great to keep the original train_cifar10.py.
While I like this code, it is slightly more complicated than the previous one. So I would keep both. The idea of this package is that any master student can easily understand it in 1hour. @ImahnShekhzadeh can you rename this file train_cifar10_ddp.py please? and re-add the previous file? Thanks
@ImahnShekhzadeh can you rename this file train_cifar10_ddp.py please? and re-add the previous file? Thanks
Done
LGTM. Thanks for the contribution @ImahnShekhzadeh
This PR adds support for distributed data parallel (DDP) and replaces
DataParallel
withDistributedDataParallel
intrain_cifar.py
, which can be used via the flagparallel
. To achieve this, the code is refactored, and the flagsmaster_addr
andmaster_port
are added.I tested the changes, on a single GPU, I get an FID of 3.74 (with the OT-CFM method), on two GPUs with DDP, I get an FID of 3.81.
Before submitting
pytest
command?pre-commit run -a
command?