ROCm / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
http://pytorch.org
Other
219 stars 50 forks source link

torch multinomial causes severe stall in Hugginface Transformers LLM generation #1407

Open Epliz opened 2 months ago

Epliz commented 2 months ago

🐛 Describe the bug

Hi,

When doing text generation with Mistral 7b with Hugginface transformers on a MI100 GPU, I can see in the collected torch trace that a lot of time is wasted due a hipMemcpyWithStream triggered by torch.multinomial. The hipMemcpyWithStream operation seems to return much later after the previously queued GPU kernels have finished executing. For information, it is responsible for a ~6ms bubble out of ~40ms for the generation of 1 token. Looks like optimizing it would have quite an impact.

Minimal example to collect the trace that can be visualized for example with https://ui.perfetto.dev:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn as nn

model_id = "mistralai/Mistral-7B-Instruct-v0.2"
tokenizer = AutoTokenizer.from_pretrained(model_id,padding_side="left")

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model: nn.Module = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to(device="cuda", dtype=torch.float16)

from typing import List, Union

def generate(model, prompt:Union[str, List[str]], max_new_tokens=20) -> Union[str, List[str]]:
  single_prompt = isinstance(prompt, str)
  if single_prompt:
    prompts = [prompt]
  else:
    prompts = prompt

  with torch.no_grad():
    inputs = tokenizer(prompts, return_tensors="pt", padding="longest").to(device="cuda")
    outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=True)
    texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)

  texts = [text[len(prompts[i]):] for i, text in enumerate(texts)]

  if single_prompt:
    return texts[0]
  else:
    return texts

def time_func(f):
  import time
  start_time = time.time()
  ret = f()
  end_time = time.time()
  elapsed_time = end_time - start_time
  return ret, elapsed_time

def profile_func(f, trace_path= "trace.json"):
  from torch.profiler import profile, ProfilerActivity
  with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
    ret = f()
  prof.export_chrome_trace(trace_path)
  return ret

text, time = time_func(lambda: generate(model, "Hello my name is", 50))
text, time = time_func(lambda: generate(model, "Hello my name is", 50))
text, time = time_func(lambda: generate(model, "Hello my name is", 50))
print("[Optimized] Completion: ", text)
print("[Optimized] Time: ", time)
text, time = profile_func(lambda: time_func(lambda: generate(model, "Hello my name is", 50)), trace_path="trace_orig.json")

Versions

Environment:

PyTorch version: 2.3.0.dev20240204+rocm5.7
Is debug build: False
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: 5.7.31921-d1770ee1b

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.29.2
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.5.0-28-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: AMD Instinct MI100 (gfx908:sramecc+:xnack-)
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: 5.7.31921
MIOpen runtime version: 2.20.0
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      48 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             16
On-line CPU(s) list:                0-15
Vendor ID:                          AuthenticAMD
Model name:                         AMD Ryzen 7 5800X3D 8-Core Processor
CPU family:                         25
Model:                              33
Thread(s) per core:                 2
Core(s) per socket:                 8
Socket(s):                          1
Stepping:                           2
Frequency boost:                    enabled
CPU max MHz:                        4548.8281
CPU min MHz:                        2200.0000
BogoMIPS:                           6800.77
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm
Virtualization:                     AMD-V
L1d cache:                          256 KiB (8 instances)
L1i cache:                          256 KiB (8 instances)
L2 cache:                           4 MiB (8 instances)
L3 cache:                           96 MiB (1 instance)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-15
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Vulnerable: Safe RET, no microcode
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==1.24.1
[pip3] pytorch-triton-rocm==3.0.0+dafe145982
[pip3] torch==2.3.0.dev20240204+rocm5.7
[pip3] torchaudio==2.2.0.dev20240204+rocm5.7
[pip3] torchvision==0.18.0.dev20240204+rocm5.7
[conda] Could not collect

Python packages: (only transformers is relevant besides the torch packages)

accelerate==0.28.0
aiohttp==3.9.3
aiosignal==1.3.1
annotated-types==0.6.0
asttokens==2.4.1
async-timeout==4.0.3
attrs==23.2.0
build==1.2.1
certifi==2022.12.7
charset-normalizer==2.1.1
comm==0.2.1
contourpy==1.2.0
cycler==0.12.1
datasets==2.16.1
debugpy==1.8.1
decorator==5.1.1
deepspeed==0.14.0
diffusers==0.27.2
dill==0.3.7
exceptiongroup==1.2.0
executing==2.0.1
filelock==3.9.0
fonttools==4.48.1
frozenlist==1.4.1
fsspec==2023.10.0
hjson==3.1.0
huggingface-hub==0.20.3
idna==3.4
importlib_metadata==7.1.0
ipykernel==6.29.2
ipython==8.21.0
jedi==0.19.1
Jinja2==3.1.2
joblib==1.3.2
jupyter_client==8.6.0
jupyter_core==5.7.1
kiwisolver==1.4.5
MarkupSafe==2.1.3
matplotlib==3.8.2
matplotlib-inline==0.1.6
mpmath==1.2.1
multidict==6.0.5
multiprocess==0.70.15
nest-asyncio==1.6.0
networkx==3.0rc1
ninja==1.11.1.1
numpy==1.24.1
packaging==23.2
pandas==2.2.0
parso==0.8.3
peft==0.8.2
pexpect==4.9.0
Pillow==9.3.0
platformdirs==4.2.0
prompt-toolkit==3.0.43
psutil==5.9.8
ptyprocess==0.7.0
pure-eval==0.2.2
py-cpuinfo==9.0.0
pyarrow==15.0.0
pyarrow-hotfix==0.6
pydantic==2.6.4
pydantic_core==2.16.3
Pygments==2.17.2
pynvml==11.5.0
pyparsing==3.1.1
pyproject_hooks==1.0.0
python-dateutil==2.8.2
pytorch-triton-rocm==3.0.0+dafe145982
pytz==2024.1
PyYAML==6.0.1
pyzmq==25.1.2
regex==2023.12.25
requests==2.28.1
safetensors==0.4.2
scikit-learn==1.4.0
scipy==1.12.0
six==1.16.0
stack-data==0.6.3
sympy==1.11.1
threadpoolctl==3.2.0
tokenizers==0.15.1
tomli==2.0.1
torch==2.3.0.dev20240204+rocm5.7
torchaudio==2.2.0.dev20240204+rocm5.7
torchvision==0.18.0.dev20240204+rocm5.7
tornado==6.4
tqdm==4.66.1
traitlets==5.14.1
transformers==4.37.2
typing_extensions==4.8.0
tzdata==2023.4
UNKNOWN==0.0.0
urllib3==1.26.13
wcwidth==0.2.13
xxhash==3.4.1
yarl==1.9.4
zipp==3.18.1

Best, Epliz

Epliz commented 2 months ago

also reported at https://github.com/ROCm/clr/issues/78