skorch-dev / skorch

A scikit-learn compatible neural network library that wraps PyTorch
BSD 3-Clause "New" or "Revised" License
5.82k stars 388 forks source link

DistributedDataParallel and Early stopping does not work together #737

Open thomasjpfan opened 3 years ago

thomasjpfan commented 3 years ago

I had a recent conversation with a user that tried to use DistributedDataParallel with skorch's early stopping and this would cause the process to hang. My guess is that since ddp workers spawn their own jobs, skorch's early stopping mechanism would stop a worker, but the parent node would not get this information. This leaves the parent waiting for a child that has stopping running.

There may also be an issue with checking the validation loss with DistributedDataParallel, because each worker would have its own loss, and this would need to be gathered to actually compute the loss for a given epoch.

BenjaminBossan commented 3 years ago

Thanks for reporting. Do you know how this is solved more generally (say, only using PyTorch without any frameworks)? I could imagine that similar errors can occur easily, given how tricky multi-threading is in general. Unfortunately, I don't have access to a setup to experiment with this.

thomasjpfan commented 3 years ago

I do not have a setup to experiment with this either. I've seen two solutions.

  1. During validation, move everything to one gpu and compute the loss/metrics there: https://github.com/Lance0218/Pytorch-DistributedDataParallel-Training-Tricks/blob/fa709835c7bf5e62f48c72b90eb12f3b795ef07d/DDP_warmup.py#L140-L151
  2. During validation, distribute the data to all gpus, use a barrier to wait for validation is complete:

https://github.com/allenai/allennlp/blob/39c40fe38cd2fd36b3465b0b3c031f54ec824160/allennlp/training/trainer.py#L1022-L1025

BTW there are a bunch of barrier calls in this file to handle the distributed case.

Since we do not have the resources to test DDP, I think it would be hard to officially support it.

BenjaminBossan commented 3 years ago

I know too little to really comment on that. Ideally, I would wish for skorch to get out of the way enough that users can use DistributedDataParallel if they wish so. Regarding barriers, is that something that could be achieved through callbacks?