IST-DASLab / torch_cgx

Pytorch distributed backend extension with compression support
GNU Affero General Public License v3.0
17 stars 0 forks source link

Correct use of new API for WGAN training? #3

Closed igor-krawczuk closed 1 year ago

igor-krawczuk commented 1 year ago

Hi, we had previously done experiments with the code version of august2022, which is now offline. We are training a WGAN-GP with simulatneous ExtraAdam, and I was wondering what would be the correct usage of the new api. I'm currently trying with

  1. 2 CGXstates, one per optimizer/model
  2. use of DDP.no_sync() when computing gradient penalties and everything except the final .backward on both models
  3. detaching of the shared tensor paths (i.e. disc loss is computed on detached images from generator)

However, we continuously get the error


    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Got the following error when running the callback: RuntimeError: Error at extracting the layers from bucket. Number of elements in bucket is not equal to number in the layers expected to be in the bucket.```

Do you have any recommendations? My current theory is that it is confused by the two states, but a single state gave the same error?
ilmarkov commented 1 year ago

Hi. CGX does not support more than one CGXState and it is only used in all_reduce hook. I don't think that several allreduce hooks make any sense in DDP context. We developed CGX for DDP-AllReduce-like workload so we expect that only gradients are synchronized.

If you don't need breaking the communicated buckets into separate layers and filtering of small layers, you can avoid using CGXState and control compression with environment variables. If you don't want to compress some parts of your communication, you can create another torch.distributed ProcessGroup (e.g. with nccl backend) and use it for this type of communications.

igor-krawczuk commented 1 year ago

Hi, okay, so if I want to compress discriminator gradients and generator gradients one after another, it should work without using the state? Or would it still expect all tensors to be sent every time and possibly error out if it only encounters e.g. the discriminator backwards call?

ilmarkov commented 1 year ago

Without CGXState CGX will compress all tensors (except super small, with less than 16 elements) that are allreduced (for which torch.distributed.all_reduce is called).

igor-krawczuk commented 1 year ago

Thanks, removing the state and using the environment variables did the trick