bytedance / flux

A fast communication-overlapping library for tensor parallelism on GPUs.
Apache License 2.0
181 stars 13 forks source link

[QUESTION] Why flux gemm_rs is not faster than torch? #34

Open hxdtest opened 3 weeks ago

hxdtest commented 3 weeks ago

Your question Ask a clear and concise question about Flux.

$./scripts/launch.sh test/test_gemm_rs.py 4096 12288 49152 --dtype=bfloat16 --iters=10
torchrun --node_rank=0 --nproc_per_node=4 --nnodes=1 --rdzv_endpoint=127.0.0.1:23456 test/test_gemm_rs.py 4096 12288 49152 --dtype=bfloat16 --iters=10
W0821 11:09:36.222000 139678045570880 torch/distributed/run.py:757] 
W0821 11:09:36.222000 139678045570880 torch/distributed/run.py:757] *****************************************
W0821 11:09:36.222000 139678045570880 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0821 11:09:36.222000 139678045570880 torch/distributed/run.py:757] *****************************************
before flux shm initialization
before flux shm initialization
after flux shm initialization
before flux shm initialization
after flux shm initialization
after flux shm initialization
before flux shm initialization
after flux shm initialization
SOL time for GEMM(M=4096,N=12288,K=49152,TP=4): 3.965ms
torch #0: gemm 4.575 ms, comm 0.882 ms, total 5.456 ms
torch #1: gemm 4.992 ms, comm 0.462 ms, total 5.454 ms
torch #2: gemm 4.528 ms, comm 0.929 ms, total 5.457 ms
torch #3: gemm 5.026 ms, comm 0.436 ms, total 5.462 ms
flux  #0: gemm 5.372 ms, comm 0.049 ms, total 5.421 ms
flux  #1: gemm 5.379 ms, comm 0.042 ms, total 5.421 ms
flux  #2: gemm 5.372 ms, comm 0.049 ms, total 5.421 ms
flux  #3: gemm 5.373 ms, comm 0.047 ms, total 5.421 ms
[root workflow_42980535 /hetero_infer/hanxudong.hxd/flux] 三 8月 21 11:09:46 
$nvidia-smi
Wed Aug 21 11:10:15 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.82.01    Driver Version: 470.82.01    CUDA Version: 12.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-SXM...  On   | 00000000:5A:00.0 Off |                    0 |
| N/A   34C    P0    64W / 400W |      3MiB / 81251MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM...  On   | 00000000:6B:00.0 Off |                    0 |
| N/A   35C    P0    65W / 400W |      3MiB / 81251MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   2  NVIDIA A100-SXM...  On   | 00000000:C3:00.0 Off |                    0 |
| N/A   34C    P0    66W / 400W |      3MiB / 81251MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   3  NVIDIA A100-SXM...  On   | 00000000:DA:00.0 Off |                    0 |
| N/A   34C    P0    65W / 400W |      3MiB / 81251MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
hxdtest commented 3 weeks ago

If I set tp_size=2,the test results are

$./scripts/launch.sh test/test_gemm_rs.py 4096 12288 49152 --dtype=bfloat16 --iters=10
torchrun --node_rank=0 --nproc_per_node=2 --nnodes=1 --rdzv_endpoint=127.0.0.1:23456 test/test_gemm_rs.py 4096 12288 49152 --dtype=bfloat16 --iters=10
W0821 11:21:10.272000 139828886370112 torch/distributed/run.py:757] 
W0821 11:21:10.272000 139828886370112 torch/distributed/run.py:757] *****************************************
W0821 11:21:10.272000 139828886370112 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0821 11:21:10.272000 139828886370112 torch/distributed/run.py:757] *****************************************
before flux shm initialization
after flux shm initialization
before flux shm initialization
after flux shm initialization
SOL time for GEMM(M=4096,N=12288,K=49152,TP=2): 7.929ms
torch #0: gemm 9.239 ms, comm 0.676 ms, total 9.915 ms
torch #1: gemm 9.557 ms, comm 0.359 ms, total 9.916 ms
flux  #0: gemm 13.111 ms, comm -0.243 ms, total 12.868 ms
flux  #1: gemm 13.341 ms, comm -0.472 ms, total 12.869 ms
wenlei-bao commented 3 weeks ago

@hxdtest Thanks for your interests in Flux. From what you reported, it looks like the gemm config in flux is not optimal on your test machine: torch #0: gemm 9.239 ms, vs flux #0: gemm 13.111 ms, GEMM is too slow here currently tuning config under reduce scatter folder contains the gemm config such as tiling shape etc., you can try the tuning tool here https://github.com/bytedance/flux/blob/main/tools/tune_gemm_rs.py to see what would be the better configs for your case. cc @zheng-ningxin

hxdtest commented 3 weeks ago

@hxdtest Thanks for your interests in Flux. From what you reported, it looks like the gemm config in flux is not optimal on your test machine: torch #0: gemm 9.239 ms, vs flux #0: gemm 13.111 ms, GEMM is too slow here currently tuning config under reduce scatter folder contains the gemm config such as tiling shape etc., you can try the tuning tool here https://github.com/bytedance/flux/blob/main/tools/tune_gemm_rs.py to see what would be the better configs for your case. cc @zheng-ningxin

Thank you for your reply. I try to run # ./scripts/launch.sh tools/tune_gemm_rs.py --output_dir ./output --check, but the result are

launch.sh  tools/tune_gemm_rs.py  --output_dir ./output --check
torchrun --node_rank=0 --nproc_per_node=8 --nnodes=1 --rdzv_endpoint=127.0.0.1:23456 tools/tune_gemm_rs.py --output_dir ./output --check
W0821 06:59:53.159000 140297798518592 torch/distributed/run.py:778] 
W0821 06:59:53.159000 140297798518592 torch/distributed/run.py:778] *****************************************
W0821 06:59:53.159000 140297798518592 torch/distributed/run.py:778] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0821 06:59:53.159000 140297798518592 torch/distributed/run.py:778] *****************************************
==== #1/1 Tuning for TuningConfig(M=4096, N=12288, K=49152, fuse_reduction=False, transpose_weight=False, dtype=torch.bfloat16, has_bias=False)
./new_3d/flux/src/ths_op/flux_shm.cc:138 Check failed: error == cudaSuccess. Got bad cuda status: invalid argument(1) at: cudaIpcOpenMemHandle(&ptrs[i], handles_h[i], cudaIpcMemLazyEnablePeerAccess)

./new_3d/flux/src/ths_op/flux_shm.cc:138 Check failed: error == cudaSuccess. Got bad cuda status: invalid argument(1) at: cudaIpcOpenMemHandle(&ptrs[i], handles_h[i], cudaIpcMemLazyEnablePeerAccess)

./new_3d/flux/src/ths_op/flux_shm.cc:138 Check failed: error == cudaSuccess. Got bad cuda status: invalid argument(1) at: cudaIpcOpenMemHandle(&ptrs[i], handles_h[i], cudaIpcMemLazyEnablePeerAccess)

./new_3d/flux/src/ths_op/flux_shm.cc:138 Check failed: error == cudaSuccess. Got bad cuda status: invalid argument(1) at: cudaIpcOpenMemHandle(&ptrs[i], handles_h[i], cudaIpcMemLazyEnablePeerAccess)

./new_3d/flux/src/ths_op/flux_shm.cc:138 failed: error == cudaSuccess. Got bad cuda status: invalid argument(1) at: cudaIpcOpenMemHandle(&ptrs[i], handles_h[i], cudaIpcMemLazyEnablePeerAccess)

[rank6]: Traceback(most recent call last)
[rank6]:   File ".new_3d/flux/tools/tune_gemm_rs.py", line 212, in <module>
[rank6]:     tune_one_config(prof_ctx=prof_ctx, config=config)
[rank6]:   File ".new_3d/flux/tools/tune_gemm_rs.py", line 167, in tune_one_config
[rank6]:     flux_output = run_flux_profiling(prof_ctx, input, weight, bias, config)
[rank6]:   File ".new_3d/flux/tools/tune_gemm_rs.py", line 135, in run_flux_profiling
[rank6]:     gemm_rs_op = flux.GemmRS(
[rank6]: RuntimeError: ,/new_3d/flux/src/ths_op/flux_shm.cc:138 Check failed: error == cudaSuccess. Got bad cuda status: invalid argument(1) at: cudaIpcOpenMemHandle(&ptrs[i], handles_h[i], cudaIpcMemLazyEnablePeerAccess)

[rank1]: Traceback (most recent call last):
[rank1]:   File "/ /new_3d/flux/tools/tune_gemm_rs.py", line 212, in <module>
[rank1]:     tune_one_config(prof_ctx=prof_ctx, config=config)
[rank1]:   File  /new_3d/flux/tools/tune_gemm_rs.py", line 167, in tune_one_config
[rank1]:     flux_output = run_flux_profiling(prof_ctx, input, weight, bias, config)
[rank1]:   File " new_3d/flux/tools/tune_gemm_rs.py", line 135, in run_flux_profiling
[rank1]:     gemm_rs_op = flux.GemmRS(
[rank1]: RuntimeError: ./new_3d/flux/src/ths_op/flux_shm.cc:138 Check failed: error == cudaSuccess. Got bad cuda status: invalid argument(1) at: cudaIpcOpenMemHandle(&ptrs[i], handles_h[i], cudaIpcMemLazyEnablePeerAccess)

[rank2]: Traceback (most recent call last):
[rank2]:   File "/input/kunlong.ckl/lizhi/new_3d/flux/tools/tune_gemm_rs.py", line 212, in <module>
[rank2]:     tune_one_config(prof_ctx=prof_ctx, config=config)
[rank2]:   File " /new_3d/flux/tools/tune_gemm_rs.py", line 167, in tune_one_config
[rank2]:     flux_output = run_flux_profiling(prof_ctx, input, weight, bias, config)
[rank2]:   File "/new_3d/flux/tools/tune_gemm_rs.py", line 135, in run_flux_profiling
[rank2]:     gemm_rs_op = flux.GemmRS(
[rank2]: RuntimeError: ./new_3d/flux/src/ths_op/flux_shm.cc:138 Check failed: error == cudaSuccess. Got bad cuda status: invalid argument(1) at: cudaIpcOpenMemHandle(&ptrs[i], handles_h[i], cudaIpcMemLazyEnablePeerAccess)
hxdtest commented 3 weeks ago

In addition, I run ./test_gemm_rs 4096 4096 4096 2, that is m=4096, n=4096, k=4096 and tp_size=2.

$./test_gemm_rs 1024 1024 1024 2
#0 time elapsed: 0.028 ms
#1 time elapsed: 0.029 ms

$./test_gemm_rs 4096 4096 4096 2
.flux/include/flux/cuda/gemm_impls/gemm_operator_base_default_impl.hpp:74 Check failed: error == cutlass::Status::kSuccess. Got cutlass error: Error Workspace Null(6) at: gemm_op.initialize(gemm_args, gemm_workspace, cu_stream)
./flux/include/flux/cuda/gemm_impls/gemm_operator_base_default_impl.hpp:74 Check failed: error == cutlass::Status::kSuccess. Got cutlass error: Error Workspace Null(6) at: gemm_op.initialize(gemm_args, gemm_workspace, cu_stream)

terminate called recursively
terminate called after throwing an instance of 'std::runtime_error'
Aborted (core dumped)
wenlei-bao commented 2 weeks ago

@hxdtest Thanks for your interests in Flux. From what you reported, it looks like the gemm config in flux is not optimal on your test machine: torch #0: gemm 9.239 ms, vs flux #0: gemm 13.111 ms, GEMM is too slow here currently tuning config under reduce scatter folder contains the gemm config such as tiling shape etc., you can try the tuning tool here https://github.com/bytedance/flux/blob/main/tools/tune_gemm_rs.py to see what would be the better configs for your case. cc @zheng-ningxin

Thank you for your reply. I try to run # ./scripts/launch.sh tools/tune_gemm_rs.py --output_dir ./output --check, but the result are

launch.sh  tools/tune_gemm_rs.py  --output_dir ./output --check
torchrun --node_rank=0 --nproc_per_node=8 --nnodes=1 --rdzv_endpoint=127.0.0.1:23456 tools/tune_gemm_rs.py --output_dir ./output --check
W0821 06:59:53.159000 140297798518592 torch/distributed/run.py:778] 
W0821 06:59:53.159000 140297798518592 torch/distributed/run.py:778] *****************************************
W0821 06:59:53.159000 140297798518592 torch/distributed/run.py:778] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0821 06:59:53.159000 140297798518592 torch/distributed/run.py:778] *****************************************
==== #1/1 Tuning for TuningConfig(M=4096, N=12288, K=49152, fuse_reduction=False, transpose_weight=False, dtype=torch.bfloat16, has_bias=False)
./new_3d/flux/src/ths_op/flux_shm.cc:138 Check failed: error == cudaSuccess. Got bad cuda status: invalid argument(1) at: cudaIpcOpenMemHandle(&ptrs[i], handles_h[i], cudaIpcMemLazyEnablePeerAccess)

./new_3d/flux/src/ths_op/flux_shm.cc:138 Check failed: error == cudaSuccess. Got bad cuda status: invalid argument(1) at: cudaIpcOpenMemHandle(&ptrs[i], handles_h[i], cudaIpcMemLazyEnablePeerAccess)

./new_3d/flux/src/ths_op/flux_shm.cc:138 Check failed: error == cudaSuccess. Got bad cuda status: invalid argument(1) at: cudaIpcOpenMemHandle(&ptrs[i], handles_h[i], cudaIpcMemLazyEnablePeerAccess)

./new_3d/flux/src/ths_op/flux_shm.cc:138 Check failed: error == cudaSuccess. Got bad cuda status: invalid argument(1) at: cudaIpcOpenMemHandle(&ptrs[i], handles_h[i], cudaIpcMemLazyEnablePeerAccess)

./new_3d/flux/src/ths_op/flux_shm.cc:138 failed: error == cudaSuccess. Got bad cuda status: invalid argument(1) at: cudaIpcOpenMemHandle(&ptrs[i], handles_h[i], cudaIpcMemLazyEnablePeerAccess)

[rank6]: Traceback(most recent call last)
[rank6]:   File ".new_3d/flux/tools/tune_gemm_rs.py", line 212, in <module>
[rank6]:     tune_one_config(prof_ctx=prof_ctx, config=config)
[rank6]:   File ".new_3d/flux/tools/tune_gemm_rs.py", line 167, in tune_one_config
[rank6]:     flux_output = run_flux_profiling(prof_ctx, input, weight, bias, config)
[rank6]:   File ".new_3d/flux/tools/tune_gemm_rs.py", line 135, in run_flux_profiling
[rank6]:     gemm_rs_op = flux.GemmRS(
[rank6]: RuntimeError: ,/new_3d/flux/src/ths_op/flux_shm.cc:138 Check failed: error == cudaSuccess. Got bad cuda status: invalid argument(1) at: cudaIpcOpenMemHandle(&ptrs[i], handles_h[i], cudaIpcMemLazyEnablePeerAccess)

[rank1]: Traceback (most recent call last):
[rank1]:   File "/ /new_3d/flux/tools/tune_gemm_rs.py", line 212, in <module>
[rank1]:     tune_one_config(prof_ctx=prof_ctx, config=config)
[rank1]:   File  /new_3d/flux/tools/tune_gemm_rs.py", line 167, in tune_one_config
[rank1]:     flux_output = run_flux_profiling(prof_ctx, input, weight, bias, config)
[rank1]:   File " new_3d/flux/tools/tune_gemm_rs.py", line 135, in run_flux_profiling
[rank1]:     gemm_rs_op = flux.GemmRS(
[rank1]: RuntimeError: ./new_3d/flux/src/ths_op/flux_shm.cc:138 Check failed: error == cudaSuccess. Got bad cuda status: invalid argument(1) at: cudaIpcOpenMemHandle(&ptrs[i], handles_h[i], cudaIpcMemLazyEnablePeerAccess)

[rank2]: Traceback (most recent call last):
[rank2]:   File "/input/kunlong.ckl/lizhi/new_3d/flux/tools/tune_gemm_rs.py", line 212, in <module>
[rank2]:     tune_one_config(prof_ctx=prof_ctx, config=config)
[rank2]:   File " /new_3d/flux/tools/tune_gemm_rs.py", line 167, in tune_one_config
[rank2]:     flux_output = run_flux_profiling(prof_ctx, input, weight, bias, config)
[rank2]:   File "/new_3d/flux/tools/tune_gemm_rs.py", line 135, in run_flux_profiling
[rank2]:     gemm_rs_op = flux.GemmRS(
[rank2]: RuntimeError: ./new_3d/flux/src/ths_op/flux_shm.cc:138 Check failed: error == cudaSuccess. Got bad cuda status: invalid argument(1) at: cudaIpcOpenMemHandle(&ptrs[i], handles_h[i], cudaIpcMemLazyEnablePeerAccess)

@hxdtest It looks like the peer access is not enabled, can you please double check if the p2p access is enabled on your test machine?