IST-DASLab / torch_cgx

Pytorch distributed backend extension with compression support
GNU Affero General Public License v3.0
17 stars 0 forks source link

"Error at extracting the layers from bucket." with the DDP communication hook, when there are other all_reduce calls outside of DDP #5

Open jrcavani opened 1 year ago

jrcavani commented 1 year ago

Hello,

I am using this library to speed up the card-to-card communication. I was able to run my code modified to support CGX, using the Dockerfile provided (based on pytorch:22.10-py3).

The code without CGX modification is here

Supplying the envvars worked fine:

mpirun --allow-run-as-root -np 8 -x CGX_COMPRESSION_QUANTIZATION_BITS=8 -x CGX_COMPRESSION_BUCKET_SIZE=128 -x CGX_INNER_COMMUNICATOR_TYPE=SHM python3 train_v2.py 

However, when I added the DDP communication hook

    if cfg.use_cgx:
        # In case the user want to perform layer-wise gradient compression, filter the small layer out of compression, CGX provides the allreduce hook for that.

        # User specifies the minimal size of the layer to compress. Then sets compression parameters.
        state = CGXState(torch.distributed.group.WORLD,
                            compression_params={"bits": 8, "bucket_size": 128})
        backbone.register_comm_hook(state, cgx_hook)

I got this error:

Traceback (most recent call last):
  File "train_v2.py", line 305, in <module>
    main(parser.parse_args())
  File "train_v2.py", line 258, in main
    loss: torch.Tensor = module_fc(local_embeddings_part, local_labels_part)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1185, in _call_impl
    return forward_call(*input, **kwargs)
  File "/training/rocket-trainer/partial_fc_v2.py", line 161, in forward
    loss = DistCrossEntropyFunc.apply(logits, labels)
  File "/training/rocket-trainer/distributed_autograd.py", line 109, in forward
    dist.all_reduce(sum_logits_exp, dist.ReduceOp.SUM)
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1388, in all_reduce
    work.wait()
RuntimeError: Error at extracting the layers from bucket. Number of elements in bucket is not equal to number in the layers expected to be in the bucket.

The code that throw the error is at DistCrossEntropyFunc. The line numbers and file names don't correspond exactly to the above exception, because I was running on internal code, but they are very similar for the purpose of illustration.

This trainer is essentially a big distributed classifier, similar to the ImageNet classifier, using DDP for distributed data parallelism for the resnet model backbone. However, it additionally uses model parallelism for the last classification layer, and there are a few all_reduce calls involved, outside of model backbone DDP.

So my understanding of the situation is as follows. It works without the hook, and errors with the hook because in the hook there might be some assumptions on what goes into the buckets, and because there are all_reduce calls outside of DDP, the hook is not able to anticipate them.

Is this true? I can live without the DDP hook. What I really want to do is the exclude compression for these all_reduce calls outside of the DDP backbone. In my experience even fp16 compression for these DistCrossEntropyFunc all_reduce calls would make loss forward/backward inaccurate. If I can turn them off through something like exclude_layers(layer_type_name) (can't find the function anymore), that would be great.

jrcavani commented 1 year ago

Here are a few runs with various compression settings for my code out of the box. I resumed a run, and measure speed + loss values. Without excluding those `DistCrossEntropyFunc calls from compression, 4 bits / 8 bits gave very unstable outcomes. If only I could exclude them from being compressed...

nccl
torchrun --
2023-09-08 04:26:02,866 Speed  4951/s  Loss  7.36  LR: 3.65e-03  Epoch:  4  Step:  28040  FP16_GS:  16384  End: 14.62hrs
2023-09-08 04:26:11,250 Speed  4890/s  Loss  7.41  LR: 3.65e-03  Epoch:  4  Step:  28060  FP16_GS:  16384  End: 14.80hrs

plain mpi
mpirun --allow-run-as-root -np 8 python3 train_v2.py
2023-09-08 04:02:40,534 Speed  1720/s  Loss  7.36  LR: 3.65e-03  Epoch:  4  Step:  28040  FP16_GS:  16384  End: 42.06hrs
2023-09-08 04:03:04,643 Speed  1700/s  Loss  7.41  LR: 3.65e-03  Epoch:  4  Step:  28060  FP16_GS:  16384  End: 42.57hrs

mpirun --allow-run-as-root -np 8 -x CGX_COMPRESSION_QUANTIZATION_BITS=32 -x CGX_COMPRESSION_BUCKET_SIZE=128 -x CGX_INNER_COMMUNICATOR_TYPE=SHM python3 train_v2.py
2023-09-08 04:13:18,354 Speed  3378/s  Loss  7.46  LR: 3.65e-03  Epoch:  4  Step:  28040  FP16_GS:  32768  End: 21.43hrs
2023-09-08 04:13:30,441 Speed  3391/s  Loss  7.46  LR: 3.65e-03  Epoch:  4  Step:  28060  FP16_GS:  16384  End: 21.34hrs

mpirun --allow-run-as-root -np 8 -x CGX_COMPRESSION_QUANTIZATION_BITS=16 -x CGX_COMPRESSION_BUCKET_SIZE=128 -x CGX_INNER_COMMUNICATOR_TYPE=SHM python3 train_v2.py
2023-09-08 04:17:10,358 Speed  3445/s  Loss  7.46  LR: 3.65e-03  Epoch:  4  Step:  28040  FP16_GS:  32768  End: 21.01hrs
2023-09-08 04:17:22,374 Speed  3411/s  Loss  7.46  LR: 3.65e-03  Epoch:  4  Step:  28060  FP16_GS:  16384  End: 21.21hrs

mpirun --allow-run-as-root -np 8 -x CGX_COMPRESSION_QUANTIZATION_BITS=16 -x CGX_COMPRESSION_BUCKET_SIZE=64 -x CGX_INNER_COMMUNICATOR_TYPE=SHM python3 train_v2.py
2023-09-08 04:22:54,114 Speed  3319/s  Loss  7.46  LR: 3.65e-03  Epoch:  4  Step:  28040  FP16_GS:  32768  End: 21.80hrs
2023-09-08 04:23:06,513 Speed  3305/s  Loss  7.46  LR: 3.65e-03  Epoch:  4  Step:  28060  FP16_GS:  16384  End: 21.89hrs

mpirun --allow-run-as-root -np 8 -x CGX_COMPRESSION_QUANTIZATION_BITS=16 -x CGX_COMPRESSION_BUCKET_SIZE=256 -x CGX_INNER_COMMUNICATOR_TYPE=SHM python3 train_v2.py
2023-09-08 04:24:13,788 Speed  3319/s  Loss  7.46  LR: 3.65e-03  Epoch:  4  Step:  28040  FP16_GS:  32768  End: 21.81hrs
2023-09-08 04:24:25,615 Speed  3465/s  Loss  7.46  LR: 3.65e-03  Epoch:  4  Step:  28060  FP16_GS:  16384  End: 20.88hrs

mpirun --allow-run-as-root -np 8 -x CGX_COMPRESSION_QUANTIZATION_BITS=8 -x CGX_COMPRESSION_BUCKET_SIZE=128 -x CGX_INNER_COMMUNICATOR_TYPE=SHM python3 train_v2.py
2023-09-08 04:20:14,864 Speed  5739/s  Loss 34.16  LR: 3.65e-03  Epoch:  4  Step:  28040  FP16_GS:  16384  End: 12.61hrs
2023-09-08 04:20:22,004 Speed  5742/s  Loss 34.35  LR: 3.65e-03  Epoch:  4  Step:  28060  FP16_GS:  16384  End: 12.60hrs

mpirun --allow-run-as-root -np 8 -x CGX_COMPRESSION_QUANTIZATION_BITS=4 -x CGX_COMPRESSION_BUCKET_SIZE=128 -x CGX_INNER_COMMUNICATOR_TYPE=SHM python3 train_v2.py
2023-09-08 04:21:24,422 Speed  5822/s  Loss 45.10  LR: 3.65e-03  Epoch:  4  Step:  28040  FP16_GS:  16384  End: 12.43hrs
2023-09-08 04:21:31,472 Speed  5817/s  Loss 45.73  LR: 3.65e-03  Epoch:  4  Step:  28060  FP16_GS:  16384  End: 12.44hrs