Closed amorehead closed 1 year ago
Hi, I don't use DistributedTrainSampler
so I can't help much. I think just copying over the clustering code should work. (Ideally the samplers would inherit from a base sampler.) batch_size=256
was used with 2 40gb A100s so in theory a single 80gb A100 should work.
Thanks for your quick reply. I think, going forward, it may make sense to replace DistributedTrainSampler
with Lightning Fabric's DistributedSamplerWrapper
. This neat class allows you to pass it an arbitrary custom Sampler
(like TrainSampler
), and it will handle all distributed sharding of the underlying dataset's indices (as I understand it). I'll personally be trying this out in PyTorch Lightning to see if it works as expected.
https://github.com/jasonkyuyim/se3_diffusion/blob/5be5b367bda4cc075fc9edc195188251acb846fe/experiments/train_se3_diffusion.py#L203
Hello. I was curious if the
DistributedTrainSampler
needs to be updated in response to the most recent changes toTrainSampler
for clustering support. A related question is, does thebatch_size=256
choice inbase.yml
correspond to training with a single A100 (80GB) GPU (instead of two A100s as originally done), by chance? If so, that would explain how and why DDP is currently not being used (sinceDistributedTrainSampler
appears to not support clustering yet).