ROCm / clr

MIT License
104 stars 50 forks source link

[Issue]: hipMemcpyWithStream causes severe stall in Hugginface Transformers LLM generation with Pytorch #78

Open Epliz opened 6 months ago

Epliz commented 6 months ago

Problem Description

Hi,

When doing text generation with Mistral 7b with Hugginface transformers on a MI100 GPU, I can see in the collected torch trace that a lot of time is wasted due a hipMemcpyWithStream triggered by torch.multinomial. The hipMemcpyWithStream operation seems to return much later after the previously queued GPU kernels have finished executing. For information, it is responsible for a ~6ms bubble out of ~40ms for the generation of 1 token. Looks like optimizing it would have quite an impact for LLM generation (a trendy topic those days).

I would suspect some kind of exponential backoff somewhere that saturates to a way too long wait time maybe.


### Operating System

Ubuntu 22.04.3 LTS (x86_64)

### CPU

AMD Ryzen 7 5800X3D 8-Core Processor

### GPU

AMD Instinct MI100

### ROCm Version

ROCm 6.0.0

### ROCm Component

_No response_

### Steps to Reproduce

Minimal example to collect the trace that can be visualized for example with https://ui.perfetto.dev:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn as nn

model_id = "mistralai/Mistral-7B-Instruct-v0.2"
tokenizer = AutoTokenizer.from_pretrained(model_id,padding_side="left")

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model: nn.Module = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to(device="cuda", dtype=torch.float16)

from typing import List, Union

def generate(model, prompt:Union[str, List[str]], max_new_tokens=20) -> Union[str, List[str]]:
  single_prompt = isinstance(prompt, str)
  if single_prompt:
    prompts = [prompt]
  else:
    prompts = prompt

  with torch.no_grad():
    inputs = tokenizer(prompts, return_tensors="pt", padding="longest").to(device="cuda")
    outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=True)
    texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)

  texts = [text[len(prompts[i]):] for i, text in enumerate(texts)]

  if single_prompt:
    return texts[0]
  else:
    return texts

def time_func(f):
  import time
  start_time = time.time()
  ret = f()
  end_time = time.time()
  elapsed_time = end_time - start_time
  return ret, elapsed_time

def profile_func(f, trace_path= "trace.json"):
  from torch.profiler import profile, ProfilerActivity
  with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
    ret = f()
  prof.export_chrome_trace(trace_path)
  return ret

text, time = time_func(lambda: generate(model, "Hello my name is", 50))
text, time = time_func(lambda: generate(model, "Hello my name is", 50))
text, time = time_func(lambda: generate(model, "Hello my name is", 50))
print("[Optimized] Completion: ", text)
print("[Optimized] Time: ", time)
text, time = profile_func(lambda: time_func(lambda: generate(model, "Hello my name is", 50)), trace_path="trace_orig.json")

### (Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

_No response_

### Additional Information

Environment:

PyTorch version: 2.3.0.dev20240204+rocm5.7 Is debug build: False CUDA used to build PyTorch: N/A ROCM used to build PyTorch: 5.7.31921-d1770ee1b

OS: Ubuntu 22.04.3 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: Could not collect CMake version: version 3.29.2 Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime) Python platform: Linux-6.5.0-28-generic-x86_64-with-glibc2.35 Is CUDA available: True CUDA runtime version: Could not collect CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: AMD Instinct MI100 (gfx908:sramecc+:xnack-) Nvidia driver version: Could not collect cuDNN version: Could not collect HIP runtime version: 5.7.31921 MIOpen runtime version: 2.20.0 Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 48 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 16 On-line CPU(s) list: 0-15 Vendor ID: AuthenticAMD Model name: AMD Ryzen 7 5800X3D 8-Core Processor CPU family: 25 Model: 33 Thread(s) per core: 2 Core(s) per socket: 8 Socket(s): 1 Stepping: 2 Frequency boost: enabled CPU max MHz: 4548.8281 CPU min MHz: 2200.0000 BogoMIPS: 6800.77 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm Virtualization: AMD-V L1d cache: 256 KiB (8 instances) L1i cache: 256 KiB (8 instances) L2 cache: 4 MiB (8 instances) L3 cache: 96 MiB (1 instance) NUMA node(s): 1 NUMA node0 CPU(s): 0-15 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Vulnerable: Safe RET, no microcode Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] numpy==1.24.1 [pip3] pytorch-triton-rocm==3.0.0+dafe145982 [pip3] torch==2.3.0.dev20240204+rocm5.7 [pip3] torchaudio==2.2.0.dev20240204+rocm5.7 [pip3] torchvision==0.18.0.dev20240204+rocm5.7 [conda] Could not collect


Python packages: (only transformers is relevant besides the torch packages)

accelerate==0.28.0 aiohttp==3.9.3 aiosignal==1.3.1 annotated-types==0.6.0 asttokens==2.4.1 async-timeout==4.0.3 attrs==23.2.0 build==1.2.1 certifi==2022.12.7 charset-normalizer==2.1.1 comm==0.2.1 contourpy==1.2.0 cycler==0.12.1 datasets==2.16.1 debugpy==1.8.1 decorator==5.1.1 deepspeed==0.14.0 diffusers==0.27.2 dill==0.3.7 exceptiongroup==1.2.0 executing==2.0.1 filelock==3.9.0 fonttools==4.48.1 frozenlist==1.4.1 fsspec==2023.10.0 hjson==3.1.0 huggingface-hub==0.20.3 idna==3.4 importlib_metadata==7.1.0 ipykernel==6.29.2 ipython==8.21.0 jedi==0.19.1 Jinja2==3.1.2 joblib==1.3.2 jupyter_client==8.6.0 jupyter_core==5.7.1 kiwisolver==1.4.5 MarkupSafe==2.1.3 matplotlib==3.8.2 matplotlib-inline==0.1.6 mpmath==1.2.1 multidict==6.0.5 multiprocess==0.70.15 nest-asyncio==1.6.0 networkx==3.0rc1 ninja==1.11.1.1 numpy==1.24.1 packaging==23.2 pandas==2.2.0 parso==0.8.3 peft==0.8.2 pexpect==4.9.0 Pillow==9.3.0 platformdirs==4.2.0 prompt-toolkit==3.0.43 psutil==5.9.8 ptyprocess==0.7.0 pure-eval==0.2.2 py-cpuinfo==9.0.0 pyarrow==15.0.0 pyarrow-hotfix==0.6 pydantic==2.6.4 pydantic_core==2.16.3 Pygments==2.17.2 pynvml==11.5.0 pyparsing==3.1.1 pyproject_hooks==1.0.0 python-dateutil==2.8.2 pytorch-triton-rocm==3.0.0+dafe145982 pytz==2024.1 PyYAML==6.0.1 pyzmq==25.1.2 regex==2023.12.25 requests==2.28.1 safetensors==0.4.2 scikit-learn==1.4.0 scipy==1.12.0 six==1.16.0 stack-data==0.6.3 sympy==1.11.1 threadpoolctl==3.2.0 tokenizers==0.15.1 tomli==2.0.1 torch==2.3.0.dev20240204+rocm5.7 torchaudio==2.2.0.dev20240204+rocm5.7 torchvision==0.18.0.dev20240204+rocm5.7 tornado==6.4 tqdm==4.66.1 traitlets==5.14.1 transformers==4.37.2 typing_extensions==4.8.0 tzdata==2023.4 UNKNOWN==0.0.0 urllib3==1.26.13 wcwidth==0.2.13 xxhash==3.4.1 yarl==1.9.4 zipp==3.18.1

Epliz commented 6 months ago

also reported at https://github.com/ROCm/pytorch/issues/1407

ppanchad-amd commented 2 months ago

Hi @Epliz, internal ticket has been created to investigate your issue. Thanks!

schung-amd commented 2 months ago

Hi @Epliz, are you still experiencing this issue? If so, can you check if IOMMU is enabled in your BIOS? If it is, then make sure you have iommu=pt set in your kernel boot options (see: https://rocm.docs.amd.com/projects/install-on-linux/en/latest/reference/install-faq.html#multi-gpu). If not, then enabling IOMMU along with iommu=pt might help with performance. Also, is your ROCm 6.0 or 5.7? You have ROCm 6.0 listed but your pytorch is for ROCm 5.7. Regardless, you should also see if this slowdown is present in ROCm 6.2 (with the ROCm 6.2 build for torch, torchaudio, and torchvision).

Epliz commented 1 month ago

Thank you @schung-amd for following up. I had not iommu=pt set at that time. I will try to see if that was the culprit.

schung-amd commented 1 month ago

Had a chance to profile this myself on MI100 and saw the same thing. https://discuss.pytorch.org/t/large-time-consuming-cudamemcpyasync-between-forward-and-backward-pass-in-collected-pytorch-profiler-trace/190242 looks the same as the trace I'm seeing, and it is stated there that aten::item signifies device synchronization, so the memcpy is being performed synchronously and the recorded duration of the memcpy includes waiting for GPU execution to finish plus some overhead. This seems to be happening under the hood in mistral/pytorch, likely by design. I suspect this can probably be made faster on MI300, since with unified memory this memcpy would be redundant.

Let me know if you have further questions, I can reach out to our internal teams for more info if necessary.

Epliz commented 1 month ago

I got access to a machine with MI300x GPUs, ROCM 6.2.2 and the stall there is even more severe: 17ms . And by that I mean that according to the trace, there is 17ms where the GPU is doing nothing (no kernel running) and the CPU is still blocked in hipMemcpyWithStream

Epliz commented 1 month ago

According to the trace, the Memcpy DtoH kernel (or command, not sure if it is an actual kernel or a DMA command) took only 16us).

schung-amd commented 1 month ago

Sorry if my statement about MI300 was misleading; I meant that it's probably possible to improve/optimize the code for MI300, not that it should run better there at the moment, as this memcpy will still be performed even though it is redundant.

It is concerning that there is seemingly a large performance loss here on MI300 though, I'll take a look.

Epliz commented 1 month ago

As said in my first message, it seems like the CPU is waiting way past the point where the GPU has finished the work. I understand that there might be a stream or device synchronization underneath and conceptually that's fine. But that synchronization seems poorly implemented. From the symptoms, I would suspect a poorly implemented exponential back off. If there is an exponential back off, it should either be limited in the maximum time, or the increment be scheduled in a way that the additional time is always small compared to the previously waited time, ideally both.

One additional hint that it might be that is that I can see in traces that the stall is shorter if the CPU starts the wait closer to when the GPU work actually ends.

Epliz commented 1 month ago

I can only use my browser right now and not a debugger to check the execution path, but, if it follows hipStreamSynchronize_common() -> hipstream->finish() -> HostQueue::finish() -> command->await completion() -> lock.wait() -> Monitor::wait() then I would not be surprised as there is a sleep of 10ms at some point. Probably the different spinning method thresholds are not chosen so that each next wait is negligible compared to the previous total wait.

Epliz commented 1 month ago

I tried to follow the execution with dynamic debugging (i.e. printf debugging) and it seems like when using hipMemcpyWithStream, it waits for the completion at https://github.com/ROCm/clr/blob/f9f995c6d0aecec28debd3db0f41df0c15ec003d/rocclr/platform/command.cpp#L248 , doing "an active wait". and::Os::yield() though seems to call sched_yield on Linux, which as far as I know will deprioritize the thread, and if called repeatedly will make it sleep more and more. Which might cause the very large sleeps here. I think a more active wait scheme would be preferable here.

Epliz commented 1 month ago

OK, so I tried using the VTune profiler to see what is happening on the CPU side. I am not exactly sure if I am right, but it seems like ihipMemcpy calls awaitCompletion, which calls Command::enqueue() which calls submitMarker() which calls flush() which indirectly calls CpuWaitForSignal which calls hsa_signal_wait_sqacquire and spend a lot of time there. Not sure if that is the problematic place, but maybe.

Epliz commented 4 weeks ago

It seems to me that for some reason the GPU block responsible for signaling the completion of the commands is not sending the signal back to the CPU... If I however implement my own hipMemcpyWithStream method, using internally a hipMemcpyAsync call plus a gpu kernel writing to a locked CPU host memory buffer, spinning on the CPU side until I see the wait value, the synchronization is fast.

The problem is that I can't avoid having at some point a hipDeviceSynchronize or hipMemcpyWithStream call, and the stalling happens there, negating any improvement from using a custom hipMemcpyWithStream.

For reference, because I find it kinda cool, my code for replacing hipMemcpyWithStream:


#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>

#include <xmmintrin.h>  // for _mm_pause

#include <stdint.h>
#include <string.h>
#include <stdio.h>

struct muillm_synchronizer {
  // locked CPU memory to do transfers to
  void* staging_buffer;
  // locked CPU memory for the GPUs to signal completions
  volatile int* signal_buffer;

  // size of the staging buffer in bytes
  size_t staging_buffer_size;

  // current sequential number for the signal
  int seq_no;
};

static void __allocate_signal_buffer(muillm_synchronizer_t* sync) {
  if (hipHostMalloc((void**) &sync->signal_buffer, sizeof(int), 0) != hipSuccess) {
    printf("allocating the signal buffer failedd!\n");
  }
  // initialize to 0 as the first sequential number to wait on will be 1
  memset((void*) sync->signal_buffer, 0, sizeof(int));
}

static size_t next_power_of_2(size_t n) {
  size_t p = 1;

  while (p < n) {
    p *= 2;
  }

  return p;
}

static void __ensure_staging_buffer_capacity(
  muillm_synchronizer_t* sync,
  size_t count,
  hipStream_t stream
) {
  if (sync->staging_buffer_size >= count) {
    // enough space
    return;
  }

  // deallocate the previous memory
  if ((sync->staging_buffer != nullptr) && (hipHostFree(sync->staging_buffer) != hipSuccess)) {
    printf("freeing the staging buffer failedd!\n");
  }

  // find the next power of two for the size to not have to do re-allocations over and over
  count = next_power_of_2(count);

  // allocate a new buffer of the required size
  if (hipHostMalloc((void**) &sync->staging_buffer, count, 0) != hipSuccess) {
    printf("allocating the staging buffer failedd!\n");
  }
  sync->staging_buffer_size = count;
}

muillm_synchronizer_t* muillm_sync_init() {

  // create the comm object
  muillm_synchronizer_t* sync = new muillm_synchronizer_t;
  sync->staging_buffer = nullptr;
  sync->signal_buffer = nullptr;
  sync->staging_buffer_size = 0;
  sync->seq_no = 1; // first value we will wait on

  // allocate the wait buffer
  __allocate_signal_buffer(sync);

  // allocate an initial buffer with an initial size of 4kB
  __ensure_staging_buffer_capacity(sync, 4*1024, 0);

  return sync;
}

// make the GPU signal the completion
// (we use a kernel instead of the signals the HIP runtime uses as it seems that the
// dedicated HW block might hang sometimes?)
__global__ void __signal_kernel(
    volatile int* signal_buffer,
    int seq_no
) {
  if (threadIdx.x != 0) {
    return;
  }
  *signal_buffer = seq_no;
}

static void __spin_pause() {
  _mm_pause();
}

static void __spin_gpu_cpu_sync(
    muillm_synchronizer_t* sync,
    hipStream_t stream) {
  int wait_no = sync->seq_no;

  const int threads_per_blocks = 64;
  const int num_blocks = 1;
  __signal_kernel<<<num_blocks, threads_per_blocks, 0, stream>>>(
    sync->signal_buffer,
    wait_no
  );

  sync->seq_no++; // next sequential number we will wait on

  // spin until the GPU has signaled completion
  while (*sync->signal_buffer != wait_no) {
    __spin_pause();
  }
}

void muillm_sync_copy(
    muillm_synchronizer_t* sync,
    hipStream_t stream,
    void* dst,
    const void* src,
    size_t count
) {
  // ensure the staging buffer is big enough
  __ensure_staging_buffer_capacity(sync, count, stream);

  // do the copy to the staging buffer
  if (hipMemcpyAsync(sync->staging_buffer, src, count, hipMemcpyDeviceToHost, stream) != hipSuccess) {
    printf("async copy failed\n");
  }

  // sync the CPU with the GPU
  __spin_gpu_cpu_sync(sync, stream);

  // do the final copy from the staging buffer to the CPU memory
  memcpy(dst, sync->staging_buffer, count);
}
schung-amd commented 3 weeks ago

Thanks for your interest in this! Just a quick update from my end: we'll be taking a look into this internally, but it might be a couple of weeks before we have a response.

Epliz commented 3 weeks ago

Thank you @schung-amd ,

I have in the meantime just tried to lower the number of such calls. It seems quite hard to me to determine how far from optimal the current situation is, given that there might be also the small issue that pytorch traces might not be super reliable regarding start/completion times of kernels, especially in regards to how the GPU timeline is aligned with the CPU timeline.