ROCm / clr

MIT License
85 stars 35 forks source link

[Issue]: hipMemcpyWithStream causes severe stall in Hugginface Transformers LLM generation with Pytorch #78

Open Epliz opened 1 month ago

Epliz commented 1 month ago

Problem Description

Hi,

When doing text generation with Mistral 7b with Hugginface transformers on a MI100 GPU, I can see in the collected torch trace that a lot of time is wasted due a hipMemcpyWithStream triggered by torch.multinomial. The hipMemcpyWithStream operation seems to return much later after the previously queued GPU kernels have finished executing. For information, it is responsible for a ~6ms bubble out of ~40ms for the generation of 1 token. Looks like optimizing it would have quite an impact for LLM generation (a trendy topic those days).

I would suspect some kind of exponential backoff somewhere that saturates to a way too long wait time maybe.


### Operating System

Ubuntu 22.04.3 LTS (x86_64)

### CPU

AMD Ryzen 7 5800X3D 8-Core Processor

### GPU

AMD Instinct MI100

### ROCm Version

ROCm 6.0.0

### ROCm Component

_No response_

### Steps to Reproduce

Minimal example to collect the trace that can be visualized for example with https://ui.perfetto.dev:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn as nn

model_id = "mistralai/Mistral-7B-Instruct-v0.2"
tokenizer = AutoTokenizer.from_pretrained(model_id,padding_side="left")

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model: nn.Module = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to(device="cuda", dtype=torch.float16)

from typing import List, Union

def generate(model, prompt:Union[str, List[str]], max_new_tokens=20) -> Union[str, List[str]]:
  single_prompt = isinstance(prompt, str)
  if single_prompt:
    prompts = [prompt]
  else:
    prompts = prompt

  with torch.no_grad():
    inputs = tokenizer(prompts, return_tensors="pt", padding="longest").to(device="cuda")
    outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=True)
    texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)

  texts = [text[len(prompts[i]):] for i, text in enumerate(texts)]

  if single_prompt:
    return texts[0]
  else:
    return texts

def time_func(f):
  import time
  start_time = time.time()
  ret = f()
  end_time = time.time()
  elapsed_time = end_time - start_time
  return ret, elapsed_time

def profile_func(f, trace_path= "trace.json"):
  from torch.profiler import profile, ProfilerActivity
  with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
    ret = f()
  prof.export_chrome_trace(trace_path)
  return ret

text, time = time_func(lambda: generate(model, "Hello my name is", 50))
text, time = time_func(lambda: generate(model, "Hello my name is", 50))
text, time = time_func(lambda: generate(model, "Hello my name is", 50))
print("[Optimized] Completion: ", text)
print("[Optimized] Time: ", time)
text, time = profile_func(lambda: time_func(lambda: generate(model, "Hello my name is", 50)), trace_path="trace_orig.json")

### (Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

_No response_

### Additional Information

Environment:

PyTorch version: 2.3.0.dev20240204+rocm5.7 Is debug build: False CUDA used to build PyTorch: N/A ROCM used to build PyTorch: 5.7.31921-d1770ee1b

OS: Ubuntu 22.04.3 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: Could not collect CMake version: version 3.29.2 Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime) Python platform: Linux-6.5.0-28-generic-x86_64-with-glibc2.35 Is CUDA available: True CUDA runtime version: Could not collect CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: AMD Instinct MI100 (gfx908:sramecc+:xnack-) Nvidia driver version: Could not collect cuDNN version: Could not collect HIP runtime version: 5.7.31921 MIOpen runtime version: 2.20.0 Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 48 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 16 On-line CPU(s) list: 0-15 Vendor ID: AuthenticAMD Model name: AMD Ryzen 7 5800X3D 8-Core Processor CPU family: 25 Model: 33 Thread(s) per core: 2 Core(s) per socket: 8 Socket(s): 1 Stepping: 2 Frequency boost: enabled CPU max MHz: 4548.8281 CPU min MHz: 2200.0000 BogoMIPS: 6800.77 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm Virtualization: AMD-V L1d cache: 256 KiB (8 instances) L1i cache: 256 KiB (8 instances) L2 cache: 4 MiB (8 instances) L3 cache: 96 MiB (1 instance) NUMA node(s): 1 NUMA node0 CPU(s): 0-15 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Vulnerable: Safe RET, no microcode Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] numpy==1.24.1 [pip3] pytorch-triton-rocm==3.0.0+dafe145982 [pip3] torch==2.3.0.dev20240204+rocm5.7 [pip3] torchaudio==2.2.0.dev20240204+rocm5.7 [pip3] torchvision==0.18.0.dev20240204+rocm5.7 [conda] Could not collect


Python packages: (only transformers is relevant besides the torch packages)

accelerate==0.28.0 aiohttp==3.9.3 aiosignal==1.3.1 annotated-types==0.6.0 asttokens==2.4.1 async-timeout==4.0.3 attrs==23.2.0 build==1.2.1 certifi==2022.12.7 charset-normalizer==2.1.1 comm==0.2.1 contourpy==1.2.0 cycler==0.12.1 datasets==2.16.1 debugpy==1.8.1 decorator==5.1.1 deepspeed==0.14.0 diffusers==0.27.2 dill==0.3.7 exceptiongroup==1.2.0 executing==2.0.1 filelock==3.9.0 fonttools==4.48.1 frozenlist==1.4.1 fsspec==2023.10.0 hjson==3.1.0 huggingface-hub==0.20.3 idna==3.4 importlib_metadata==7.1.0 ipykernel==6.29.2 ipython==8.21.0 jedi==0.19.1 Jinja2==3.1.2 joblib==1.3.2 jupyter_client==8.6.0 jupyter_core==5.7.1 kiwisolver==1.4.5 MarkupSafe==2.1.3 matplotlib==3.8.2 matplotlib-inline==0.1.6 mpmath==1.2.1 multidict==6.0.5 multiprocess==0.70.15 nest-asyncio==1.6.0 networkx==3.0rc1 ninja==1.11.1.1 numpy==1.24.1 packaging==23.2 pandas==2.2.0 parso==0.8.3 peft==0.8.2 pexpect==4.9.0 Pillow==9.3.0 platformdirs==4.2.0 prompt-toolkit==3.0.43 psutil==5.9.8 ptyprocess==0.7.0 pure-eval==0.2.2 py-cpuinfo==9.0.0 pyarrow==15.0.0 pyarrow-hotfix==0.6 pydantic==2.6.4 pydantic_core==2.16.3 Pygments==2.17.2 pynvml==11.5.0 pyparsing==3.1.1 pyproject_hooks==1.0.0 python-dateutil==2.8.2 pytorch-triton-rocm==3.0.0+dafe145982 pytz==2024.1 PyYAML==6.0.1 pyzmq==25.1.2 regex==2023.12.25 requests==2.28.1 safetensors==0.4.2 scikit-learn==1.4.0 scipy==1.12.0 six==1.16.0 stack-data==0.6.3 sympy==1.11.1 threadpoolctl==3.2.0 tokenizers==0.15.1 tomli==2.0.1 torch==2.3.0.dev20240204+rocm5.7 torchaudio==2.2.0.dev20240204+rocm5.7 torchvision==0.18.0.dev20240204+rocm5.7 tornado==6.4 tqdm==4.66.1 traitlets==5.14.1 transformers==4.37.2 typing_extensions==4.8.0 tzdata==2023.4 UNKNOWN==0.0.0 urllib3==1.26.13 wcwidth==0.2.13 xxhash==3.4.1 yarl==1.9.4 zipp==3.18.1

Epliz commented 1 month ago

also reported at https://github.com/ROCm/pytorch/issues/1407