Batchnorm for PatchDiscriminator running in DDP #451

sRassmann commented 6 months ago

Thanks for this amazing work, helps a lot in accelerating experiments!

I tried training a AE using PatchDiscriminator and ran into this issue, when switching to DistributedDataParallel (DDP)

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [512]] is at version 4; expected version 3 instead.

running it with torch.autograd.set_detect_anomaly(True) gives:

UserWarning: Error detected in CudnnBatchNormBackward0

With some troubleshooting I found that the issue is the BatchNorm. So running

discriminator = PatchDiscriminator(**kwargs)

solves it.

Might it be worthwhile to put this into the constructor of the PatchDiscriminator to avoid similar issues in the future? e.g.

class PatchDiscriminator(nn.Sequential):
    def __init__(**kwargs) -> None:

marksgraham commented 6 months ago

Hi @sRassmann

I find that using convert_sync_batchnorm in init causes errors if I am not training using DDP, so it doesn't seem like a general solution. I was thinking about printing a warning when users initialise the model with batchnorm and we detect a distributed environment, something like:

if norm.lower() == 'batch' and torch.distributed.is_initialized():
    print("WARNING: Discriminator is using BatchNorm and a distributed training environment has been detected. "
          "To train with DDP, convert discriminator to SyncBatchNorm using "

what do yo think?

sRassmann commented 6 months ago

Thanks for having a look. Weird, it didn't give me an error when I tried on non-DDP. But sure, rather just give a warning in case of DDP than potentially messing things up for everyone.