Project-MONAI / GenerativeModels

MONAI Generative Models makes it easy to train, evaluate, and deploy generative models and related applications
Apache License 2.0
555 stars 78 forks source link

Batchnorm for PatchDiscriminator running in DDP #451

Closed sRassmann closed 5 months ago

sRassmann commented 6 months ago

Thanks for this amazing work, helps a lot in accelerating experiments!

I tried training a AE using PatchDiscriminator and ran into this issue, when switching to DistributedDataParallel (DDP)

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [512]] is at version 4; expected version 3 instead.

running it with torch.autograd.set_detect_anomaly(True) gives:

UserWarning: Error detected in CudnnBatchNormBackward0

With some troubleshooting I found that the issue is the BatchNorm. So running

discriminator = PatchDiscriminator(**kwargs)
torch.nn.SyncBatchNorm.convert_sync_batchnorm(discriminator)

solves it.

Might it be worthwhile to put this into the constructor of the PatchDiscriminator to avoid similar issues in the future? e.g.

class PatchDiscriminator(nn.Sequential):
    def __init__(**kwargs) -> None:
        super().__init__()
        [...]
        self.apply(self.initialise_weights)

        torch.nn.SyncBatchNorm.convert_sync_batchnorm(self)
marksgraham commented 6 months ago

Hi @sRassmann

I find that using convert_sync_batchnorm in init causes errors if I am not training using DDP, so it doesn't seem like a general solution. I was thinking about printing a warning when users initialise the model with batchnorm and we detect a distributed environment, something like:

self.apply(self.initialise_weights)
if norm.lower() == 'batch' and torch.distributed.is_initialized():
    print("WARNING: Discriminator is using BatchNorm and a distributed training environment has been detected. "
          "To train with DDP, convert discriminator to SyncBatchNorm using "
          "torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).)")

what do yo think?

sRassmann commented 6 months ago

Thanks for having a look. Weird, it didn't give me an error when I tried on non-DDP. But sure, rather just give a warning in case of DDP than potentially messing things up for everyone.