Open martenlienen opened 3 years ago
I just had this bug causing validation to hang when used with DistributedDataParallel. DistributedDataParallel checks torch.is_grad_enabled() when determining whether to perform synchronization before forward pass. Geomloss turns grad_enabled on after the first iteration, causing threads to hang waiting for synchronization.
Indiscriminately enabling autograd at the end of the loop also enables it when the user had explicitly disabled it before. This is a common occurrence when the loss is computed over a validation set.
In my application this would stop the training with an out of memory error because autograd on the validation data quickly exhausts memory while it is fine in the training loop (sequential data with validation on full sequences but training on subsequences).