pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
BSD 3-Clause "New" or "Revised" License
5.33k stars 484 forks source link

Question about the ENABLE_INTRA_NODE_COMM for speculative decoding #183

Open jianc99 opened 1 week ago

jianc99 commented 1 week ago

Hi, I just tried to use this custom all reduce kernel for speculative decoding. I set ENABLE_INTRA_NODE_COMM=1. But I found the code will stuck after several iteration. Is there some bugs of this kernel for the support of speculative decoding? The code can successfully run with the original NCCL ring kernel. Thanks!

jianc99 commented 1 week ago

And this problem only occurs on 8xA100 and 4xA100, I tested on other machine like 2xA100 and 8xL40, the problem did't occur.

yifuwang commented 6 days ago

Hey @jianc99, I wasn't able to reproduce the issue on my setup. Can you post your GPU connectivity with nvidia-smi topo -m and post it here?

jianc99 commented 6 days ago

Hi @yifuwang , here is the GPU connectivity info

     GPU0  GPU1    GPU2    GPU3    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X  NV12    NV12    NV12    0,48    0       N/A
GPU1    NV12     X  NV12    NV12    0,48    0       N/A
GPU2    NV12    NV12     X  NV12    0,48    0       N/A
GPU3    NV12    NV12    NV12     X  0,48    0       N/A

And the torch version I am using is 2.5.0.dev20240613+cu121. Whether or not compile doesn't affect the occur of the problem.

export MODEL_REPO=meta-llama/Llama-2-70b-hf
export DRAFT_MODEL_REPO=meta-llama/Llama-2-7b-hf
ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=4 generate.py  --draft_checkpoint_path checkpoints/$DRAFT_MODEL_REPO/model.pth  --checkpoint_path checkpoints/$MODEL_REPO/model.pth --speculate_k 5 --prompt "def quicksort(arr):" --max_new_tokens 200 --num_samples 50 --temperature 0

The output is below, just running several iterations, then stuck, and all the GPU util keeps 100%.

def quicksort(arr):
    if len(arr) <= 1:
        return arr
    pivot = arr[0]
    left = [x for x in arr[1:] if x < pivot]
    right = [x for x in arr[1:] if x >= pivot]
    return quicksort(left) + [pivot] + quicksort(right)

if __name__ == '__main__':
    arr = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    print(quicksort(arr))
 package com.example.android.miwok;

import android.content.Context;
import android.media.AudioManager;
import android.media.MediaPlayer;
import android.os.Bundle;
import android.support.v4.app.Fragment;
import android
Time for inference 1: 19.81 sec total, 10.10 tokens/sec
Bandwidth achieved: 350.92 GB/s
def quicksort(arr):
    if len(arr) <= 1:
        return arr
    pivot = arr[0]
    left = [x for x in arr[1:] if x < pivot]
    right = [x for x in arr[1:] if x >= pivot]
    return quicksort(left) + [pivot] + quicksort(right)

if __name__ == '__main__':
    arr = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    print(quicksort(arr))
 package com.example.android.miwok;

import android.content.Context;
import android.media.AudioManager;
import android.media.MediaPlayer;
import android.os.Bundle;
import android.support.v4.app.Fragment;
import android
Time for inference 2: 12.96 sec total, 15.43 tokens/sec
Bandwidth achieved: 536.10 GB/s
def quicksort(arr):
    if len(arr) <= 1:
        return arr
    pivot = arr[0]
    left = [x for x in arr[1:] if x < pivot]
    right = [x for x in arr[1:] if x >= pivot]
    return quicksort(left) + [pivot] + quicksort(right)

if __name__ == '__main__':
    arr = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    print(quicksort(arr))
 package com.example.android.miwok;

import android.content.Context;
import android.media.AudioManager;
import android.media.MediaPlayer;
import android.os.Bundle;
import android.support.v4.app.Fragment;
import android
Time for inference 3: 13.00 sec total, 15.39 tokens/sec
Bandwidth achieved: 534.68 GB/s
def quicksort(arr):
    if len(arr) <= 1:
        return arr
    pivot = arr[0]
    left = [x for x in arr[1:] if x < pivot]
    right = [x for x in arr[1:] if x >= pivot]
    return quicksort(left) + [pivot] + quicksort(right)

if __name__ == '__main__':
    arr = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    print(quicksort(arr))
 package com.example.android.miwok;

import android.content.Context;
import android.media.AudioManager;
import android.media.MediaPlayer;
import android.os.Bundle;
import android.support.v4.app.Fragment;
import android
Time for inference 4: 12.59 sec total, 15.88 tokens/sec
Bandwidth achieved: 551.99 GB/s
def quicksort(arr):
    if len(arr) <= 1:
        return arr
    pivot = arr[0]
    left = [x for x in arr[1:] if x < pivot]
    right = [x for x in arr[1:] if x >= pivot]
    return quicksort(left) + [pivot] + quicksort(right)

if __name__ == '__main__':
    arr = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    print(quicksort(arr))
 package com.example.android.miwok;

import android.content.Context;
import android.media.AudioManager;
import android.media.MediaPlayer;
import android.os.Bundle;
import android.support.v4.app.Fragment;
import android
Time for inference 5: 12.64 sec total, 15.82 tokens/sec
Bandwidth achieved: 549.77 GB/s
def quicksort(arr):
    if len(arr) <= 1:
        return arr
    pivot = arr[0]
    left = [x for x in arr[1:] if x < pivot]
    right = [x for x in arr[1:] if x >= pivot]
    return quicksort(left) + [pivot] + quicksort(right)

if __name__ == '__main__':
    arr = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    print(quicksort(arr))
 package com.example.android.miwok;

import android.content.Context;
import android.media.AudioManager;
import android.media.MediaPlayer;
import android.os.Bundle;
import android.support.v4.app.Fragment;
import android
Time for inference 6: 12.95 sec total, 15.45 tokens/sec
Bandwidth achieved: 536.78 GB/s
def quicksort(arr):
    if len(arr) <= 1:
        return arr
    pivot = arr[0]
    left = [x for x in arr[1:] if x < pivot]
    right = [x for x in arr[1:] if x >= pivot]
    return quicksort(left) + [pivot] + quicksort(right)

if __name__ == '__main__':
    arr = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    print(quicksort(arr))
 package com.example.android.miwok;

import android.content.Context;
import android.media.AudioManager;
import android.media.MediaPlayer;
import android.os.Bundle;
import android.support.v4.app.Fragment;
import android
Time for inference 7: 12.73 sec total, 15.71 tokens/sec
Bandwidth achieved: 545.87 GB/s
def quicksort(arr):
    if len(arr) <= 1:
        return arr
    pivot = arr[0]
    left = [x for x in arr[1:] if x < pivot]
    right = [x for x in arr[1:] if x >= pivot]
    return quicksort(left) + [pivot] + quicksort(right)

if __name__ == '__main__':
    arr = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    print(quicksort(arr))
 package com.example.android.miwok;

import android.content.Context;
import android.media.AudioManager;
import android.media.MediaPlayer;
import android.os.Bundle;
import android.support.v4.app.Fragment;
import android
Time for inference 8: 12.73 sec total, 15.71 tokens/sec
Bandwidth achieved: 546.02 GB/s
jianc99 commented 6 days ago

@yifuwang And just use the same code and remove ENABLE_INTRA_NODE_COMM=1, the code can successfully run.

jianc99 commented 6 days ago

And here is the env version @yifuwang : PyTorch version: 2.5.0.dev20240613+cu121 Is debug build: False CUDA used to build PyTorch: 12.1 ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64) GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0 Clang version: Could not collect CMake version: version 3.27.7 Libc version: glibc-2.31

Python version: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:36:13) [GCC 12.3.0] (64-bit runtime) Python platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31 Is CUDA available: True CUDA runtime version: Could not collect CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA A100-SXM4-80GB GPU 1: NVIDIA A100-SXM4-80GB GPU 2: NVIDIA A100-SXM4-80GB GPU 3: NVIDIA A100-SXM4-80GB GPU 4: NVIDIA A100-SXM4-80GB GPU 5: NVIDIA A100-SXM4-80GB GPU 6: NVIDIA A100-SXM4-80GB GPU 7: NVIDIA A100-SXM4-80GB

Nvidia driver version: 535.104.12 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian Address sizes: 46 bits physical, 48 bits virtual CPU(s): 96 On-line CPU(s) list: 0-95 Thread(s) per core: 2 Core(s) per socket: 24 Socket(s): 2 NUMA node(s): 2 Vendor ID: GenuineIntel CPU family: 6 Model: 85 Model name: Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz Stepping: 7 CPU MHz: 2999.998 BogoMIPS: 5999.99 Hypervisor vendor: KVM Virtualization type: full L1d cache: 1.5 MiB L1i cache: 1.5 MiB L2 cache: 48 MiB L3 cache: 71.5 MiB NUMA node0 CPU(s): 0-23,48-71 NUMA node1 CPU(s): 24-47,72-95 Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported Vulnerability L1tf: Mitigation; PTE Inversion Vulnerability Mds: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown Vulnerability Meltdown: Mitigation; PTI Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown Vulnerability Retbleed: Vulnerable Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Vulnerable Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke

Versions of relevant libraries: [pip3] numpy==1.26.4 [pip3] pytorch-triton==3.0.0+45fff310c8 [pip3] torch==2.5.0.dev20240613+cu121 [pip3] torchaudio==2.4.0.dev20240617+cu121 [pip3] torchvision==0.19.0.dev20240617+cu121 [conda] numpy 1.26.4 pypi_0 pypi [conda] pytorch-triton 3.0.0+45fff310c8 pypi_0 pypi [conda] torch 2.5.0.dev20240613+cu121 pypi_0 pypi [conda] torchaudio 2.4.0.dev20240617+cu121 pypi_0 pypi [conda] torchvision 0.19.0.dev20240617+cu121 pypi_0 pypi

jianc99 commented 6 days ago

Add an error information. Sometimes there will be the error message like below. And sometimes it just stuck without any error message. Hope this will be helpful @yifuwang 2%|▏ | 4/200 [07:09<3:14:17, 59.48s/it][rank2]:[E622 01:14:05.084831011 ProcessGroupNCCL.cpp:607] [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=9, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=600000) ran for 600064 milliseconds before timing out. [rank0]:[E622 01:14:05.084836161 ProcessGroupNCCL.cpp:607] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=9, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=600000) ran for 600064 milliseconds before timing out. [rank3]:[E622 01:14:05.084837275 ProcessGroupNCCL.cpp:607] [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=9, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=600000) ran for 600065 milliseconds before timing out. [rank0]:[E622 01:14:05.145697835 ProcessGroupNCCL.cpp:1664] [PG 0 (default_pg) Rank 0] Exception (either an error or timeout) detected by watchdog at work: 9, last enqueued NCCL work: 9, last completed NCCL work: 8. [rank2]:[E622 01:14:05.145700867 ProcessGroupNCCL.cpp:1664] [PG 0 (default_pg) Rank 2] Exception (either an error or timeout) detected by watchdog at work: 9, last enqueued NCCL work: 9, last completed NCCL work: 8. [rank3]:[E622 01:14:05.145704160 ProcessGroupNCCL.cpp:1664] [PG 0 (default_pg) Rank 3] Exception (either an error or timeout) detected by watchdog at work: 9, last enqueued NCCL work: 9, last completed NCCL work: 8. [rank0]:[E622 01:26:52.726366069 ProcessGroupNCCL.cpp:1375] [PG 0 (default_pg) Rank 0] First PG on this rank that detected no heartbeat of its watchdog. [rank2]:[E622 01:26:52.726387839 ProcessGroupNCCL.cpp:1375] [PG 0 (default_pg) Rank 2] First PG on this rank that detected no heartbeat of its watchdog. [rank3]:[E622 01:26:52.726374054 ProcessGroupNCCL.cpp:1375] [PG 0 (default_pg) Rank 3] First PG on this rank that detected no heartbeat of its watchdog. [rank0]:[E622 01:26:52.733272608 ProcessGroupNCCL.cpp:1413] [PG 0 (defaultpg) Rank 0] Heartbeat monitor timed out! Process will be terminated after dumping debug info. workMetaList.size()=1 [rank2]:[E622 01:26:52.733276826 ProcessGroupNCCL.cpp:1413] [PG 0 (defaultpg) Rank 2] Heartbeat monitor timed out! Process will be terminated after dumping debug info. workMetaList.size()=1 [rank3]:[E622 01:26:52.733286283 ProcessGroupNCCL.cpp:1413] [PG 0 (defaultpg) Rank 3] Heartbeat monitor timed out! Process will be terminated after dumping debug info. workMetaList.size()=1 [rank3]:[F622 01:36:52.743006702 ProcessGroupNCCL.cpp:1224] [PG 0 (default_pg) Rank 3] [PG 0 (default_pg) Rank 3] ProcessGroupNCCL's watchdog got stuck for 600 seconds without making progress in monitoring enqueued collectives. This typically indicates a NCCL/CUDA API hang blocking the watchdog, and could be triggered by another thread holding the GIL inside a CUDA api, or other deadlock-prone behaviors.If you suspect the watchdog is not actually stuck and a longer timeout would help, you can either increase the timeout (TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC) to a larger value or disable the heartbeat monitor (TORCH_NCCL_ENABLEMONITORING=0).If either of aforementioned helps, feel free to file an issue to PyTorch about the short timeout or false positive abort; otherwise, please attempt to debug the hang. workMetaList.size() = 1 [rank0]:[F622 01:36:52.743044824 ProcessGroupNCCL.cpp:1224] [PG 0 (default_pg) Rank 0] [PG 0 (default_pg) Rank 0] ProcessGroupNCCL's watchdog got stuck for 600 seconds without making progress in monitoring enqueued collectives. This typically indicates a NCCL/CUDA API hang blocking the watchdog, and could be triggered by another thread holding the GIL inside a CUDA api, or other deadlock-prone behaviors.If you suspect the watchdog is not actually stuck and a longer timeout would help, you can either increase the timeout (TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC) to a larger value or disable the heartbeat monitor (TORCH_NCCL_ENABLEMONITORING=0).If either of aforementioned helps, feel free to file an issue to PyTorch about the short timeout or false positive abort; otherwise, please attempt to debug the hang. workMetaList.size() = 1 [rank2]:[F622 01:36:52.743294147 ProcessGroupNCCL.cpp:1224] [PG 0 (default_pg) Rank 2] [PG 0 (default_pg) Rank 2] ProcessGroupNCCL's watchdog got stuck for 600 seconds without making progress in monitoring enqueued collectives. This typically indicates a NCCL/CUDA API hang blocking the watchdog, and could be triggered by another thread holding the GIL inside a CUDA api, or other deadlock-prone behaviors.If you suspect the watchdog is not actually stuck and a longer timeout would help, you can either increase the timeout (TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC) to a larger value or disable the heartbeat monitor (TORCH_NCCL_ENABLEMONITORING=0).If either of aforementioned helps, feel free to file an issue to PyTorch about the short timeout or false positive abort; otherwise, please attempt to debug the hang. workMetaList.size() = 1 W0622 01:37:23.259000 139910449522496 torch/distributed/elastic/multiprocessing/api.py:858] Sending process 1288873 closing signal SIGTERM W0622 01:37:23.264000 139910449522496 torch/distributed/elastic/multiprocessing/api.py:858] Sending process 1288875 closing signal SIGTERM W0622 01:37:53.265000 139910449522496 torch/distributed/elastic/multiprocessing/api.py:875] Unable to shutdown process 1288875 via 15, forcefully exiting via 9 E0622 01:37:53.711000 139910449522496 torch/distributed/elastic/multiprocessing/api.py:833] failed (exitcode: -6)

yifuwang commented 6 days ago

Hmm I was able to repro. You are right that speculative decoding + INTRA_NODE_COMM=1 can result in a hang after certain iterations. I poked around a bit and had the following observations:

These make me think that the issue is data depedent. Though I don't have a lead on what it could be. Need to spend more time digging.

jianc99 commented 6 days ago

Yes, in my case, it always hung at a specific iteration as well. And actually I also encountered this problem for autoregressive decoding without speculation. But it doesn't occur each time, like a random event.

yifuwang commented 5 days ago

I believe https://github.com/pytorch/pytorch/pull/129501 should fix this issue