Open siegelaaron94 opened 1 year ago
If "augmentations" are done in the dataloader worker processes, we already handle perturbing the seed for each worker. This can, and in specific use cases like here, should also be applied to the main processes as @siegelaaron94 explains. I wouldn't add this as a LightningModule method, but rather add it to seed_everything
as an option.
seed_everything(
seed=111, # the base value for the seed
workers=True, # dataloader workers derive their seed from the base seed
per_device=True, # we can set the seed to e.g. base_seed + global_rank (or similar) like in the dl workers
)
Would need a suitable name for per_device
.
@awaelchli By "doing augmentation in the training loop," I mean stylegan2-ada and others like it. I think an API like the one you suggest would lead to more issues; if you did what you said most people would do
seed_everything(
seed=111, # the base value for the seed
workers=True, # dataloader workers derive their seed from the base seed
per_device=True, # we can set the seed to e.g. base_seed + global_rank (or similar) like in the dl workers
)
data = MyDataModule(...)
model = MyLightningModule(....)
model.fit(....)
and now each GPU would have different model initialization to use your code correctly you would have to do this
seed_everything(
seed=111, # the base value for the seed
workers=True, # dataloader workers derive their seed from the base seed
per_device=False
)
data = MyDataModule(...)
model = MyLightningModule(....)
seed_everything(
seed=111, # the base value for the seed
workers=True, # dataloader workers derive their seed from the base seed
per_device=True, # we can set the seed to e.g. base_seed + global_rank (or similar) like in the dl workers
)
model.fit(....)
and now each GPU would have different model initialization
It would have a different initialization, yes. That would be ok, because PyTorch broadcasts the weights in DDP before running the first forward. So you would never actually end up with any optimization that uses different weights. PyTorch can already guarantee that the weights are always the same (given the user does not perturb them manually during optimization). Does that make sense? Or do you have evidence that something isn't working in that area?
I think that I'm just not convinced that it has to be a method (stateful) on the module. I think a function would be much more suited for this. Something like this:
def _sample_latent(self, imgs):
augmentation_seed()
return torch.randn(imgs.shape[0], self.hparams.latent_dim)
(Optional) One could even make it a context manager so that it will not affect the random state of anything outside.
def _sample_latent(self, imgs):
with augmentation_seed():
return torch.randn(imgs.shape[0], self.hparams.latent_dim)
Bug description
If you run the GAN example model https://github.com/Lightning-AI/lightning/blob/master/examples/pl_domain_templates/generative_adversarial_net.py with DDP you are effectively training with only a single GPU because they are all sampling the same latent vector. The offending code is https://github.com/Lightning-AI/lightning/blob/3ff3ec3fdef92fa2f187f06eca41bf08dcc4eb19/examples/pl_domain_templates/generative_adversarial_net.py#L155 to fix this I do something like this in every GAN model I create.
It would be nice if lightning could either 1 warn the user about this or 2 (I think this would be better) after all models have been initialized, something like _augmentation_seed is called internally to the Lightning module. I call _augmentation_seed in training_step because I am not 100% sure when all models have been initialized. I also want to point out that this is not only restricted to GANs; it would also affect anyone doing augmentation in the training loop.
How to reproduce the bug
No response
Error messages and logs
Environment
Current environment
``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): #- PyTorch Lightning Version (e.g., 1.5.0): #- Lightning App Version (e.g., 0.5.2): #- PyTorch Version (e.g., 1.10): #- Python version (e.g., 3.9): #- OS (e.g., Linux): #- CUDA/cuDNN version: #- GPU models and configuration: #- How you installed Lightning(`conda`, `pip`, source): #- Running environment of LightningApp (e.g. local, cloud): ```More info
No response
cc @borda @awaelchli