jasonkyuyim / se3_diffusion

Implementation for SE(3) diffusion model with application to protein backbone generation
https://arxiv.org/abs/2302.02277
MIT License
332 stars 54 forks source link

Does `DistributedTrainSampler` also need to be updated for `cluster` support? #25

Closed amorehead closed 1 year ago

amorehead commented 1 year ago

https://github.com/jasonkyuyim/se3_diffusion/blob/5be5b367bda4cc075fc9edc195188251acb846fe/experiments/train_se3_diffusion.py#L203

Hello. I was curious if the DistributedTrainSampler needs to be updated in response to the most recent changes to TrainSampler for clustering support. A related question is, does the batch_size=256 choice in base.yml correspond to training with a single A100 (80GB) GPU (instead of two A100s as originally done), by chance? If so, that would explain how and why DDP is currently not being used (since DistributedTrainSampler appears to not support clustering yet).

jasonkyuyim commented 1 year ago

Hi, I don't use DistributedTrainSampler so I can't help much. I think just copying over the clustering code should work. (Ideally the samplers would inherit from a base sampler.) batch_size=256 was used with 2 40gb A100s so in theory a single 80gb A100 should work.

amorehead commented 1 year ago

Thanks for your quick reply. I think, going forward, it may make sense to replace DistributedTrainSampler with Lightning Fabric's DistributedSamplerWrapper. This neat class allows you to pass it an arbitrary custom Sampler (like TrainSampler), and it will handle all distributed sharding of the underlying dataset's indices (as I understand it). I'll personally be trying this out in PyTorch Lightning to see if it works as expected.