bytedance / flux

A fast communication-overlapping library for tensor parallelism on GPUs.
Apache License 2.0
223 stars 17 forks source link

[BUG] incorrect shape output from AGKernel.gather() #44

Open 152334H opened 1 month ago

152334H commented 1 month ago

Describe the bug

The command,

scripts/launch.sh test/test_ag_kernel_functional.py 4096 4096 4096

Seems to work fine.

However, if we add this line:

diff --git a/test/test_ag_kernel_functional.py b/test/test_ag_kernel_functional.py
index 0f1d9ce..c7a5267 100644
--- a/test/test_ag_kernel_functional.py
+++ b/test/test_ag_kernel_functional.py
@@ -115,9 +115,14 @@ def run_test_with_args(
             bias,
             local_copy,
         )
+        assert flux_all_input.shape == gt_all_input.shape, f"{flux_all_input.shape=} vs {gt_all_input.shape=}"

Then execute with some different N, such as:

scripts/launch.sh test/test_ag_kernel_functional.py 4096 16384 4096

Many errors like,

AssertionError: flux_all_input.shape=torch.Size([512, 4096]) vs gt_all_input.shape=torch.Size([2048, 4096])

Will appear.

Expected behavior

Not sure. I do not fully understand what this code is supposed to do.

I believe it is supposed to contain the original gathered input, but I am uncertain.

I do not have any proposed fix either.

Additional context

Environment if useful:

byte_flux==1.0.3
filelock==3.16.1
fsspec==2024.9.0
Jinja2==3.1.4
MarkupSafe==3.0.1
mpmath==1.3.0
networkx==3.4.1
numpy==2.1.2
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.6.77
nvidia-nvtx-cu12==12.1.105
packaging==24.1
pynvml==11.5.3
sympy==1.13.3
torch==2.4.1
triton==3.0.0
typing_extensions==4.12.2

This should be a fresh environment, except with pynvml+numpy added.

wenlei-bao commented 2 days ago

@152334H Yeah, it looks like a bug when checking the shape, all_input shape should be [TP*m, N]. We will take a look at this, meanwhile, please use test_ag_kernel.py instead as the main testing script.