turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.18k stars 233 forks source link

ROCM: Issues with wave64 device #510

Closed IMbackK closed 2 days ago

IMbackK commented 1 week ago

it seams there are still issues with wave64 devices

i just tried latest master a cdna1/mi100 with rocm 6.0.2 and pytorch 2.3.0 @ 97ff6cfd9c86c5c09d7ce775ab64ec5c99230f5d

Finding flash_attn
NO flash_attn module
 -- Model: CodeLlama-13B-GPTQ/
 -- Options: ['length: 2048', 'no_flash_attn']
 -- Loading model...
 -- Loaded model in 3.3227 seconds
 -- Loading tokenizer...
 -- Warmup...
/usr/lib/python3.12/site-packages/torch/nn/attention/bias.py:205: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at /build/python-pytorch/src/pytorch-opt-rocm/aten/src/ATen/native/transformers/hip/sdp_utils.cpp:505.)
  return scaled_dot_product_attention(
 -- Generating...

corrupted double-linked list
[1]    315893 IOT instruction (core dumped)  python test_inference.py -m CodeLlama-13B-GPTQ/ -p "hi there" -nfa -l 2048

Sometimes the above dose not trigger and gibberish is emitted instead.

I can give access to the machine in question for debugging

IMbackK commented 1 week ago

what allso often occures is:

 -- Model: CodeLlama-13B-GPTQ/                                                                                                                                                           
 -- Options: ['length: 2048', 'no_flash_attn']
 -- Loading model...
 -- Loaded model in 3.2615 seconds
 -- Loading tokenizer...
 -- Warmup...
[New Thread 0x7ff6ba2006c0 (LWP 320328)]
[Thread 0x7ff6ba2006c0 (LWP 320328) exited]
/usr/lib/python3.12/site-packages/torch/nn/attention/bias.py:205: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at /build/python-pytorch/sr
c/pytorch-opt-rocm/aten/src/ATen/native/transformers/hip/sdp_utils.cpp:505.)
  return scaled_dot_product_attention(
 -- Generating...

Traceback (most recent call last):
  File "/home/philipp/machine-lerning/exllamav2/test_inference.py", line 206, in <module>
    output = generator.generate_simple(args.prompt, settings, args.tokens, token_healing = True, add_bos = not args.prompt_no_bos)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/philipp/machine-lerning/exllamav2/exllamav2/generator/base.py", line 322, in generate_simple
    text = self.tokenizer.decode(decode_ids, decode_special_tokens = decode_special_tokens)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/philipp/machine-lerning/exllamav2/exllamav2/tokenizer/tokenizer.py", line 510, in decode
    texts.append(self.decode_(seq, decode_special_tokens))
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/philipp/machine-lerning/exllamav2/exllamav2/tokenizer/tokenizer.py", line 455, in decode_
    return self.decode_unspecial(seq)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/philipp/machine-lerning/exllamav2/exllamav2/tokenizer/tokenizer.py", line 428, in decode_unspecial
    return self.tokenizer_model.decode(seq)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/philipp/machine-lerning/exllamav2/exllamav2/tokenizer/spm.py", line 52, in decode
    text = self.spm.decode(ids)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/philipp/machine-lerning/exllamav2/venv/lib/python3.12/site-packages/sentencepiece/__init__.py", line 801, in Decode
    return self._DecodeIds(input)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/philipp/machine-lerning/exllamav2/venv/lib/python3.12/site-packages/sentencepiece/__init__.py", line 343, in _DecodeIds
    return _sentencepiece.SentencePieceProcessor__DecodeIds(self, ids)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: Out of range: piece id is out of range.

unfortinatly i cant seam to trigger the assert above in gdb

IMbackK commented 1 week ago

Issue seams very varriable run-to-run, another run another result:

Finding flash_attn                                                                                                                                                                           
NO flash_attn module
 -- Model: CodeLlama-13B-GPTQ/                                                                                                                                                               
 -- Options: ['length: 2048', 'no_flash_attn']
 -- Loading model...
 -- Loaded model in 3.2214 seconds
 -- Loading tokenizer...
 -- Warmup...
[New Thread 0x7ff6ba2006c0 (LWP 321691)]
[Thread 0x7ff6ba2006c0 (LWP 321691) exited]
/usr/lib/python3.12/site-packages/torch/nn/attention/bias.py:205: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at /build/python-pytorch/src/pytorch-opt-rocm/aten/src/ATen/native/transformers/hip/sdp_utils.cpp:505.)
  return scaled_dot_product_attention(
 -- Generating...

Thread 1 "pt_main_thread" received signal SIGSEGV, Segmentation fault.
0x00007ffad265c73e in multinomial_cpu (num_candidates=0, temp_probs=0x555566a21f20, temp_indices=0x5555669fec80, random=0.276738256)
    at /home/philipp/machine-lerning/exllamav2/exllamav2/exllamav2_ext/cpp/sampling.cpp:823
823             accum += temp_probs[idx];
(gdb) bt
#0  0x00007ffad265c73e in multinomial_cpu (num_candidates=0, temp_probs=0x555566a21f20, temp_indices=0x5555669fec80, random=0.276738256)
    at /home/philipp/machine-lerning/exllamav2/exllamav2/exllamav2_ext/cpp/sampling.cpp:823
#1  0x00007ffad2685c8c in sample_basic (logits=..., temperature=1, top_k=0, top_p=0.800000012, top_a=0, min_p=0, tfs=0, typical=0, random=0.276738256, output_tokens=..., output_probs=..., 
    output_kprobs=..., output_ktokens=..., logit_filter=..., mirostat=false, mirostat_mu=std::vector of length 0, capacity 0, mirostat_tau=1.5, mirostat_eta=0.100000001, 
    post_temperature=1, min_temp=0, max_temp=0, temp_exponent=1, smoothing_factor=0, skew=0) at /home/philipp/machine-lerning/exllamav2/exllamav2/exllamav2_ext/ext_sampling_hip.cpp:234
#2  0x00007ffad266c0a6 in pybind11::cpp_function::initialize<std::vector<float, std::allocator<float> > (*&)(at::Tensor, float, int, float, float, float, float, float, float, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, bool, std::vector<float, std::allocator<float> >&, float, float, float, float, float, float, float, float), std::vector<float, std::allocator<float> >, at::Tensor, float, int, float, float, float, float, float, float, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, bool, std::vector<float, std::allocator<float> >&, float, float, float, float, float, float, float, float, pybind11::name, pybind11::scope, pybind11::sibling, char [13]>(std::vector<float, std::allocator<float> > (*&)(at::Tensor, float, int, float, float, float, float, float, float, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, bool, std::vector<float, std::allocator<float> >&, float, float, float, float, float, float, float, float), std::vector<float, std::allocator<float> > (*)(at::Tensor, float, int, float, float, float, float, float, float, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, bool, std::vector<float, std::allocator<float> >&, float, float, float, float, float, float, float, float), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&, char const (&) [13])::{lambda(pybind11::detail::function_call&)#1}::_FUN(pybind11::detail::function_call&) () at /usr/include/pybind11/detail/../cast.h:1613
#3  0x00007ffad26553b1 in pybind11::cpp_function::dispatcher (self=<optimized out>, args_in=0x7ffacc94e200, kwargs_in=<optimized out>) at /usr/include/pybind11/pybind11.h:987
#4  0x00007ffff79a52c6 in cfunction_call (func=0x7ffad7209fd0, args=0x7ffacc94e200, kwargs=0x0) at Objects/methodobject.c:537
#5  0x00007ffff798550b in _PyObject_MakeTpCall (tstate=0x7ffff7e22ae8 <_PyRuntime+459656>, callable=0x7ffad7209fd0, args=<optimized out>, nargs=24, keywords=0x0) at Objects/call.c:240
#6  0x00007ffff788bdfa in _PyEval_EvalFrameDefault (tstate=<optimized out>, frame=0x7ffff7f77308, throwflag=<optimized out>) at Python/bytecodes.c:2706
#7  0x00007ffff7a3d767 in PyEval_EvalCode (co=0x5555557822b0, globals=<optimized out>, locals=0x7ffff71fa000) at Python/ceval.c:578
#8  0x00007ffff7a608b7 in run_eval_code_obj (tstate=tstate@entry=0x7ffff7e22ae8 <_PyRuntime+459656>, co=co@entry=0x5555557822b0, globals=globals@entry=0x7ffff71fa000, 
    locals=locals@entry=0x7ffff71fa000) at Python/pythonrun.c:1722
#9  0x00007ffff7a5b9dc in run_mod (mod=mod@entry=0x55555572cb40, filename=filename@entry=0x7ffff7148d50, globals=globals@entry=0x7ffff71fa000, locals=locals@entry=0x7ffff71fa000, 
    flags=flags@entry=0x7fffffffd510, arena=arena@entry=0x7ffff711be30) at Python/pythonrun.c:1743
#10 0x00007ffff7a74f33 in pyrun_file (fp=fp@entry=0x55555560dc30, filename=filename@entry=0x7ffff7148d50, start=start@entry=257, globals=globals@entry=0x7ffff71fa000, 
    locals=locals@entry=0x7ffff71fa000, closeit=closeit@entry=1, flags=0x7fffffffd510) at Python/pythonrun.c:1643
#11 0x00007ffff7a74346 in _PyRun_SimpleFileObject (fp=0x55555560dc30, filename=0x7ffff7148d50, closeit=1, flags=0x7fffffffd510) at Python/pythonrun.c:433
#12 0x00007ffff7a73f88 in _PyRun_AnyFileObject (fp=0x55555560dc30, filename=0x7ffff7148d50, closeit=1, flags=0x7fffffffd510) at Python/pythonrun.c:78
#13 0x00007ffff7a6cc67 in pymain_run_file_obj (skip_source_first_line=0, filename=0x7ffff7148d50, program_name=0x7ffff712e550) at Modules/main.c:360
#14 pymain_run_file (config=0x7ffff7dc56c8 <_PyRuntime+77672>) at Modules/main.c:379
#15 pymain_run_python (exitcode=0x7fffffffd4e4) at Modules/main.c:629
#16 Py_RunMain () at Modules/main.c:709
#17 0x00007ffff7a28fab in Py_BytesMain (argc=<optimized out>, argv=<optimized out>) at Modules/main.c:763
#18 0x00007ffff7639c88 in __libc_start_call_main (main=main@entry=0x555555555120 <main>, argc=argc@entry=9, argv=argv@entry=0x7fffffffd778) at ../sysdeps/nptl/libc_start_call_main.h:58
#19 0x00007ffff7639d4c in __libc_start_main_impl (main=0x555555555120 <main>, argc=9, argv=0x7fffffffd778, init=<optimized out>, fini=<optimized out>, rtld_fini=<optimized out>, 
    stack_end=0x7fffffffd768) at ../csu/libc-start.c:360
#20 0x0000555555555045 in _start ()
turboderp commented 1 week ago

It looks like all of these issues might be related to SDPA. It's supposed to be a safe fallback as long as lower-right causal masking is supported, but it's possible that's not working for wave64 for some reason. Could you try disabling this in exllamav2/attn.py:

has_lower_right_sdpa = False
try:
    from torch.nn.attention.bias import causal_lower_right
    has_lower_right_sdpa = True   # <-- remove this
except ImportError:
    pass

That should make it stop trying to use SDPA and fall back on matmul attention.

IMbackK commented 1 week ago

This appears not to help

Finding flash_attn
NO flash_attn module
> /home/philipp/machine-lerning/exllamav2/exllamav2/attn.py(62)<module>()
-> def assert_paged_attn():
(Pdb) c
 -- Model: CodeLlama-13B-GPTQ/
 -- Options: ['length: 2048', 'no_flash_attn']
 -- Loading model...
 -- Loaded model in 3.5775 seconds
 -- Loading tokenizer...
 -- Warmup...
[New Thread 0x7ff6bd8006c0 (LWP 334956)]
[Thread 0x7ff6bd8006c0 (LWP 334956) exited]

 -- Generating...

Thread 1 "pt_main_thread" received signal SIGSEGV, Segmentation fault.
0x00007ffad265c73e in multinomial_cpu (num_candidates=0, temp_probs=0x555563ba9400, temp_indices=0x555563bc88c0, random=0.411402971)
    at /home/philipp/machine-lerning/exllamav2/exllamav2/exllamav2_ext/cpp/sampling.cpp:823
823             accum += temp_probs[idx];
(gdb) bt
#0  0x00007ffad265c73e in multinomial_cpu (num_candidates=0, temp_probs=0x555563ba9400, temp_indices=0x555563bc88c0, random=0.411402971)
    at /home/philipp/machine-lerning/exllamav2/exllamav2/exllamav2_ext/cpp/sampling.cpp:823
#1  0x00007ffad2685c8c in sample_basic (logits=..., temperature=1, top_k=0, top_p=0.800000012, top_a=0, min_p=0, tfs=0, typical=0, random=0.411402971, output_tokens=..., output_probs=..., 
    output_kprobs=..., output_ktokens=..., logit_filter=..., mirostat=false, mirostat_mu=std::vector of length 0, capacity 0, mirostat_tau=1.5, mirostat_eta=0.100000001, 
    post_temperature=1, min_temp=0, max_temp=0, temp_exponent=1, smoothing_factor=0, skew=0) at /home/philipp/machine-lerning/exllamav2/exllamav2/exllamav2_ext/ext_sampling_hip.cpp:234
#2  0x00007ffad266c0a6 in pybind11::cpp_function::initialize<std::vector<float, std::allocator<float> > (*&)(at::Tensor, float, int, float, float, float, float, float, float, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, bool, std::vector<float, std::allocator<float> >&, float, float, float, float, float, float, float, float), std::vector<float, std::allocator<float> >, at::Tensor, float, int, float, float, float, float, float, float, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, bool, std::vector<float, std::allocator<float> >&, float, float, float, float, float, float, float, float, pybind11::name, pybind11::scope, pybind11::sibling, char [13]>(std::vector<float, std::allocator<float> > (*&)(at::Tensor, float, int, float, float, float, float, float, float, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, bool, std::vector<float, std::allocator<float> >&, float, float, float, float, float, float, float, float), std::vector<float, std::allocator<float> > (*)(at::Tensor, float, int, float, float, float, float, float, float, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, bool, std::vector<float, std::allocator<float> >&, float, float, float, float, float, float, float, float), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&, char const (&) [13])::{lambda(pybind11::detail::function_call&)#1}::_FUN(pybind11::detail::function_call&) () at /usr/include/pybind11/detail/../cast.h:1613
#3  0x00007ffad26553b1 in pybind11::cpp_function::dispatcher (self=<optimized out>, args_in=0x7ffaac8108b0, kwargs_in=<optimized out>) at /usr/include/pybind11/pybind11.h:987
#4  0x00007ffff79a52c6 in ?? () from /usr/lib/libpython3.12.so.1.0
#5  0x00007ffff798550b in _PyObject_MakeTpCall () from /usr/lib/libpython3.12.so.1.0
#6  0x00007ffff788bdfa in ?? () from /usr/lib/libpython3.12.so.1.0
#7  0x00007ffff7a3d767 in PyEval_EvalCode () from /usr/lib/libpython3.12.so.1.0
#8  0x00007ffff7a608b7 in ?? () from /usr/lib/libpython3.12.so.1.0
#9  0x00007ffff7a5b9dc in ?? () from /usr/lib/libpython3.12.so.1.0
#10 0x00007ffff7a74f33 in ?? () from /usr/lib/libpython3.12.so.1.0
#11 0x00007ffff7a74346 in _PyRun_SimpleFileObject () from /usr/lib/libpython3.12.so.1.0
#12 0x00007ffff7a73f88 in _PyRun_AnyFileObject () from /usr/lib/libpython3.12.so.1.0
#13 0x00007ffff7a6cc67 in Py_RunMain () from /usr/lib/libpython3.12.so.1.0
#14 0x00007ffff7a28fab in Py_BytesMain () from /usr/lib/libpython3.12.so.1.0
#15 0x00007ffff7639c88 in ?? () from /usr/lib/libc.so.6
#16 0x00007ffff7639d4c in __libc_start_main () from /usr/lib/libc.so.6
#17 0x0000555555555045 in _start ()

Wierdly it now dose "work" more often (the above segfault happens about 30% of the time) when it works it outputs only one somtimes 2 tokens, but they do seam coherent.

turboderp commented 1 week ago

num_candidates=0 is a clue that something goes wrong in the sampling. Which could be the result of corrupted logits, I guess, since the only reason I can think of for ending up with no candidate tokens is if there's some sorting error involving NaNs. You could try greedy sampling to see if it still crashes?

My guess is it won't, but it will produce nonsense output, in which case I'll have to figure out some sort of remote debugging (or maybe look for some really cheap wave64 GPU). But on the off chance greedy decoding works it's likely a CPU issue.

IMbackK commented 1 week ago

I will try that. Failing this i can give you a mi50 as a gift for debugging, i can also give you ssh access to the mi100 machine.

IMbackK commented 1 week ago

So i appended print(f"logits nan: {torch.isnan(logits).sum()}") here: https://github.com/turboderp/exllamav2/blob/5996922a0f0937aa503efa773780f1648915d73e/exllamav2/generator/base.py#L259

This is the result of a typical non-crashing run:

python test_inference.py -m CodeLlama-13B-GPTQ/ -p "int main(" -nfa -l 2048 
Finding flash_attn
NO flash_attn module
 -- Model: CodeLlama-13B-GPTQ/
 -- Options: ['length: 2048', 'no_flash_attn']
 -- Loading model...
 -- Loaded model in 3.5639 seconds
 -- Loading tokenizer...
 -- Warmup...
 -- Generating...

logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 32016
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
logits nan: 0
int main(void) {
turboderp commented 1 week ago

Have you only tested with GPTQ models, or do EXL2 models have the same issue? They do use different matmul kernels.

IMbackK commented 1 week ago

only gptq so far, let me quantize something exl2

turboderp commented 1 week ago

Probably better to download a known-good model to isolate it to whatever goes wrong in inference. Like one of these

IMbackK commented 6 days ago

I was going to quant with a wave32 device, but that works too ofc. Looks like no difference with exl2:

Finding flash_attn
NO flash_attn module
 -- Model: ../KoboldAI-Client/models/llama-3-8b-exl2
 -- Options: ['length: 2048', 'no_flash_attn']
 -- Loading model...
 -- Loaded model in 2.8434 seconds
 -- Loading tokenizer...
 -- Warmup...
 -- Generating...

sqeuence nan: 0
logits nan: 0
sqeuence nan: 0
logits nan: 128256
[UVOSLinux:77515:0:77515] Caught signal 11 (Segmentation fault: address not mapped to object at address 0x63e2cb862000)
IMbackK commented 6 days ago

Ok so further development i made this small repoducer:

from exllamav2.generator import ExLlamaV2BaseGenerator
from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache
import torch

config = ExLlamaV2Config()
config.model_dir = "../KoboldAI-Client/models/llama-3-8b-exl2"
config.prepare()
config.max_seq_len = 2048
config.no_flash_attn = True

model = ExLlamaV2(config)
model.load()
cache = ExLlamaV2Cache(model)

nansum = 0
while nansum == 0:
    logits = model.forward(torch.tensor([[81422]]), cache, input_mask = None, loras = None, position_offsets = None, indexed_embeddings = None).float().cpu()
    nansum = torch.isnan(logits).sum()
    print(nansum.item())

On the mi100 it looks like this:

python repo.py                 
Finding flash_attn
NO flash_attn module
0
0
0
0
0
0
0
0
0
0
0
0
0
128256
python repo.py

while on a rx 6800xt (wave32) device it continues with 0 indefinitely

changing cache = ExLlamaV2Cache(model) to cache = None makes the mi100 also never produce any nans so something fishy is there

turboderp commented 6 days ago

Okay, so that does help a bit. Matmul kernel might still be failing in both cases, but it's also possible it's one of the other kernels, and my first suspect would be the layernorm kernels since they're the only places where the warp size should even matter.

One way to test would be to rename the forward_torch function in attn.py and mlp.py, so that it's always called in place of forward. This would split up the operations and allow you to find exactly where something goes wrong:

    def test_tensor(name, tensor):
        nansum = torch.isnan(logits).sum()
        infsum = torch.isinf(logits).sum()
        if nansum.item() > 0: print(name, "NaN", nansum.item())
        if infsum.item() > 0: print(name, "inf", infsum.item())
        residual = hidden_states
        post_norm = self.input_layernorm.forward(hidden_states) if self.has_norm else hidden_states

        test_tensor("attn post_norm", post_norm)  # <---

        query_states = self.q_proj.forward(post_norm, loras = loras)
        key_states = self.k_proj.forward(post_norm, loras = loras)
        value_states = self.v_proj.forward(post_norm, loras = loras)

        test_tensor("attn q", query_states)  # <---
        test_tensor("attn k", key_states)  # <---
        test_tensor("attn v", value_states)  # <---

The method you have there calling forward() over and over again will perform the same computation on each pass without a cache, but with a cache it will grow the cache with each forward pass. You can do cache.current_seq_len = 0 before the forward pass to prevent that.

IMbackK commented 6 days ago

Looks like it goes wrong in attn post_norm

Finding flash_attn
NO flash_attn module
0
attn post_norm NaN 4096
attn q NaN 4096
attn k NaN 1024
attn v NaN 1024
mlp post_norm NaN 4096
mlp gate_proj NaN 14336
mlp up_proj NaN 14336
mlp down_proj NaN 4096
attn hidden_states NaN 4096
attn post_norm NaN 4096
attn q NaN 4096
attn k NaN 1024
attn v NaN 1024
mlp post_norm NaN 4096
mlp gate_proj NaN 14336
mlp up_proj NaN 14336
mlp down_proj NaN 4096
attn hidden_states NaN 4096
attn hidden_states NaN 4096
attn hidden_states NaN 4096
attn hidden_states NaN 4096
attn hidden_states NaN 4096
attn hidden_states NaN 4096
attn hidden_states NaN 4096
attn hidden_states NaN 4096
attn hidden_states NaN 4096
attn post_norm NaN 4096
attn q NaN 4096
attn k NaN 1024
attn v NaN 1024
mlp post_norm NaN 4096
mlp gate_proj NaN 14336
mlp up_proj NaN 14336
mlp down_proj NaN 4096
attn hidden_states NaN 4096
attn post_norm NaN 4096
attn q NaN 4096
attn k NaN 1024
attn v NaN 1024
mlp post_norm NaN 4096
mlp gate_proj NaN 14336
mlp up_proj NaN 14336
mlp down_proj NaN 4096
attn hidden_states NaN 4096
attn post_norm NaN 4096
attn q NaN 4096
attn k NaN 1024
attn v NaN 1024
mlp post_norm NaN 4096
mlp gate_proj NaN 14336
mlp up_proj NaN 14336
mlp down_proj NaN 4096
attn hidden_states NaN 4096
attn post_norm NaN 4096
attn q NaN 4096
attn k NaN 1024
attn v NaN 1024
mlp post_norm NaN 4096
mlp gate_proj NaN 14336
mlp up_proj NaN 14336
mlp down_proj NaN 4096
attn hidden_states NaN 4096
attn post_norm NaN 4096
attn q NaN 4096
attn k NaN 1024
attn v NaN 1024
mlp post_norm NaN 4096
mlp gate_proj NaN 14336
mlp up_proj NaN 14336
mlp down_proj NaN 4096
attn hidden_states NaN 4096
attn post_norm NaN 4096
attn q NaN 4096
attn k NaN 1024
attn v NaN 1024
mlp post_norm NaN 4096
mlp gate_proj NaN 14336
mlp up_proj NaN 14336
mlp down_proj NaN 4096
attn hidden_states NaN 4096
attn post_norm NaN 4096
attn q NaN 4096
attn k NaN 1024
attn v NaN 1024
mlp post_norm NaN 4096
mlp gate_proj NaN 14336
mlp up_proj NaN 14336
mlp down_proj NaN 4096
attn hidden_states NaN 4096
attn post_norm NaN 4096
attn q NaN 4096
attn k NaN 1024
attn v NaN 1024
mlp post_norm NaN 4096
mlp gate_proj NaN 14336
mlp up_proj NaN 14336
mlp down_proj NaN 4096
attn hidden_states NaN 4096
attn post_norm NaN 4096
attn q NaN 4096
attn k NaN 1024
attn v NaN 1024
mlp post_norm NaN 4096
mlp gate_proj NaN 14336
mlp up_proj NaN 14336
mlp down_proj NaN 4096
attn hidden_states NaN 4096
attn post_norm NaN 4096
attn q NaN 4096
attn k NaN 1024
attn v NaN 1024
mlp post_norm NaN 4096
mlp gate_proj NaN 14336
mlp up_proj NaN 14336
mlp down_proj NaN 4096
attn hidden_states NaN 4096
attn post_norm NaN 4096
attn q NaN 4096
attn k NaN 1024
attn v NaN 1024
mlp post_norm NaN 4096
mlp gate_proj NaN 14336
mlp up_proj NaN 14336
mlp down_proj NaN 4096
attn hidden_states NaN 4096
attn post_norm NaN 4096
attn q NaN 4096
attn k NaN 1024
attn v NaN 1024
mlp post_norm NaN 4096
mlp gate_proj NaN 14336
mlp up_proj NaN 14336
mlp down_proj NaN 4096
attn hidden_states NaN 4096
attn post_norm NaN 4096
attn q NaN 4096
attn k NaN 1024
attn v NaN 1024
mlp post_norm NaN 4096
mlp gate_proj NaN 14336
mlp up_proj NaN 14336
mlp down_proj NaN 4096
attn hidden_states NaN 4096
attn post_norm NaN 4096
attn q NaN 4096
attn k NaN 1024
attn v NaN 1024
mlp post_norm NaN 4096
mlp gate_proj NaN 14336
mlp up_proj NaN 14336
mlp down_proj NaN 4096
attn hidden_states NaN 4096
attn post_norm NaN 4096
attn q NaN 4096
attn k NaN 1024
attn v NaN 1024
mlp post_norm NaN 4096
mlp gate_proj NaN 14336
mlp up_proj NaN 14336
mlp down_proj NaN 4096
attn hidden_states NaN 4096
attn post_norm NaN 4096
attn q NaN 4096
attn k NaN 1024
attn v NaN 1024
mlp post_norm NaN 4096
mlp gate_proj NaN 14336
mlp up_proj NaN 14336
mlp down_proj NaN 4096
attn hidden_states NaN 4096
attn post_norm NaN 4096
attn q NaN 4096
attn k NaN 1024
attn v NaN 1024
mlp post_norm NaN 4096
mlp gate_proj NaN 14336
mlp up_proj NaN 14336
mlp down_proj NaN 4096
attn hidden_states NaN 4096
attn post_norm NaN 4096
attn q NaN 4096
attn k NaN 1024
attn v NaN 1024
mlp post_norm NaN 4096
mlp gate_proj NaN 14336
mlp up_proj NaN 14336
mlp down_proj NaN 4096
attn hidden_states NaN 4096
attn post_norm NaN 4096
attn q NaN 4096
attn k NaN 1024
attn v NaN 1024
mlp post_norm NaN 4096
mlp gate_proj NaN 14336
mlp up_proj NaN 14336
mlp down_proj NaN 4096
attn hidden_states NaN 4096
attn post_norm NaN 4096
attn q NaN 4096
attn k NaN 1024
attn v NaN 1024
mlp post_norm NaN 4096
mlp gate_proj NaN 14336
mlp up_proj NaN 14336
mlp down_proj NaN 4096
attn hidden_states NaN 4096
attn post_norm NaN 4096
attn q NaN 4096
attn k NaN 1024
attn v NaN 1024
mlp post_norm NaN 4096
mlp gate_proj NaN 14336
mlp up_proj NaN 14336
mlp down_proj NaN 4096
attn hidden_states NaN 4096
attn post_norm NaN 4096
attn q NaN 4096
attn k NaN 1024
attn v NaN 1024
mlp post_norm NaN 4096
mlp gate_proj NaN 14336
mlp up_proj NaN 14336
mlp down_proj NaN 4096
128256

ExLlamaV2Cache:

'model (<class 'exllamav2.model.ExLlamaV2'>)': 
<exllamav2.model.ExLlamaV2 object at 0x7ddef339d6a0>

'max_seq_len (<class 'int'>)': 
256

'batch_size (<class 'int'>)': 
1

'dtype (<class 'torch.dtype'>)': 
torch.float16

'weights_per_element_k (<class 'int'>)': 
1

'weights_per_element_v (<class 'int'>)': 
1

'has_scales (<class 'bool'>)': 
False

'key_states (<class 'list'>)': 
nans: 1024
nans: 1024
nans: 0
nans: 0
nans: 0
nans: 0
nans: 0
nans: 0
nans: 0
nans: 0
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024

'value_states (<class 'list'>)': 
nans: 1024
nans: 1024
nans: 0
nans: 0
nans: 0
nans: 0
nans: 0
nans: 0
nans: 0
nans: 0
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024
nans: 1024

'key_scales (<class 'list'>)': 

'value_scales (<class 'list'>)': 

'num_key_value_heads (<class 'int'>)': 
8

'num_hidden_layers (<class 'int'>)': 
32

'head_dim (<class 'int'>)': 
128

'current_seq_len (<class 'int'>)': 
2

'shape_basic (<class 'tuple'>)': 
(1, 256, 8, 128)

'shape_wk (<class 'tuple'>)': 
(1, 256, 8, 128)

'shape_wv (<class 'tuple'>)': 
(1, 256, 8, 128)

'shape_s (<class 'tuple'>)': 
(1, 256, 8, 4)
turboderp commented 6 days ago

Thank you for investigating this. It really narrows it down.

One way to confirm that the layernorm is causing it could be, with the changes to attn.py and mlp.py to force the "torch" path, you could do the same to rmsnorm.py. The forward_torch function there use pure Torch functions for the RMS layernorm so I would expect that to work.

I see I've left some clamping in that function too:

        hidden_states[hidden_states == -float('inf')] = -65504.0
        hidden_states[hidden_states == float('inf')] = 65504.0

You could try with and without that, I guess.

IMbackK commented 6 days ago

indeed renameing forward to forward_never and forward_torch to forward in rmsnorm fixes repoducer and test_inference

IMbackK commented 6 days ago

however this is not the only problem as reverting the changes to attn.py and mlp.py to again use q_attn_forward_1 and so on breaks it again

but maybe that makes sense, since i dont know what calls what kernels exactly

turboderp commented 6 days ago

Well then, thank you for helping out with this. I thought the wave64 issues with the layernorm function had been resolved but I guess not. It's a standard function, though, so it can't be too hard to either get it working or find some other wave64-friendly kernel to use instead.

The reason reverting the changes to attn.py and mlp.py breaks it again is that, for quantized models, those modules fuse a few operation and move operations to the C++ extension to reduce overhead (Python is slow). In the process they call the C++/CUDA layernorm code directly. So it still needs to be fixed, but at least I know exactly where the problem is.

IMbackK commented 6 days ago

exllamav1 works fine btw, maybe you can just joink the implementation from there as a temporary fix

turboderp commented 2 days ago

So I think I found the error. Very smooth brain bug, but hey, at least it seems to be fixed now.

I tested on an MI300X on RunPod, and that's wave64 so I would assume it also works on the MI100, given that you tested everything else.

I did find an additional issue with SDPA which appears to be broken on the ROCm PyTorch build (?). Hard to say exactly what's going wrong there but it successfully creates a lower-right causal mask and then it just silently applies an upper-left mask instead. Or so it would seem. I added a no_sdpa option to the config for now along with an EXLLAMA_NO_SDPA env variable.

IMbackK commented 2 days ago

I can confirm that it works now. If you want i can write a small reducer for causal mask and report this upstream. Performance is pretty poor, i will try and run exllama through omniperf some time

turboderp commented 2 days ago

I saw the same thing on the MI300X. It's about 3-4x slower than a 3090 despite having specs that should blow the 3090 out of the water. Surprisingly, this was across the board, even with unquantized models that run almost entirely on PyTorch tensor operations (tested on 2.3.1+rocm6.0).

I'd like to dig more into it at some point, but I think the first step would have to be profiling and then I'm already stuck since the ROCm ecosystem is strange to me. I get the impression that it's kind of underdeveloped, too?

As for SDPA, I wouldn't mind if you could rig up an example. I'm not able to run any code right now, so this is off the top of my head and I'm not sure it's correct, but it should illustrate what the problem appears to be:

import torch
import torch.nn.functional as F
from torch.nn.attention.bias import causal_lower_right

bsz, num_heads, length, dim = 1, 32, 16, 128
right_slice = 8

q = torch.rand(bsz, num_heads, length, dim)
k = torch.rand(bsz, num_heads, length, dim)
v = torch.rand(bsz, num_heads, length, dim)

mask_full = causal_lower_right(length, length)
attn_full = F.scaled_dot_product_attention(q, k, v, mask_full)

q_right = q[:, :, -right_slice:, :]
mask_right = causal_lower_right(right_slice, length)
attn_right = F.scaled_dot_product_attention(q_right, k, v, mask_right)

assert(torch.allclose(attn_full[:, :, -right_slice, :], attn_right), ":(")
IMbackK commented 1 day ago

The rocm ecosystem is def underdeveloped atm compared to cuda, this is not really surprising give relative age. With regards to pytorch performance on mi300, this is uniquely bad atm as mi300's performance in rocblas is pretty bad, however pytorch has a envvar (not sure if this is in the offical release yet) that makes pytorch use hipblaslt for gemm, this makes a huge difference there, anyhow this dosent pertain to mi100

i can help you a bit with profiling if you like and have specific questions on how to use the tooling, omiperf is the big gun to use here, for simpler tasks i would recommend using rocprofv2.

ill try your reproducer snippet when i am back at the mi100 machine

IMbackK commented 1 day ago

ok i ran the (slightly modified) repducer on the mi100 machine via ssh, and i cant see any issue. potentally this is because on mi300 the flash_sdp is avaialble while on mi100 it is not, but i dont see anything wrong here on mi100/rx6800xt