microsoft / tutel

Tutel MoE: An Optimized Mixture-of-Experts Implementation
MIT License
710 stars 85 forks source link

100x slower when using 4nodes than 1node to run the helloworld_ddp example #160

Closed a157801 closed 2 years ago

a157801 commented 2 years ago

Hello, I meet a problem that it is 100x slower when using 2node than 1node to run the helloworld_ddp example. I compile tutel with cuda11.3, pytorch 1.11 and nccl 2.9.9 on a nvidia-a100 GPU cluster with 100G IB. When I run tutel.examples.helloworld_ddp on a single node with 8 gpus and batch size 16, the speed meets the results in your table(step_time = 0.012315). But when I test with 4nodes, the step time becomes about 1 second, which is about 100x slower. Other multi-node tasks can normally run on my cluster, so I think maybe something is wrong with the environment when I build the project. It will be very nice if you can share the detailed environment information, such as the pytorch version, cuda version, g++ version, etc. Thanks.

EricWangCN commented 2 years ago

Have you tried the latest NGC's PyTorch container?

ghostplant commented 2 years ago

Seems like it is typically a problem that libNCCL for Pytorch is not configure correctly to use IB, so that Pytorch's NCCL uses other slower eithernet interfaces etc. @abuccts Do you have some suggestions?

abuccts commented 2 years ago

Could you provide the detailed command for "helloworld_ddp example" and the log after setting NCCL_DEBUG=INFO env?

a157801 commented 2 years ago

Our cluster is based on slurm so I add following code into the helloword_ddp

if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
    dist_rank = int(os.environ["RANK"])
    dist_world_size = int(os.environ['WORLD_SIZE'])
    args.local_rank = int(os.environ['LOCAL_RANK'])
    os.environ['MASTER_PORT'] = '29529'
if 'SLURM_PROCID' in os.environ:
    dist_rank = int(os.environ['SLURM_PROCID'])
    args.local_rank = dist_rank % torch.cuda.device_count()
    dist_world_size = int(os.environ['SLURM_NTASKS'])
    node_list = os.environ['SLURM_NODELIST']
    num_gpus = torch.cuda.device_count()
    addr = subprocess.getoutput(
        f'scontrol show hostname {node_list} | head -n1')
    os.environ['MASTER_PORT'] = '29529'
    os.environ['MASTER_ADDR'] = addr
    os.environ['WORLD_SIZE'] = str(dist_world_size)
    os.environ['LOCAL_RANK'] = str(args.local_rank)
    os.environ['RANK'] = str(dist_rank)

and run the command

srun --mpi=pmi2 -n 32 --gres=gpu:8 --ntasks-per-node=8 python -u -m tutel.examples.helloworld_ddp --batch_size=16

Other multi-node tasks, such as the swin transformer, can normally run under this setting. And I also provide the log file. Thanks.

ghostplant commented 2 years ago

Can you run tutel.examples.helloworld instead of DDP version to help us collect these TFlop value respectively?

1) srun --mpi=pmi2 -n 32 --gres=gpu:8 --ntasks-per-node=8 --export=SKIP_A2A=1 python -u -m tutel.examples.helloworld --batch_size=16

2) srun --mpi=pmi2 -n 32 --gres=gpu:8 --ntasks-per-node=8 python -u -m tutel.examples.helloworld --batch_size=16 --allreduce_degree=-1

It'll tell whether the regression comes from NCCL's Allreduce or AlltoAll. Thanks!

a157801 commented 2 years ago

Can you run tutel.examples.helloworld instead of DDP version to help us collect these TFlop value respectively?

  1. srun --mpi=pmi2 -n 32 --gres=gpu:8 --ntasks-per-node=8 --export=SKIP_A2A=1 python -u -m tutel.examples.helloworld --batch_size=16
  2. srun --mpi=pmi2 -n 32 --gres=gpu:8 --ntasks-per-node=8 python -u -m tutel.examples.helloworld --batch_size=16 --allreduce_degree=-1

It'll tell whether the regression comes from NCCL's Allreduce or AlltoAll. Thanks!

After I set SKIP_A2A as 1, the speed on 32gpus is similar to that on 8gpus. (step_time = 0.010818). Just setting allreduce_degree as -1 does not affect the speed(about 1 second). A2A communication would be the bottleneck.

ghostplant commented 2 years ago

That's much clear. Seems like the problem is from Pytorch A2A not utilizing IB correctly. @abuccts

abuccts commented 2 years ago

Hi @a157801, thanks for the details.

According to log, I can see in your cluster the intra-node connections are NVLink (300 GB/s if you are using A100) while inter-node only has one InfiniBand (100 Gbps instead of 100 GB/s in your original info). So the bad scaling you observed could be expected. There're two types of NCCL operations in MoE and here're the estimations for your cluster:

You can check above estimations by running allreduce_perf and alltoall_perf in nccl-tests with mpi. Please also correct me if the topology for your cluster or busbw numbers are wrong.


Depending on the computation/communication ratio, the downgrade numbers will be smaller if you run an end-to-end model instead of single MoE layer test. As far as I can see, the cluster you are using has an extremely imbalanced intra- and inter- connections (you can also check this comment) and is not suitable for multi-node MoE workloads where alltoall is used heavily. I'm not sure which number you are comparing with but in the Azure cluster we used, there're 8x 200 Gbps NIC so 1.6 Tbps or 200 GB/s in total.

ghostplant commented 2 years ago

Can you run tutel.examples.helloworld instead of DDP version to help us collect these TFlop value respectively?

  1. srun --mpi=pmi2 -n 32 --gres=gpu:8 --ntasks-per-node=8 --export=SKIP_A2A=1 python -u -m tutel.examples.helloworld --batch_size=16
  2. srun --mpi=pmi2 -n 32 --gres=gpu:8 --ntasks-per-node=8 python -u -m tutel.examples.helloworld --batch_size=16 --allreduce_degree=-1

It'll tell whether the regression comes from NCCL's Allreduce or AlltoAll. Thanks!

After I set SKIP_A2A as 1, the speed on 32gpus is similar to that on 8gpus. (step_time = 0.010818). Just setting allreduce_degree as -1 does not affect the speed(about 1 second). A2A communication would be the bottleneck.

Can you try this and inform us the step_time differences?

srun --mpi=pmi2 -n 32 --gres=gpu:8 --ntasks-per-node=8 --export=LOCAL_SIZE=8 python -u -m tutel.examples.helloworld --batch_size=16 --use_2dh

a157801 commented 2 years ago

srun --mpi=pmi2 -n 32 --gres=gpu:8 --ntasks-per-node=8 --export=LOCAL_SIZE=8 python -u -m tutel.examples.helloworld --batch_size=16 --use_2dh

The step time is also about 1 second

a157801 commented 2 years ago

We profile the inter-node speed of our cluster, and the all2all speed is limited due to the ib hardware. I trained the SwinV2-MoE-S with 32 experts, the qps is about 128, which is 2.5x slower than 295 in your paper. It may be necessary to update our hardwares, such as adding more IBs, to train the sparse model efficiently with multiple nodes. Thanks a lot! @ghostplant @abuccts

ghostplant commented 2 years ago

I'll close this issue since it is not actually a tutel application issue. Hope libNCCL's primal ncclSend/ncclRecv to be more efficient on your ib hardware.