Closed tlrmchlsmth closed 4 months ago
@tlrmchlsmth Did you build with --nvshmem or no? I didn't reproduce this on my end, I will try the other build option.
@wenlei-bao I'm not using --nvshmem
@tlrmchlsmth I am trying that build, at meanwhile, can you please try add --nvshmem when build?
@tlrmchlsmth NVD, please use the script to run, I think it maybe because of that. Below is the result without using --nvshmem option.
$ ./scripts/launch.sh test/test_gemm_rs.py 1024 1024 1024 --iters=10
SOL time for GEMM(M=1024,N=1024,K=1024,TP=8): 0.001ms
torch #0: gemm 0.015 ms, comm 0.448 ms, total 0.464 ms
torch #1: gemm 0.332 ms, comm 0.116 ms, total 0.448 ms
torch #2: gemm 0.332 ms, comm 0.116 ms, total 0.448 ms
torch #3: gemm 0.013 ms, comm 0.452 ms, total 0.465 ms
torch #4: gemm 0.013 ms, comm 0.452 ms, total 0.465 ms
torch #5: gemm 0.016 ms, comm 0.449 ms, total 0.465 ms
torch #6: gemm 0.013 ms, comm 0.452 ms, total 0.466 ms
torch #7: gemm 0.015 ms, comm 0.449 ms, total 0.464 ms
flux #0: gemm 0.021 ms, comm 0.067 ms, total 0.088 ms
flux #1: gemm 0.037 ms, comm 0.051 ms, total 0.088 ms
flux #2: gemm 0.036 ms, comm 0.036 ms, total 0.072 ms
flux #3: gemm 0.021 ms, comm 0.068 ms, total 0.089 ms
flux #4: gemm 0.021 ms, comm 0.067 ms, total 0.089 ms
flux #5: gemm 0.022 ms, comm 0.066 ms, total 0.088 ms
flux #6: gemm 0.021 ms, comm 0.068 ms, total 0.089 ms
flux #7: gemm 0.021 ms, comm 0.067 ms, total 0.088 ms
whoops, I mis-copied the line I used to run -- I am using ./scripts/launch.sh test/test_gemm_rs.py 1024 1024 1024
BTW it's failing the exact same way when I try to use it in my vLLM PR
Ah, OK. So is this a separated build and test run of flux (test_gemm_rs.py ) on your machine? Or after integration? Depends on that, we might need to use the vLLM branch to reproduce. cc @zheng-ningxin
No, this was with a completely fresh clone of flux
, fresh venv, so the repro shouldn't depend on vLLM.
Just now as suggested, I rebuilt with --nvshmem
and now I see success:
./scripts/launch.sh test/test_gemm_rs.py 4096 12288 49152 --iters=10 tms@bunsen
torchrun --node_rank=0 --nproc_per_node=8 --nnodes=1 --rdzv_endpoint=127.0.0.1:23456 test/test_gemm_rs.py 4096 12288 49152 --iters=10
W0628 21:14:24.716000 140121298232448 torch/distributed/run.py:757]
W0628 21:14:24.716000 140121298232448 torch/distributed/run.py:757] *****************************************
W0628 21:14:24.716000 140121298232448 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0628 21:14:24.716000 140121298232448 torch/distributed/run.py:757] *****************************************
before flux shm initialization
before flux shm initialization
before flux shm initialization
before flux shm initialization
before flux shm initialization
before flux shm initialization
before flux shm initialization
before flux shm initialization
WARN: init failed for remote transport: ibrc
WARN: init failed for remote transport: ibrc
WARN: init failed for remote transport: ibrc
WARN: init failed for remote transport: ibrc
WARN: init failed for remote transport: ibrc
WARN: init failed for remote transport: ibrc
WARN: init failed for remote transport: ibrc
WARN: init failed for remote transport: ibrc
after flux shm initializationafter flux shm initializationafter flux shm initializationafter flux shm initialization
after flux shm initializationafter flux shm initialization
after flux shm initialization
after flux shm initialization
SOL time for GEMM(M=4096,N=12288,K=49152,TP=8): 1.982ms
torch #0: gemm 2.764 ms, comm 0.522 ms, total 3.286 ms
torch #1: gemm 2.764 ms, comm 0.522 ms, total 3.286 ms
torch #2: gemm 2.768 ms, comm 0.518 ms, total 3.286 ms
torch #3: gemm 2.769 ms, comm 0.520 ms, total 3.289 ms
torch #4: gemm 2.646 ms, comm 0.638 ms, total 3.284 ms
torch #5: gemm 2.723 ms, comm 0.563 ms, total 3.286 ms
torch #6: gemm 2.785 ms, comm 0.502 ms, total 3.287 ms
torch #7: gemm 2.768 ms, comm 0.519 ms, total 3.287 ms
flux #0: gemm 2.780 ms, comm 0.090 ms, total 2.870 ms
flux #1: gemm 2.777 ms, comm 0.092 ms, total 2.870 ms
flux #2: gemm 2.785 ms, comm 0.085 ms, total 2.870 ms
flux #3: gemm 2.782 ms, comm 0.088 ms, total 2.870 ms
flux #4: gemm 2.308 ms, comm 0.561 ms, total 2.870 ms
flux #5: gemm 2.322 ms, comm 0.548 ms, total 2.870 ms
flux #6: gemm 2.469 ms, comm 0.401 ms, total 2.870 ms
flux #7: gemm 2.325 ms, comm 0.545 ms, total 2.870 ms
@tlrmchlsmth OK. Interesting. Is this on A100 NVlink, right? As you see my other comment, I also build without nvshmem. There might be some issue, let me try a clean build.
Yeah that's right. I'm adding the output of vllm's collect_env.py as well.
@tlrmchlsmth I tried a clean build without nvshmem on my side, I still cannot reproduce the issue. It runs without issue on my side. @zheng-ningxin have you seen this before?
@tlrmchlsmth I tried a clean build without nvshmem on my side, I still cannot reproduce the issue. It runs without issue on my side. @zheng-ningxin have you seen this before?
No, I haven’t seen this issue before. I’ll try to reproduce it today and take a look.
I have reproduced this issue. I found that this issue only occurs when using torch==2.3. I switched to torch2.1 and this problem no longer exists. I will continue to investigate. @tlrmchlsmth @wenlei-bao
Great! Yes, I can confirm I am on torch==2.3.
Great! Yes, I can confirm I am on torch==2.3.
solved here https://github.com/bytedance/flux/pull/13
Describe the bug I'm trying to use the latest update, but running into a new issue.
Running:
Results in the following:
Environment Information