Open 152334H opened 1 month ago
The command,
scripts/launch.sh test/test_ag_kernel_functional.py 4096 4096 4096
Seems to work fine.
However, if we add this line:
diff --git a/test/test_ag_kernel_functional.py b/test/test_ag_kernel_functional.py index 0f1d9ce..c7a5267 100644 --- a/test/test_ag_kernel_functional.py +++ b/test/test_ag_kernel_functional.py @@ -115,9 +115,14 @@ def run_test_with_args( bias, local_copy, ) + assert flux_all_input.shape == gt_all_input.shape, f"{flux_all_input.shape=} vs {gt_all_input.shape=}"
Then execute with some different N, such as:
N
scripts/launch.sh test/test_ag_kernel_functional.py 4096 16384 4096
Many errors like,
AssertionError: flux_all_input.shape=torch.Size([512, 4096]) vs gt_all_input.shape=torch.Size([2048, 4096])
Will appear.
Not sure. I do not fully understand what this code is supposed to do.
I believe it is supposed to contain the original gathered input, but I am uncertain.
I do not have any proposed fix either.
Environment if useful:
byte_flux==1.0.3 filelock==3.16.1 fsspec==2024.9.0 Jinja2==3.1.4 MarkupSafe==3.0.1 mpmath==1.3.0 networkx==3.4.1 numpy==2.1.2 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==9.1.0.70 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.20.5 nvidia-nvjitlink-cu12==12.6.77 nvidia-nvtx-cu12==12.1.105 packaging==24.1 pynvml==11.5.3 sympy==1.13.3 torch==2.4.1 triton==3.0.0 typing_extensions==4.12.2
This should be a fresh environment, except with pynvml+numpy added.
@152334H Yeah, it looks like a bug when checking the shape, all_input shape should be [TP*m, N]. We will take a look at this, meanwhile, please use test_ag_kernel.py instead as the main testing script.
Describe the bug
The command,
Seems to work fine.
However, if we add this line:
Then execute with some different
N
, such as:Many errors like,
Will appear.
Expected behavior
Not sure. I do not fully understand what this code is supposed to do.
I believe it is supposed to contain the original gathered input, but I am uncertain.
I do not have any proposed fix either.
Additional context
Environment if useful:
This should be a fresh environment, except with pynvml+numpy added.