vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
26.92k stars 3.95k forks source link

[Bug]: vLLM crashes with larger context sizes on TPUs #8318

Open francescov1 opened 1 week ago

francescov1 commented 1 week ago

Your current environment

The output of `python collect_env.py` ```text Collecting environment information... INFO 09-10 05:05:28 importing.py:10] Triton not installed; certain GPU-related functions will not be available. PyTorch version: 2.5.0 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A OS: Debian GNU/Linux 11 (bullseye) (x86_64) GCC version: (Debian 10.2.1-6) 10.2.1 20210110 Clang version: Could not collect CMake version: Could not collect Libc version: glibc-2.31 Python version: 3.10.14 (main, Aug 13 2024, 02:16:06) [GCC 10.2.1 20210110] (64-bit runtime) Python platform: Linux-6.1.85+-x86_64-with-glibc2.31 Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian Address sizes: 48 bits physical, 48 bits virtual CPU(s): 224 On-line CPU(s) list: 0-223 Thread(s) per core: 2 Core(s) per socket: 56 Socket(s): 2 NUMA node(s): 2 Vendor ID: AuthenticAMD CPU family: 25 Model: 1 Model name: AMD EPYC 7B13 Stepping: 0 CPU MHz: 2449.998 BogoMIPS: 4899.99 Hypervisor vendor: KVM Virtualization type: full L1d cache: 3.5 MiB L1i cache: 3.5 MiB L2 cache: 56 MiB L3 cache: 448 MiB NUMA node0 CPU(s): 0-55,112-167 NUMA node1 CPU(s): 56-111,168-223 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Mitigation; safe RET Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP conditional; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext invpcid_single ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr arat npt nrip_save umip vaes vpclmulqdq rdpid fsrm Versions of relevant libraries: [pip3] numpy==1.26.4 [pip3] pyzmq==26.2.0 [pip3] torch==2.5.0 [pip3] torch-xla==2.5.0+git17a4ef5 [pip3] torchvision==0.19.0a0+d23a6e1 [pip3] transformers==4.44.2 [conda] Could not collect ROCM Version: Could not collect Neuron SDK Version: N/A vLLM Version: 0.6.0@1447c97e753919709b613590d7267c93d07d9382 vLLM Build Flags: CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled GPU Topology: Could not collect ```

🐛 Describe the bug

I am trying to deploy llama3.1 8B instruct on GKE 2x4 v5e TPUs. The vLLM server boots up normally and works properly if max-model-len is very low (ie 1024), but crashes with higher context sizes (tried 16k and above). I'm using the Dockerfile.tpu, and the command python3.10 -u -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --model meta-llama/Meta-Llama-3.1-8B-Instruct --tensor-parallel-size 8 --swap-space 16 --disable-log-requests --max-model-len=32768

Here is my full log file: vllm-llama31-tpu.log

Here is my full k8s manifest:

apiVersion: apps/v1
kind: Deployment
metadata:
  name: llama3-8b-instruct-vllm-deployment
  namespace: default
  labels:
    app: llama3-8b-instruct-vllm
spec:
  replicas: 1
  selector:
    matchLabels:
      app: llama3-8b-instruct-vllm
  template:
    metadata:
      labels:
        app: llama3-8b-instruct-vllm
    spec:
      nodeSelector:
        cloud.google.com/gke-tpu-topology: 2x4
        cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
      hostIPC: true
      hostNetwork: true
      containers:
      - name: llama3-8b-instruct-vllm
        image: us-central1-docker.pkg.dev/project-lighthouse-403916/lighthouse-production/vllm-server/vllm-tpu-image:latest
        command: ["/bin/sh", "-c", "python3.10 -u -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --model meta-llama/Meta-Llama-3.1-8B-Instruct --tensor-parallel-size 8 --swap-space 16 --disable-log-requests --max-model-len=32768"]
        env:
        - name: HUGGING_FACE_HUB_TOKEN
          value: <hf-token>
        - name: VLLM_LOGGING_LEVEL
          value: "DEBUG"
        - name: VLLM_TRACE_FUNCTION
          value: "1"
        securityContext:
          privileged: true
        ports:
        - containerPort: 8000
          protocol: TCP
        resources:
          requests:
            google.com/tpu: 8
          limits:
            google.com/tpu: 8

Before submitting a new issue...

youkaichao commented 1 week ago

RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space smem. Used 1.00M of 1.00M smem. Exceeded smem capacity by 1.7K.

looks like some TPU compilation error.

JackCaoG commented 1 week ago

Seem like a smem OOM compiler time error. Given that v5e per chip HBM is only 16G OOMing at 16k seq length is kind of expected.

miladm commented 1 week ago

What's the largest seq len that you can run?

francescov1 commented 1 week ago

@miladm Was able to get it running with --max-model-len 8192, but 16k crashed it

JackCaoG commented 1 week ago

Hey @francescov1 can you use

XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 XLA_SAVE_TENSORS_FMT="hlo" XLA_SAVE_TENSORS_FILE="/tmp/save1.hlo"

and share the /tmp/save1.hlo(i think it will be /tmp/save1.hlo.0 in this case) with us? The error here is that XLA compiler complains that it doesn't have SMEM space(we need to move parameter from HBM to SMEM before executing each op) for this HLO. We can open a bug to the XLA compiler team and see if there is anything we can do about it.

francescov1 commented 6 days ago

@JackCaoG Here you go. File was too big to upload so put on gdrive https://drive.google.com/file/d/1oSY7QqCIsQTZS2tGxJbFV-R_rOepDwHl/view?usp=sharing

Theres a number of other save1.hlo.X files, let me know if you need any others

JackCaoG commented 1 day ago

Thanks @francescov1 I downloaded the HLO, from the HLO it seems like that VLLM is warming up the cache and throw a compilation error in the last compilation. And that compilation error is the RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space smem. Used 1.00M of 1.00M smem right? If so I can just take the last HLO and open a bug to the XLA team.

francescov1 commented 1 day ago

Yep @JackCaoG that's what i saw too. Sounds good, let me know if you need anything else from me

and could you send the link to the issue or tag this one in it so i can track it