bytedance / flux

A fast communication-overlapping library for tensor parallelism on GPUs.
Apache License 2.0
223 stars 17 forks source link

[BUG] gemm and reduce-scatter are not overlapped #7

Closed wenscarl closed 4 months ago

wenscarl commented 4 months ago

Describe the bug No overlap of gemm and ncclKernelReduceScatter are overlapped on the provided reduce-scatter example. The gemm kernels are on the default stream. image

To Reproduce On Ampere machine with 2 GPUs, ./scripts/launch.sh test/test_gemm_rs.py 4096 12288 49152 --dtype=float16 --iters=10

Expected behavior gemms and reduce-scatter should be overlapped.

wenlei-bao commented 4 months ago

@wenscarl Thanks for your feedback. Can you please also post the whole results from your command ? In the test script, there are two test case, one is baseline/pytorch, the other is flux; and for TP=2(2GPUs), GEMM RS, flux shouldn't have nccl kernel IIRC, so my guess is that your screenshot show the baseline profile, which is the torch one, it has the nccl RS kernel. Flux profile should be at later point in the timeline. And also is this A100 NVLink or PCIe one?

wenscarl commented 4 months ago

It's on A100 nvlink. Should I look at cutlass::Kernel2<bytedance::flux::FluxGemmKernel<bytedance::flux::GemmMeta<bytedance::flux::GemmDTypeConfig<cute::C<(bytedance::flux::DataTypeEnum)1>, cute::C<(bytedance::flux::DataTypeEnum)1> and barrier_on_stream_kernel_threadgroup<(threadgroup_t)1>(int, int) where the first one fuses gemm and rs and the second one is a barrier? image

wenlei-bao commented 4 months ago

Yes. BTW you can also use --profile option. cc @zheng-ningxin

wenscarl commented 4 months ago

From the profile result shown above, one FluxGemmKernel+barrier_on_stream_kernel_threadgroup is about 65ms while sm80_xmma_gemm_f16f16_f16f32_f32_tn_n_tilesize192x128x32_stage3_warpsize4x2x1_tensor16x8x16_kernel+ncclDevKernel_AllReduce_Sum_u8_RING_LL is about 8.6ms. Is that normal?

wenlei-bao commented 4 months ago

From the profile result shown above, one FluxGemmKernel+barrier_on_stream_kernel_threadgroup is about 65ms while sm80_xmma_gemm_f16f16_f16f32_f32_tn_n_tilesize192x128x32_stage3_warpsize4x2x1_tensor16x8x16_kernel+ncclDevKernel_AllReduce_Sum_u8_RING_LL is about 8.6ms. Is that normal?

No. The profiled result doesn't look right. As I asked, can you please paste the whole output of the command? That report the time of Flux vs pytorch. e.g. on a 8-A100 NVLink cluster:

./launch.sh test/test_gemm_rs.py 4096 12288 49152 --dtype=float16 --iters=10
SOL time for GEMM(M=4096,N=12288,K=49152,TP=8): 1.982ms
torch #0: gemm 2.771 ms, comm 0.490 ms, total 3.262 ms
torch #1: gemm 2.779 ms, comm 0.484 ms, total 3.262 ms
torch #2: gemm 2.472 ms, comm 0.790 ms, total 3.262 ms
torch #3: gemm 2.414 ms, comm 0.845 ms, total 3.259 ms
torch #4: gemm 2.435 ms, comm 0.825 ms, total 3.260 ms
torch #5: gemm 2.779 ms, comm 0.483 ms, total 3.262 ms
torch #6: gemm 2.473 ms, comm 0.790 ms, total 3.262 ms
torch #7: gemm 2.771 ms, comm 0.491 ms, total 3.262 ms
flux  #0: gemm 2.749 ms, comm 0.029 ms, total 2.778 ms
flux  #1: gemm 2.757 ms, comm 0.020 ms, total 2.778 ms
flux  #2: gemm 2.492 ms, comm 0.286 ms, total 2.778 ms
flux  #3: gemm 2.427 ms, comm 0.351 ms, total 2.778 ms
flux  #4: gemm 2.435 ms, comm 0.343 ms, total 2.778 ms
flux  #5: gemm 2.756 ms, comm 0.021 ms, total 2.777 ms
flux  #6: gemm 2.487 ms, comm 0.291 ms, total 2.778 ms
flux  #7: gemm 2.750 ms, comm 0.028 ms, total 2.777 ms
wenscarl commented 4 months ago

The full output on a A100x2 with pcie. My test was on commit 96b2e03adeee52ebcac0199b4f327e8415ce84b3.


# ./scripts/launch.sh test/test_gemm_rs.py 4096 12288 49152 --dtype=float16 --iters=10
torchrun --node_rank=0 --nproc_per_node=2 --nnodes=1 --rdzv_endpoint=127.0.0.1:23456 test/test_gemm_rs.py 4096 12288 49152 --dtype=float16 --iters=10
W0627 16:19:38.106000 139630205469824 torch/distributed/run.py:757]
W0627 16:19:38.106000 139630205469824 torch/distributed/run.py:757] *****************************************
W0627 16:19:38.106000 139630205469824 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0627 16:19:38.106000 139630205469824 torch/distributed/run.py:757] *****************************************
before pynvshmem.init_with_c10d_pg
before pynvshmem.init_with_c10d_pg
WARN: init failed for remote transport: ibrc
WARN: init failed for remote transport: ibrc
after pynvshmem.init_with_c10d_pgafter pynvshmem.init_with_c10d_pg

SOL time for GEMM(M=4096,N=12288,K=49152,TP=2): 7.929ms
torch #0: gemm 10.084 ms, comm 6.278 ms, total 16.362 ms
torch #1: gemm 9.840 ms, comm 6.517 ms, total 16.357 ms
flux  #0: gemm 20.291 ms, comm 0.407 ms, total 20.698 ms
flux  #1: gemm 20.197 ms, comm 0.500 ms, total 20.697 ms
wenlei-bao commented 4 months ago

The full output on a A100x2 with pcie. My test was on commit 96b2e03.

# ./scripts/launch.sh test/test_gemm_rs.py 4096 12288 49152 --dtype=float16 --iters=10
torchrun --node_rank=0 --nproc_per_node=2 --nnodes=1 --rdzv_endpoint=127.0.0.1:23456 test/test_gemm_rs.py 4096 12288 49152 --dtype=float16 --iters=10
W0627 16:19:38.106000 139630205469824 torch/distributed/run.py:757]
W0627 16:19:38.106000 139630205469824 torch/distributed/run.py:757] *****************************************
W0627 16:19:38.106000 139630205469824 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0627 16:19:38.106000 139630205469824 torch/distributed/run.py:757] *****************************************
before pynvshmem.init_with_c10d_pg
before pynvshmem.init_with_c10d_pg
WARN: init failed for remote transport: ibrc
WARN: init failed for remote transport: ibrc
after pynvshmem.init_with_c10d_pgafter pynvshmem.init_with_c10d_pg

SOL time for GEMM(M=4096,N=12288,K=49152,TP=2): 7.929ms
torch #0: gemm 10.084 ms, comm 6.278 ms, total 16.362 ms
torch #1: gemm 9.840 ms, comm 6.517 ms, total 16.357 ms
flux  #0: gemm 20.291 ms, comm 0.407 ms, total 20.698 ms
flux  #1: gemm 20.197 ms, comm 0.500 ms, total 20.697 ms

I thought you said NVLink. For PCIe, we haven't release PCIe support yet.

wenscarl commented 4 months ago
  1. Is there any plan to support pattern gemm-allreduce? Or if tailored from gemm-rs + allgather, how much performance gain would be expected against nonoverlapped gemm-allreduce?
  2. Is there code pointer to showcase applying flux to vllm? I think it's mentioned in the paper.
wenlei-bao commented 4 months ago

@wenscarl Not yet to both 1/2.

I will close this issue as the reported bug is not applied. You are welcome to open issue to continue discuss other problems.