Infini-AI-Lab / MagicDec

Breaking Throughput-Latency Trade-off for Long Sequences with Speculative Decoding
Apache License 2.0
82 stars 4 forks source link

Hanging on multiple GPU clusters #2

Open YJHMITWEB opened 2 months ago

YJHMITWEB commented 2 months ago

Hi, thanks for your great work.

I am following the instructions to install and run the test scripts.

I tried two systems, one with 4xA100 40G, the other with 4xA100 80G.

I use the following command to run:

ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=4 tests/longspec_benchmark.py --target checkpoints/togethercomputer/LLaMA-2-7B-32K/model.pth --model checkpoints/TinyLlama/TinyLlama_v1.1/model.pth --model_name /home/users/jyao/ai_general_research/yueying/MagicDec/checkpoints/togethercomputer/LLaMA-2-7B-32K --rank_group 0 1 2 3 --draft_ranks 0 1 2 3 --gamma 3 --B 4 --prefix_len 16000 --gen_len 64 --streamingllm_budget 256 --benchmark

Both systems will hang at some point: image

For example, it stops here. The nvidia-smi shows $100\%$ of utilization, but I believe they are all idle due to the low power usage. image

YJHMITWEB commented 2 months ago

A follow-up here, after hanging for a while: image

jianc99 commented 2 months ago

Hi, did you use the latest pytorch nightly? We encountered this issue before, and it has been fixed in the latest pytorch nightly version.

YJHMITWEB commented 2 months ago

Hi @jianc99 , I believe so. I followed the instructions to configure it.

cd MagicDec
pip install -r requirements.txt
pip install flash-attn --no-build-isolation
pip install torch==2.5.0.dev20240813+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121/

The conda list is:

$ conda list
# packages in environment at /home/miniconda3/envs/magicdec:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main  
_openmp_mutex             5.1                       1_gnu  
aiohappyeyeballs          2.4.0                    pypi_0    pypi
aiohttp                   3.10.5                   pypi_0    pypi
aiosignal                 1.3.1                    pypi_0    pypi
attrs                     24.2.0                   pypi_0    pypi
bzip2                     1.0.8                h5eee18b_6  
ca-certificates           2024.7.2             h06a4308_0  
certifi                   2024.8.30                pypi_0    pypi
charset-normalizer        3.3.2                    pypi_0    pypi
datasets                  2.16.1                   pypi_0    pypi
dill                      0.3.7                    pypi_0    pypi
einops                    0.8.0                    pypi_0    pypi
filelock                  3.16.0                   pypi_0    pypi
flash-attn                2.6.3                    pypi_0    pypi
frozenlist                1.4.1                    pypi_0    pypi
fsspec                    2023.10.0                pypi_0    pypi
huggingface-hub           0.24.6                   pypi_0    pypi
idna                      3.8                      pypi_0    pypi
jinja2                    3.1.4                    pypi_0    pypi
ld_impl_linux-64          2.38                 h1181459_1  
libffi                    3.4.4                h6a678d5_1  
libgcc-ng                 11.2.0               h1234567_1  
libgomp                   11.2.0               h1234567_1  
libstdcxx-ng              11.2.0               h1234567_1  
libuuid                   1.41.5               h5eee18b_0  
markupsafe                2.1.5                    pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
multidict                 6.0.5                    pypi_0    pypi
multiprocess              0.70.15                  pypi_0    pypi
ncurses                   6.4                  h6a678d5_0  
networkx                  3.3                      pypi_0    pypi
ninja                     1.11.1.1                 pypi_0    pypi
numpy                     1.26.3                   pypi_0    pypi
nvidia-cublas-cu12        12.1.3.1                 pypi_0    pypi
nvidia-cuda-cupti-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-runtime-cu12  12.1.105                 pypi_0    pypi
nvidia-cudnn-cu12         9.1.0.70                 pypi_0    pypi
nvidia-cufft-cu12         11.0.2.54                pypi_0    pypi
nvidia-curand-cu12        10.3.2.106               pypi_0    pypi
nvidia-cusolver-cu12      11.4.5.107               pypi_0    pypi
nvidia-cusparse-cu12      12.1.0.106               pypi_0    pypi
nvidia-nccl-cu12          2.21.5                   pypi_0    pypi
nvidia-nvjitlink-cu12     12.6.68                  pypi_0    pypi
nvidia-nvtx-cu12          12.1.105                 pypi_0    pypi
openssl                   3.0.15               h5eee18b_0  
packaging                 24.1                     pypi_0    pypi
pandas                    2.2.2                    pypi_0    pypi
pip                       24.2            py311h06a4308_0  
protobuf                  5.28.0                   pypi_0    pypi
pyarrow                   17.0.0                   pypi_0    pypi
pyarrow-hotfix            0.6                      pypi_0    pypi
python                    3.11.9               h955ad1f_0  
python-dateutil           2.9.0.post0              pypi_0    pypi
pytorch-triton            3.0.0+dedb7bdf33          pypi_0    pypi
pytz                      2024.1                   pypi_0    pypi
pyyaml                    6.0.2                    pypi_0    pypi
readline                  8.2                  h5eee18b_0  
regex                     2024.7.24                pypi_0    pypi
requests                  2.32.3                   pypi_0    pypi
safetensors               0.4.5                    pypi_0    pypi
sentencepiece             0.2.0                    pypi_0    pypi
setuptools                72.1.0          py311h06a4308_0  
six                       1.16.0                   pypi_0    pypi
sqlite                    3.45.3               h5eee18b_0  
sympy                     1.13.1                   pypi_0    pypi
tiktoken                  0.7.0                    pypi_0    pypi
tk                        8.6.14               h39e8969_0  
tokenizers                0.15.2                   pypi_0    pypi
torch                     2.5.0.dev20240813+cu121          pypi_0    pypi
tqdm                      4.66.5                   pypi_0    pypi
transformers              4.36.2                   pypi_0    pypi
triton                    3.0.0                    pypi_0    pypi
typing-extensions         4.12.2                   pypi_0    pypi
tzdata                    2024.1                   pypi_0    pypi
urllib3                   2.2.2                    pypi_0    pypi
wheel                     0.43.0          py311h06a4308_0  
xxhash                    3.5.0                    pypi_0    pypi
xz                        5.4.6                h5eee18b_1  
yarl                      1.10.0                   pypi_0    pypi
zlib                      1.2.13               h5eee18b_1
NonvolatileMemory commented 1 month ago

I also face maybe the same problem.

when bsz is not big enough, e.g., 1, 2, 16, it will have this problem.