turboderp / exllama

A more memory-efficient rewrite of the HF transformers implementation of Llama for use with quantized weights.
MIT License
2.66k stars 214 forks source link

Optimize q4_matmul #275

Closed QuarticCat closed 10 months ago

QuarticCat commented 10 months ago

Performance changes

Before:

$ python test_benchmark_inference.py -p -d models/LLaMA-7B-4bit-128g -cs
/home/qc/Workspace/NotMe/exllama/cuda_ext.py:82: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:84.)
  none_tensor = torch.empty((1, 1), device = "meta")
 -- Tokenizer: models/LLaMA-7B-4bit-128g/tokenizer.model
 -- Model config: models/LLaMA-7B-4bit-128g/config.json
 -- Model: models/LLaMA-7B-4bit-128g/llama-7b-4bit-128g.safetensors
 -- Sequence length: 2048
 -- Tuning:
 -- --sdp_thd: 8
 -- --matmul_recons_thd: 8
 -- --fused_mlp_thd: 2
 -- --concurrent_streams
 -- Options: ['perf']
 ** Time, Load model: 1.37 seconds
 ** Time, Load tokenizer: 0.01 seconds
 -- Groupsize (inferred): 128
 -- Act-order (inferred): yes
 ** VRAM, Model: [cuda:0] 3,638.47 MB
 ** VRAM, Cache: [cuda:0] 1,024.00 MB
 -- Warmup pass 1...
 ** Time, Warmup: 0.44 seconds
 -- Warmup pass 2...
 ** Time, Warmup: 0.42 seconds
 -- Inference, first pass.
 ** Time, Inference: 0.60 seconds
 ** Speed: 3212.74 tokens/second
 -- Generating 128 tokens, 1920 token prompt...
 ** Speed: 37.35 tokens/second
 -- Generating 128 tokens, 4 token prompt...
 ** Speed: 50.67 tokens/second
 ** VRAM, Inference: [cuda:0] 143.92 MB
 ** VRAM, Total: [cuda:0] 4,806.38 MB

After:

$ python test_benchmark_inference.py -p -d models/LLaMA-7B-4bit-128g -cs
/home/qc/Workspace/NotMe/exllama/cuda_ext.py:82: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:84.)
  none_tensor = torch.empty((1, 1), device = "meta")
 -- Tokenizer: models/LLaMA-7B-4bit-128g/tokenizer.model
 -- Model config: models/LLaMA-7B-4bit-128g/config.json
 -- Model: models/LLaMA-7B-4bit-128g/llama-7b-4bit-128g.safetensors
 -- Sequence length: 2048
 -- Tuning:
 -- --sdp_thd: 8
 -- --matmul_recons_thd: 8
 -- --fused_mlp_thd: 2
 -- --concurrent_streams
 -- Options: ['perf']
 ** Time, Load model: 1.42 seconds
 ** Time, Load tokenizer: 0.01 seconds
 -- Groupsize (inferred): 128
 -- Act-order (inferred): yes
 ** VRAM, Model: [cuda:0] 3,638.47 MB
 ** VRAM, Cache: [cuda:0] 1,024.00 MB
 -- Warmup pass 1...
 ** Time, Warmup: 0.44 seconds
 -- Warmup pass 2...
 ** Time, Warmup: 0.42 seconds
 -- Inference, first pass.
 ** Time, Inference: 0.60 seconds
 ** Speed: 3210.79 tokens/second
 -- Generating 128 tokens, 1920 token prompt...
 ** Speed: 80.41 tokens/second
 -- Generating 128 tokens, 4 token prompt...
 ** Speed: 152.34 tokens/second
 ** VRAM, Inference: [cuda:0] 143.92 MB
 ** VRAM, Total: [cuda:0] 4,806.38 MB

Benchmarked on RTX 2070 Super. Other models cannot fit in VRAM. Expect less speedup if the model contains less x_map.

PPL changes

Before:

$ python test_benchmark_inference.py -ppl -d models/LLaMA-7B-4bit-128g  
/home/qc/Workspace/NotMe/exllama/cuda_ext.py:82: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:84.)
  none_tensor = torch.empty((1, 1), device = "meta")
 -- Perplexity:
 -- - Dataset: datasets/wikitext2_val_sample.jsonl
 -- - Chunks: 100
 -- - Chunk size: 2048 -> 2048
 -- - Chunk overlap: 0
 -- - Min. chunk size: 50
 -- - Key: text
 -- Tokenizer: models/LLaMA-7B-4bit-128g/tokenizer.model
 -- Model config: models/LLaMA-7B-4bit-128g/config.json
 -- Model: models/LLaMA-7B-4bit-128g/llama-7b-4bit-128g.safetensors
 -- Sequence length: 2048
 -- Tuning:
 -- --sdp_thd: 8
 -- --matmul_recons_thd: 8
 -- --fused_mlp_thd: 2
 -- Options: ['perplexity']
 ** Time, Load model: 1.39 seconds
 ** Time, Load tokenizer: 0.01 seconds
 -- Groupsize (inferred): 128
 -- Act-order (inferred): yes
 ** VRAM, Model: [cuda:0] 3,638.47 MB
 ** VRAM, Cache: [cuda:0] 1,024.00 MB
 -- Loading dataset...
 -- Testing 100 chunks..........
 ** Perplexity: 6.0227

After:

$ python test_benchmark_inference.py -ppl -d models/LLaMA-7B-4bit-128g
/home/qc/Workspace/NotMe/exllama/cuda_ext.py:82: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:84.)
  none_tensor = torch.empty((1, 1), device = "meta")
 -- Perplexity:
 -- - Dataset: datasets/wikitext2_val_sample.jsonl
 -- - Chunks: 100
 -- - Chunk size: 2048 -> 2048
 -- - Chunk overlap: 0
 -- - Min. chunk size: 50
 -- - Key: text
 -- Tokenizer: models/LLaMA-7B-4bit-128g/tokenizer.model
 -- Model config: models/LLaMA-7B-4bit-128g/config.json
 -- Model: models/LLaMA-7B-4bit-128g/llama-7b-4bit-128g.safetensors
 -- Sequence length: 2048
 -- Tuning:
 -- --sdp_thd: 8
 -- --matmul_recons_thd: 8
 -- --fused_mlp_thd: 2
 -- Options: ['perplexity']
 ** Time, Load model: 1.40 seconds
 ** Time, Load tokenizer: 0.01 seconds
 -- Groupsize (inferred): 128
 -- Act-order (inferred): yes
 ** VRAM, Model: [cuda:0] 3,638.47 MB
 ** VRAM, Cache: [cuda:0] 1,024.00 MB
 -- Loading dataset...
 -- Testing 100 chunks..........
 ** Perplexity: 6.0232

Delta = 0.0005

turboderp commented 10 months ago

Thanks for this.

Interestingly, I have more or less the same optimization in V2 already. The difference in my tests has been minimal, though, and I'm doing it mostly for the sake of the new quant format, but I guess that's because I have nothing to test on that's older than Ampere.

I'm really surprised there's this much of a difference here, given that the data you're explicitly caching in SMEM in these tests is all of 8 kB in total, and Turing is also supposed to have a shared architecture for L1 cache and SMEM, with about the same performance on both.

I'll do some tests and merge this in a few hours if it doesn't break anything. But in the meantime, could you test if there's a further difference in performance with the --matmul_fused_remap argument? That's what would trigger the use_x_map flag to the kernel and skip launching the column_remap_cuda kernel separately.

QuarticCat commented 10 months ago

Before this PR:

$ python test_benchmark_inference.py -p -d models/LLaMA-7B-4bit-128g -cs --matmul_fused_remap
/home/qc/Workspace/NotMe/exllama/cuda_ext.py:82: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:84.)
  none_tensor = torch.empty((1, 1), device = "meta")
 -- Tokenizer: models/LLaMA-7B-4bit-128g/tokenizer.model
 -- Model config: models/LLaMA-7B-4bit-128g/config.json
 -- Model: models/LLaMA-7B-4bit-128g/llama-7b-4bit-128g.safetensors
 -- Sequence length: 2048
 -- Tuning:
 -- --sdp_thd: 8
 -- --matmul_recons_thd: 8
 -- --fused_mlp_thd: 2
 -- --matmul_fused_remap
 -- --concurrent_streams
 -- Options: ['perf']
 ** Time, Load model: 1.36 seconds
 ** Time, Load tokenizer: 0.01 seconds
 -- Groupsize (inferred): 128
 -- Act-order (inferred): yes
 ** VRAM, Model: [cuda:0] 3,638.47 MB
 ** VRAM, Cache: [cuda:0] 1,024.00 MB
 -- Warmup pass 1...
 ** Time, Warmup: 0.44 seconds
 -- Warmup pass 2...
 ** Time, Warmup: 0.42 seconds
 -- Inference, first pass.
 ** Time, Inference: 0.59 seconds
 ** Speed: 3233.46 tokens/second
 -- Generating 128 tokens, 1920 token prompt...
 ** Speed: 34.70 tokens/second
 -- Generating 128 tokens, 4 token prompt...
 ** Speed: 46.09 tokens/second
 ** VRAM, Inference: [cuda:0] 143.92 MB
 ** VRAM, Total: [cuda:0] 4,806.38 MB

After this PR:

$ python test_benchmark_inference.py -p -d models/LLaMA-7B-4bit-128g -cs --matmul_fused_remap
/home/qc/Workspace/NotMe/exllama/cuda_ext.py:82: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:84.)
  none_tensor = torch.empty((1, 1), device = "meta")
 -- Tokenizer: models/LLaMA-7B-4bit-128g/tokenizer.model
 -- Model config: models/LLaMA-7B-4bit-128g/config.json
 -- Model: models/LLaMA-7B-4bit-128g/llama-7b-4bit-128g.safetensors
 -- Sequence length: 2048
 -- Tuning:
 -- --sdp_thd: 8
 -- --matmul_recons_thd: 8
 -- --fused_mlp_thd: 2
 -- --matmul_fused_remap
 -- --concurrent_streams
 -- Options: ['perf']
 ** Time, Load model: 1.40 seconds
 ** Time, Load tokenizer: 0.01 seconds
 -- Groupsize (inferred): 128
 -- Act-order (inferred): yes
 ** VRAM, Model: [cuda:0] 3,638.47 MB
 ** VRAM, Cache: [cuda:0] 1,024.00 MB
 -- Warmup pass 1...
 ** Time, Warmup: 0.45 seconds
 -- Warmup pass 2...
 ** Time, Warmup: 0.42 seconds
 -- Inference, first pass.
 ** Time, Inference: 0.60 seconds
 ** Speed: 3216.22 tokens/second
 -- Generating 128 tokens, 1920 token prompt...
 ** Speed: 79.21 tokens/second
 -- Generating 128 tokens, 4 token prompt...
 ** Speed: 159.22 tokens/second
 ** VRAM, Inference: [cuda:0] 143.92 MB
 ** VRAM, Total: [cuda:0] 4,806.38 MB
QuarticCat commented 10 months ago

I'm really surprised there's this much of a difference here, given that the data you're explicitly caching in SMEM in these tests is all of 8 kB in total, and Turing is also supposed to have a shared architecture for L1 cache and SMEM, with about the same performance on both.

Your original memory access pattern cannot be coalesced. In addition, x_map values are calculated multiple times, which is redundant.

turboderp commented 10 months ago

Well, without the fused remap parameter the remapping is done exactly once, from global to global memory, but the state should be in at least L2 cache by that point, and subsequently reading from L1 should not be slower. However, I haven't done nearly as much profiling on the V1 kernel as I have on V2, so I may have missed a lot. Perhaps coalescing matters more than I thought in places.

And upon some further testing, this is faster, even on the 4090, but only for some models. For others it's considerably slower, and I'll need a moment to figure out why, or if it should be switchable to get the best of both worlds depending on e.g. model size.

QuarticCat commented 10 months ago

Let me guess. Do those models have a huge group size or no group size?

turboderp commented 10 months ago

You're right, they have no group size, which is to say they have one group as large as the hidden dim of the model. So they'll be using a lot of SMEM and occupancy will be terrible with this approach. But no reason this couldn't just be limited to 128 rows or something, in that case, so performance should be the same.

I am getting broken output, though, so there's definitely something amiss. Note that the perplexity test runs sequences larger than the threshold that triggers reconstruction, where the custom kernel is bypassed in favor of just temporarily reconstructing the FP16 weights and using cuBLAS, since it's invariably faster.

If you run the benchmark script with -v it'll test perplexity in both modes to verify that the result is similar (it shouldn't be exact since the kernel uses atomicAdd rather than reduction as cuBLAS does), and do a quick generation. Here's what I'm getting:

 -- Tokenizer: /mnt/str/models/_test_models/TheBloke_Llama-2-7B-GPTQ/tokenizer.model
 -- Model config: /mnt/str/models/_test_models/TheBloke_Llama-2-7B-GPTQ/config.json
 -- Model: /mnt/str/models/_test_models/TheBloke_Llama-2-7B-GPTQ/model.safetensors
 -- Sequence length: 2048
 -- Tuning:
 -- --sdp_thd: 8
 -- --matmul_recons_thd: 8
 -- --fused_mlp_thd: 2
 -- Options: ['perf', 'validate']
 ** Time, Load model: 1.49 seconds
 ** Time, Load tokenizer: 0.01 seconds
 -- Groupsize (inferred): 128
 -- Act-order (inferred): yes
 ** VRAM, Model: [cuda:0] 3,638.47 MB - [cuda:1] 0.00 MB
 ** VRAM, Cache: [cuda:0] 1,024.00 MB - [cuda:1] 0.00 MB
 -- Warmup pass 1...
 ** Time, Warmup: 0.58 seconds
 -- Warmup pass 2...
 ** Time, Warmup: 0.15 seconds
 -- Inference, first pass.
 ** Time, Inference: 0.15 seconds
 ** Speed: 13058.57 tokens/second
 -- Generating 128 tokens, 1920 token prompt...
 ** Speed: 219.08 tokens/second
 -- Generating 128 tokens, 4 token prompt...
 ** Speed: 371.99 tokens/second
 ** VRAM, Inference: [cuda:0] 143.92 MB - [cuda:1] 0.00 MB
 ** VRAM, Total: [cuda:0] 4,806.38 MB - [cuda:1] 0.00 MB
 -- Testing 8 chunks.
 ** Perplexity (reconstruct): 6.0388
 -- Testing 8 chunks.
 ** Perplexity (quant, token): 129202.8114
 ** Generation: 'To be or not to be, that is the same exactitudeelnabilatembiaancômesampleionaHRtz sierpészlin FelNUfolglach'

Obviously that isn't right. I'm really hoping it's fixable because those speeds are... well, technically they're higher than the theoretical maximum for 1 TB/s VRAM bandwidth, which I guess is concerning, too. I'm investigating.

QuarticCat commented 10 months ago

I'll look into it.

QuarticCat commented 10 months ago

@turboderp It's fixed. Quite a stupid mistake. :cry:

And now the speed is much lower. But still faster than before.

$ python test_benchmark_inference.py -p -d models/LLaMA-7B-4bit-128g -cs
/home/qc/Workspace/NotMe/exllama/cuda_ext.py:82: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:84.)
  none_tensor = torch.empty((1, 1), device = "meta")
 -- Tokenizer: models/LLaMA-7B-4bit-128g/tokenizer.model
 -- Model config: models/LLaMA-7B-4bit-128g/config.json
 -- Model: models/LLaMA-7B-4bit-128g/llama-7b-4bit-128g.safetensors
 -- Sequence length: 2048
 -- Tuning:
 -- --sdp_thd: 8
 -- --matmul_recons_thd: 8
 -- --fused_mlp_thd: 2
 -- --concurrent_streams
 -- Options: ['perf']
 ** Time, Load model: 1.37 seconds
 ** Time, Load tokenizer: 0.01 seconds
 -- Groupsize (inferred): 128
 -- Act-order (inferred): yes
 ** VRAM, Model: [cuda:0] 3,638.47 MB
 ** VRAM, Cache: [cuda:0] 1,024.00 MB
 -- Warmup pass 1...
 ** Time, Warmup: 0.44 seconds
 -- Warmup pass 2...
 ** Time, Warmup: 0.42 seconds
 -- Inference, first pass.
 ** Time, Inference: 0.59 seconds
 ** Speed: 3264.15 tokens/second
 -- Generating 128 tokens, 1920 token prompt...
 ** Speed: 45.67 tokens/second
 -- Generating 128 tokens, 4 token prompt...
 ** Speed: 67.89 tokens/second
 ** VRAM, Inference: [cuda:0] 143.92 MB
 ** VRAM, Total: [cuda:0] 4,806.38 MB
$ python test_benchmark_inference.py -v -d models/LLaMA-7B-4bit-128g
/home/qc/Workspace/NotMe/exllama/cuda_ext.py:82: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:84.)
  none_tensor = torch.empty((1, 1), device = "meta")
 -- Tokenizer: models/LLaMA-7B-4bit-128g/tokenizer.model
 -- Model config: models/LLaMA-7B-4bit-128g/config.json
 -- Model: models/LLaMA-7B-4bit-128g/llama-7b-4bit-128g.safetensors
 -- Sequence length: 2048
 -- Tuning:
 -- --sdp_thd: 8
 -- --matmul_recons_thd: 8
 -- --fused_mlp_thd: 2
 -- Options: ['validate']
 ** Time, Load model: 1.38 seconds
 ** Time, Load tokenizer: 0.01 seconds
 -- Groupsize (inferred): 128
 -- Act-order (inferred): yes
 ** VRAM, Model: [cuda:0] 3,638.47 MB
 ** VRAM, Cache: [cuda:0] 1,024.00 MB
 -- Testing 8 chunks.
 ** Perplexity (reconstruct): 6.0643
 -- Testing 8 chunks.
 ** Perplexity (quant, token): 6.0777
 ** Generation: 'To be or not to be, that is the question.\nThe answer is: Yes and no. The first part of this sentence is a simple'

I'm now working on improving the performance of other group sizes.

QuarticCat commented 10 months ago

@turboderp Could you benchmark the latest commit on your models?

turboderp commented 10 months ago

I ran some tests on the 4090, but the latest version is about 10% slower than the original. However, if I change the blocksize back to 32, the performance is comparable, and somewhat better for 3B models:

model orig orig/mmfr new/128 new/32 new/32/mmfr
3B/128g/act 221 210 220 242 246
7B/128g 170 162 154 172 159
13B/128g 97 97 83 100 98
30B/128g 46 45 30 46 45
30B/128g/act 43 38 29 44 45
30B/32g/act 40 35 30 41 40
13B/128g/act 96 81 80 93 96
30B 45 45 31 45 45

Doing the same tests on the 3090 paints a different picture, with similar performance for 32 and 128 threads per block, and overall improvements on the order of 10%:

model orig orig/mmfr new/128 new/32 new/32/mmfr
3B/128g/act 125 118 146 153 159
7B/128g 100 100 106 106 107
13B/128g 63 61 69 70 70
30B/128g 33 33 35 37 37
30B/128g/act 32 20 39 37 37
30B/32g/act 27 17 34 28 27
13B/128g/act 70 45 77 78 80
30B 28 28 36 34 34

I'm inclined to say that with the 32 block dim it's an improvement overall, if not on Ada then at least on Ampere, which is still great. Does setting THREADS_X = 32 and GROUP_STEP = 32 make it noticeably slower on the 2070?

QuarticCat commented 10 months ago

THREADS_X = 32 & GROUP_STEP = 32 is 5~15% slower than THREADS_X = 128 & GROUP_STEP = 128 on 2070S. Tested with Llama-2-7B-GPTQ-4bit-{32g,64g,128g}. The smaller the group size, the more significant the difference. Anyway, it's still much faster than the original, so let's set them to 32.

To improve performance on 30/40-series cards, maybe you can try something like __pipeline_memcpy_async to hide the latency of memory accesses. I believe this kernel is memory-bound.

turboderp commented 10 months ago

I'm pretty sure it's latency bound right now, at least on the 4090. Getting around 700 MB/s throughput which, even assuming a little overhead, suggests there's room for improvement. But I'm also reaching 100% occupancy, which means to further hide the latency, the ratio of compute to memory access has to go up. Loading int4s and computing 4 partial row-column products per thread might be an idea to try to saturate the bus between L2 and GMEM.

ardfork commented 10 months ago

It's quite a bit slower on a 6700 XT with a 13B 128g. Around ~5% slow down. From experience, I expect this to also be a slow down on pascal cards.

QuarticCat commented 10 months ago

That's frustrating. Seems like we need more discussion before merging. Do you have any idea how it causes the slowdown? The only card I have is a 2070S so I cannot dig up further.

ardfork commented 10 months ago

I don't have the answer to that. But with exllama v2 around the corner, is it really worth it to spend time trying to optimize exllama v1?

QuarticCat commented 10 months ago

Fine. I didn't know the existence of exllama v2 at the time I wrote this PR. You can close this PR if you want. @turboderp

ardfork commented 10 months ago

I'm not exllama maintainer or developer, turboderp is the only one that should make the decision if this optimization is worth or not. What I wanted to convey by my message is that I, personally, don't want to spend the time investigating why it is a slow-down on my card when exllama v2 might get released soon. Maybe, exllama v2 will never get released or be different enough (like focusing on 2/3 bits quantization) to not make exllama v1 useless and then my decision would have been wrong.

turboderp commented 10 months ago

I think it's worth keeping, but it should probably be switchable one way or another if there's performance degradation on ROCm. Probably it could just switch at compile time based on __CUDA_ARCH__ and USE_ROCM.

As for V2, it supports GPTQ models as well as the new quant format, and performance is considerably better, at least on the cards I have access to. On 4090 it's 10-15% faster than V1 (so far), and the 3090-Ti is only 6% slower than the 4090, which tracks with the kernel being close to optimal, for some definition of optimal. So I also don't want to devote too much time to optimizing V1. If anything I'd rather back-port the new kernel at some point. Though in the meantime there's nothing wrong with making better use of SMEM, as in this PR.

Ph0rk0z commented 10 months ago

Pascal is nigh unusable for this regardless. I'd be happy with a speedup for supported cards. Looking forward to v2 to get higher perplexity quants perhaps. I am noticing the difference between GGML and GPTQ now but the latter is much better at memory management.

turboderp commented 10 months ago

Well, it has to be switchable regardless if it's an issue for ROCm. So I'll just switch on both the CUDA arch version and USE_ROCM. Should get around to it later today. Just need to juggle some forks around so everyone gets credit.

turboderp commented 10 months ago

There, I finally figured out how github works! And thanks for the optimization @QuarticCat.

It should now be enabled when CUDA_ARCH >= 700 but not on ROCm, using the old version as a fallback. It gets a bit messy of course, but I'm not expecting to add too much more functionality to this version anyway.

If there's a further performance benefit to be had on certain arch versions, it should be simple enough to add some more conditional code to select a different GROUP_STEP or whatever is appropriate for the 2070, though I still have no way to test that myself.