ggerganov / llama.cpp

LLM inference in C/C++
MIT License
67.38k stars 9.67k forks source link

Bug: Nondeterministic results on AMD RDNA3 (ROCm) despite zero temperature and fixed seed #10197

Open Googulator opened 2 hours ago

Googulator commented 2 hours ago

What happened?

We are running llama-server on a Radeon RX 7900 XT, with the command line ./llama-server -t 4 -ngl 50 -c 13000 --host 0.0.0.0 --port 18080 --mlock -m mistral-nemo-instruct-2407-q8_0.gguf --chat-template llama2.

Upon calling the server repeatedly ("completion" endpoint) with the following JSON request:

{
  "seed":1234,
  "cache_prompt":false,
  "n_predict":2048,
  "temperature":0,
  "stop":[
    "</s>"
  ],
  "image_data":[],
  "repeat_last_n":0,
  "repeat_penalty":1,
  "presence_penalty":0,
  "frequency_penalty":0,
  "min_p":0.05,
  "top_k":40,
  "top_p":0.95,
  "tfs_z":1,
  "typical_p":1,
  "mirostat":0,
  "mirostat_tau":5,
  "mirostat_eta":0.1,
  "grammar":"",
  "n_probs":0,
  "n_keep":-1,
  "penalize_nl":false,
  "penalty_prompt":null,
  "ignore_eos":false,
  "logit_bias":[],
  "stream":false,
  "prompt":"[INST] What are the two major political parties in the United States?[/INST]"
}

...we get inconsistent output between calls, despite temperature being 0, and using a fixed seed.

We have found the following workarounds, which all result in deterministic output:

ROCm version is 6.2.2 (running in a Docker container); the amdgpu kernel driver is the one supplied with Ubuntu kernel 6.8.0-47-generic (x86-64).

Name and Version

$ ./llama-server --version ggml_cuda_init: GGML_CUDA_FORCE_MMQ: yes ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no ggml_cuda_init: found 1 ROCm devices: Device 0: Radeon RX 7900 XT, compute capability 11.0, VMM: no version: 0 (unknown) built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu

$ git show commit b8deef0ec0af5febac1d2cfd9119ff330ed0b762 (HEAD -> master, tag: b4034, origin/master, origin/HEAD) Author: Gabe Goodhart ghart@us.ibm.com Date: Tue Nov 5 05:23:04 2024 -0700

llama : add <|tool_call|> formatting to Granite template (#10177)

Branch: GraniteToolCallTemplate

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

What operating system are you seeing the problem on?

Linux

Relevant log output

No response

slaren commented 2 hours ago

From what you are describing, my conclusion would be that the ROCm version of cuBLAS is not deterministic.

Googulator commented 2 hours ago

In that case, it would be cublasGemmEx specifically, since forcing the cublasSgemm version results in deterministic output.

Looking at that CC check, it seems to be checking for tensor cores, which the RDNA family of GPUs indeed doesn't have, so using the cublasSgemm version makes more sense on RDNAx GPUs.

What's not clear is why without Flash Attention, even force-MMQ doesn't help.

slaren commented 2 hours ago

The CC check is intended to determine if the GPU has fast enough F16 matrix multiplication that it may be worth converting the operands to F16, but that was written for NVIDIA GPUs, and I don't think that there was testing done on AMD hardware. Without flash attention, the matrix multiplications in the attention will be done with cuBLAS, so that should explain the difference.