bytedance / flux

A fast communication-overlapping library for tensor parallelism on GPUs.
Apache License 2.0
193 stars 13 forks source link

[QUESTION] Can Gemm_V3 be used in SM80? #38

Open ginowu opened 2 weeks ago

ginowu commented 2 weeks ago

Hello I tried to run Gemm_V3 in A100(by some trivial changes to register and allow Gemm_V3 can be used in SM80), but there is out of bound access exception in kernel as you can see from below logs captured from running compute-sanitizer.

 Then could Gemm_V3 run in SM80? If cannot, which part needs further modification? Thanks a lot for your help.
image
ginowu commented 2 weeks ago

command I used for running: compute-sanitizer torchrun --node_rank=0 --nproc_per_node=1 --nnodes=1 --rdzv_endpoint=127.0.0.1:23456 test/test_gemm_rs.py 4096 12288 6144 --dtype=float16 --iters=10

ginowu commented 2 weeks ago

Replaced MainloopSm80CpAsyncUnpredicated with MainloopSm80CpAsync, then after more tests, found gemm_only is OK, but gemm_rs failed.

Below line of code in epilogue_vectorized_reducescatter.hpp triggered the out of bound read in kernel: `copy(CopyAtomR2G{}, tDrD(, m, n), tDgDmn(_, m, n));`

But EpilogueReduceScatterVectorized is not used anymore for sm80 and sm90 in current latest codes, have you ever used it successfully before?

ginowu commented 2 weeks ago

BTW, there is one bug in "select_mma_atom": need to swap mma atoms for FP16 & BF16

ginowu commented 1 week ago

The root cause is in gemm_reduce_scatter.cc, the "output_scatter_ptrs" is in host memory, then accessing this memory block in epilogue_vectorized_reduce_scatter.hpp will bring in crash. After copying the pointers into device memory, test/test_gemm_rs.py can run successfully.

ginowu commented 1 week ago

another observation is when K is large(like 20480), the accuracy checking can pass most of the time, but when K is small(like 6144), accuracy checking would fail most of the time(about 1/100000 results are wrong).

With big K: torchrun --node_rank=0 --nproc_per_node=2 --nnodes=1 --rdzv_endpoint=127.0.0.1:23456 test/test_gemm_rs.py 4096 12288 20480 --dtype=float16 --iters=10