ggerganov / llama.cpp

LLM inference in C/C++
MIT License
64.52k stars 9.24k forks source link

Apple Silicon GPU Support Possible? #1545

Closed bfrasure closed 1 year ago

bfrasure commented 1 year ago

The CUDA acceleration is very impressive. Does anyone know of any efforts to run this on the GPU cores of the M processors? I'd be willing to assist but I'd rather not start from scratch if something exists.

evanmiller commented 1 year ago

You can see the author's previous notes on attempting to use the M1 GPU on GPT-J:

https://github.com/ggerganov/ggml/tree/master/examples/gpt-j#attempt-to-use-the-m1-gpu

Basically he found that with Apple's unified memory architecture, the bottleneck is memory bandwidth rather than pure compute.

ggerganov commented 1 year ago

@evanmiller This info is outdated (and likely wrong) - I am working on offloading the full computation on the M1 GPU with custom kernels and hope to do it properly this time around.

x4080 commented 1 year ago

Can't wait for this, @ggerganov

iandennismiller commented 1 year ago

I've been trying to understand more about MPS and I found a few resources that helped.

Philip Turner has been doing interesting work with Metal:

Also, I learned a lot about the limitations of MPS from this pytorch thread, titled "MPS device appears much slower than CPU on M1 Mac Pro." It's an old thread but it still has current activity:

https://github.com/pytorch/pytorch/issues/77799

I thought I would leave these links in case it helps someone else.

philipturner commented 1 year ago

Context size 2048, 512 tokens, LLaMa 6.7B (3.9 GB)

Layers on Accelerate Layers on CLBlast Sequential Throughput/Token Bandwidth
32 0 40590 µs 96 GB/s
30 2 58150 µs 67 GB/s
28 4 76310 µs 51 GB/s
24 8 113460 µs 34 GB/s
16 16 198330 µs 20 GB/s
0 32 253050 µs 15 GB/s

Theoretically it should be able to utilize more bandwidth. I think I can make this an order of magnitude faster. It will require a ton of tuning to align memory transactions and utilize ~376 GB/s of the bandwidth. Use triangular FlashAttention for long context, with dynamic work redistribution to keep GPU cores fully utilized.

Variant Layers on Accelerate Sequential Throughput/Token Bandwidth
6.7B 32 40590 µs 96 GB/s
13.0B 40 75370 µs 103 GB/s
32.5B 60 173540 µs 112 GB/s
65.2B 80 34842780 µs 1 GB/s

Also, has anyone considered lane compression, so I can run 65B-q4 on 32 GB RAM with reasonable speed?

philipturner commented 1 year ago

I'm trying to fill in this table. Next is the PyTorch MPS fork of LLaMa.cpp that's slower than CPU.

Latency per 512 tokens:

LLaMa 6.7B 13.0B 32.5B 65.2B
PyTorch
LLaMa.cpp 20.8 s 38.6 s 88.9 s 17839.5 s
MLC LLM
MPSGraph
Metal FlashAttention
Theoretical Lower Bound 4.4 s 8.6 s 21.6 s 43.3 s
iandennismiller commented 1 year ago

Very happy to see @philipturner in this thread. @ggerganov , philipturner is an expert with MPS, M1 architecture, and GPU computation in general. I believe there is a lot of hard-won esoteric knowledge when it comes to optimizing for this architecture. So excited to see where this leads.

philipturner commented 1 year ago

I want an honest comparison. Please make LLaMa.cpp as fast as possible; I think I can make something faster. An open-source successor to MPS.

optimizing for this architecture

Why not AMD and Intel too?

x4080 commented 1 year ago

@philipturner have you take a look at https://github.com/mlc-ai/web-llm especially https://github.com/mlc-ai/mlc-llm ? It can run apple silicon gpu pretty fast using webgpu. I tested it and it seems as fast as llama cpp but with no heat

philipturner commented 1 year ago

have you take a look at https://github.com/mlc-ai/web-llm

Screenshot 2023-05-26 at 10 44 17 PM

I tested it and it seems as fast as llama cpp but with no heat

Can you quantify how much faster?

x4080 commented 1 year ago

@philipturner yes, mlc llm is as fast as llama cpp, like I said above

philipturner commented 1 year ago

I suspect it might be a slight bit slower, just like PyTorch MPS. It's one thing to use the GPU. It's another to use it right.

If you use Metal the way it's designed, it should not be 10% slower than CPU, but 300% faster.

This is why I asked specifically for a number. Is it 0% faster? 50% faster? 50% slower?

x4080 commented 1 year ago

Its about the same 😄

I was surprised as well, because I alread read about your work, btw I'm using M2 pro 16GB, and latest ventura dont know if that make a difference, have you tried it and the result is slower than llama cpp ?

Using M1?

philipturner commented 1 year ago

I have not benchmarked it yet. What speed are you getting on CPU with LLaMa.cpp? Try reporting the ms/token with -c 2048 -n 512. Run about 10 times, and report the fastest trial. Also exclude trials where it finishes early (before 450 tokens).

After we can quantify LLaMa.cpp on your device, I'll find a way to get a numerical measurement from your MLC experience.


To explain the benefit of using Metal correctly, you're talking about jumps from the red curve to the cyan curve. LLMs are smaller problems (1000 atoms -> 1000 tokens) and memory/latency bounded. For big stuff like Stable Diffusion, you can grossly misuse the MTLCommandQueue (ahem, PyTorch), but the problem's so big, the compute bottleneck is even larger.

Molecular Simulation Speed
Most people use Metal like this: ```swift MTLCommandQueue { MTLCommandBuffer { MTLComputeCommandEncoder { // command that adds two 1000-wide tensors // finishes in 1 microsecond } } // 10 microseconds latency MTLCommandBuffer { MTLComputeCommandEncoder { // command that multiplies two 1000-wide tensors // finishes in 1 microsecond } } // 10 microseconds latency MTLCommandBuffer { MTLComputeCommandEncoder { // command that exponentiates two 1000-wide tensors // finishes in 1 microsecond } } // 10 microseconds latency MTLCommandBuffer.waitUntilCompeted { // 200 microseconds latency } } ```
How it's designed to be used: ```swift MTLCommandQueue { MTLCommandBuffer { MTLComputeCommandEncoder { // command that adds two 1000-wide tensors // finishes in 1 microsecond // command that multiplies two 1000-wide tensors // finishes in 1 microsecond // command that exponentiates two 1000-wide tensors // finishes in 1 microsecond // some very clever GPGPU way to run the complex // logic on the GPU, to prevent synchronization // with the CPU // finishes in 5 microseconds } } // 10 microseconds latency } ```
x4080 commented 1 year ago

For llama.cpp, now i'm using 13b models, so its kinda slower for the model that I use for mlc (7b), yes the mlc dont have any speed displayed but it really brrr 😄

Before using the web version, I think its about >20 token/s

philipturner commented 1 year ago

Before using the web version, I think its about >20 token/s

LLaMa.cpp command-line or MLC AI command-line? You gave me a number of 50 ms/token.

yes the mlc dont have any speed displayed but it really brrr

Can you get a screen recording of it? On your specific computer.

x4080 commented 1 year ago
Screenshot 2023-05-27 at 11 12 29

Using webgpu, this is what I got

x4080 commented 1 year ago

The strange thing is for stable diffusion is not "that fast"

philipturner commented 1 year ago

It gives 47 ms/token decoding. Compare to LLaMa.cpp ctx=512, 38 ms.

philipturner commented 1 year ago

Latency per 512 tokens:

LLaMA 6.7B 13.0B 32.5B 65.2B
PyTorch 116.1 s 558.1 s OOM OOM
Web LLM 30.1 s n/a n/a n/a
LLaMa.cpp 20.8 s 38.6 s 88.9 s 17839.5 s
MPSGraph
Metal FlashAttention
Theoretical Lower Bound 4.4 s 8.6 s 21.6 s 43.3 s

Why does the tokens per second get monotonically slower as I get farther into the conversation (Web LLM)?

philipturner commented 1 year ago

I don’t really know what I’m doing, but I translated the core of LLaMA_MPS to MPSGraph. The hard part is figuring out how to load the weights.

https://gist.github.com/philipturner/23e30121a6a898f501d03f117bfe6f92

philipturner commented 1 year ago

I got the neural network to run 3x faster. It's going to be several weeks before I publish the Metal code - can you wait until then?

x4080 commented 1 year ago

@philipturner i'll wait, is it for your repo or for llama.cpp ?

philipturner commented 1 year ago

I'm making a repo that does a lot more than just optimize quantized GEMV. It's also multi-vendor. Should be easy to integrate into llama.cpp.

x4080 commented 1 year ago

@philipturner 👍

philipturner commented 1 year ago

Basically, I'm doing everything I can, so Apple platforms can get properly supported by Modular AI. It's a long ways away, but eventually, we won't need to make custom AI frameworks (e.g. GGML) just to run a language model fast.

x4080 commented 1 year ago

cool

edit : what do you think about mlc ? Is it as fast as using metal directly ?

philipturner commented 1 year ago

MLC seems to dispatch to TVM, which uses neural networks to guess how to run a neural network the fastest way. There's a much simpler and faster solution to the matrix multiplication problem, which Modular implemented with flying colors. Also TVM only supports AI inference, not AI training.

x4080 commented 1 year ago

@philipturner thanks

philipturner commented 1 year ago

@evanmiller This info is outdated (and likely wrong) - I am working on offloading the full computation on the M1 GPU with custom kernels and hope to do it properly this time around.

@ggerganov I guess it wouldn't hurt to drop the Q4 shader variant (not the full FlashAttention though). I recommend using metal-cpp instead of the ObjC or Swift bindings. Have fun making the Apple GPU go brrr 😄

CPU code ```swift import Metal func testLLaMA() { let device = MTLCreateSystemDefaultDevice()! let commandQueue = device.makeCommandQueue()! let library = device.makeDefaultLibrary()! let constants = MTLFunctionConstantValues() var ncols: Int = 4096 constants.setConstantValue(&ncols, type: .ushort, index: 0) // The LLaMA.cpp code is no longer functioning because I hard-coded the // different dispatching heuristics for `FeedForward.encode`. //let functionName = "dequantize_mul_mat_vec_q4_0" let functionName = "gemv_q4_0" var function = try! library.makeFunction( name: functionName, constantValues: constants) let w13Pipeline = try! device.makeComputePipelineState(function: function) ncols = 4096 * 4 constants.setConstantValue(&ncols, type: .ushort, index: 0) function = try! library.makeFunction( name: functionName, constantValues: constants) let w2Pipeline = try! device.makeComputePipelineState(function: function) let matrixElements = 4096 * 4096 * 4 let matrixSize = matrixElements / 2 + (matrixElements / 32) * 2 let vectorSize = 4096 * 4 assert(matrixSize == 4096 * 16384 / 2 * 9 / 8) // Run all 32 layers in quick succession to find asymptotic maximum bandwidth. struct Vectors { var x: MTLBuffer var w1Val: MTLBuffer // quadruple the vector size var w3Val: MTLBuffer // quadruple the vector size var output: MTLBuffer init(device: MTLDevice, vectorSize: Int) { x = device.makeBuffer(length: vectorSize)! w1Val = device.makeBuffer(length: vectorSize * 4)! w3Val = device.makeBuffer(length: vectorSize * 4)! output = device.makeBuffer(length: vectorSize)! } } struct Context { var vectors: Vectors var w13Pipeline: MTLComputePipelineState var w2Pipeline: MTLComputePipelineState } // Don't do anything with the vectors written back to RAM. struct FeedForward { var matrixSize: Int var weights1: MTLBuffer var weights2: MTLBuffer var weights3: MTLBuffer init(device: MTLDevice, matrixSize: Int) { self.matrixSize = matrixSize weights1 = device.makeBuffer(length: matrixSize)! weights2 = device.makeBuffer(length: matrixSize)! weights3 = device.makeBuffer(length: matrixSize)! } var totalMemory: Int { weights1.length + weights2.length + weights3.length } func encode(encoder: MTLComputeCommandEncoder, context ctx: Context) { let simdRowStride = 4 let simdsPerGroup = 4 encoder.setComputePipelineState(ctx.w13Pipeline) encoder.setThreadgroupMemoryLength(4 * 32, index: 0) encoder.setBuffer(weights1, offset: 0, index: 0) encoder.setBuffer(weights1, offset: matrixSize * 8 / 9, index: 1) encoder.setBuffer(ctx.vectors.x, offset: 0, index: 2) encoder.setBuffer(ctx.vectors.w1Val, offset: 0, index: 3) var ncols: UInt32 = 4096 var nrows: UInt32 = 4096 * 4 encoder.setBytes(&ncols, length: 4, index: 4) encoder.dispatchThreadgroups( MTLSizeMake(Int(nrows) / simdRowStride / simdsPerGroup, 1, 1), threadsPerThreadgroup: MTLSizeMake(32 * simdsPerGroup, 1, 1)) encoder.setComputePipelineState(ctx.w13Pipeline) encoder.setThreadgroupMemoryLength(4 * 32, index: 0) encoder.setBuffer(weights3, offset: 0, index: 0) encoder.setBuffer(weights3, offset: matrixSize * 8 / 9, index: 1) encoder.setBuffer(ctx.vectors.x, offset: 0, index: 2) encoder.setBuffer(ctx.vectors.w3Val, offset: 0, index: 3) encoder.setBytes(&ncols, length: 4, index: 4) encoder.dispatchThreadgroups( MTLSizeMake(Int(nrows) / simdRowStride / simdsPerGroup, 1, 1), threadsPerThreadgroup: MTLSizeMake(32 * simdsPerGroup, 1, 1)) encoder.setComputePipelineState(ctx.w2Pipeline) encoder.setThreadgroupMemoryLength(4 * 32, index: 0) encoder.setBuffer(weights2, offset: 0, index: 0) encoder.setBuffer(weights2, offset: matrixSize * 8 / 9, index: 1) encoder.setBuffer(ctx.vectors.w3Val, offset: 0, index: 2) encoder.setBuffer(ctx.vectors.output, offset: 0, index: 3) ncols = 4096 * 4 nrows = 4096 encoder.setBytes(&ncols, length: 4, index: 4) encoder.dispatchThreadgroups( MTLSizeMake(Int(nrows) / simdRowStride / simdsPerGroup, 1, 1), threadsPerThreadgroup: MTLSizeMake(32 * simdsPerGroup, 1, 1)) } } let numLayers = 32 let vectors = Vectors(device: device, vectorSize: vectorSize) let context = Context( vectors: vectors, w13Pipeline: w13Pipeline, w2Pipeline: w2Pipeline) var layers: [FeedForward] = [] for _ in 0..
GPU code ```metal #include using namespace metal; // Perform a feedforward layer of LLaMA-6.7B. Ensure you are cycling through // 32 instances of the layer weights, otherwise they will fall into the // system-level cache. // // Inference from both the LLaMA.cpp format and a bandwidth-optimized format. #define COMPILE_LEGACY_LLAMA_CPP_OPENCL_SHADER 0 #if COMPILE_LEGACY_LLAMA_CPP_OPENCL_SHADER // Reference implementation from LLaMA.cpp. The custom implementation stores // weights in a different format. #define QK4_0 32 #define QR4_0 2 struct __attribute__ ((packed)) block_q4_0 { half d; uint8_t qs[QK4_0 / 2]; }; void dequantize_q4_0 ( const device block_q4_0* x, const int ib, const int iqs, thread float* v0, thread float* v1) { const float d = float(x[ib].d); const uint8_t vui = x[ib].qs[iqs]; const int8_t vi0 = vui & 0xF; const int8_t vi1 = vui >> 4; *v0 = (vi0 - 8)*d; *v1 = (vi1 - 8)*d; } // Original: // 88.5 GB/s // Don't read out-of-bounds vector data: // 96.0 GB/s kernel void dequantize_mul_mat_vec_q4_0 ( device block_q4_0* x [[buffer(0)]], threadgroup float* tmp [[threadgroup(0)]], device float* y [[buffer(2)]], device float* dst [[buffer(3)]], constant uint &ncols [[buffer(4)]], uint block_size [[threads_per_threadgroup]], uint global_id [[thread_position_in_grid]], uint local_id [[thread_position_in_threadgroup]]) { const uint row = global_id / block_size; const uint qk = QK4_0; const uint qr = QR4_0; const int y_offset = qr == 1 ? 1 : qk/2; tmp[local_id] = 0; for (uint i = 0; i < ncols/block_size; i += 2) { const uint col = i*block_size + 2*local_id; const uint ib = (row*ncols + col)/qk; // block index const uint iqs = (col%qk)/qr; // quant index const uint iybs = col - col%qk; // y block start index // dequantize float v0, v1; dequantize_q4_0(x, ib, iqs, &v0, &v1); // matrix multiplication tmp[local_id] += v0 * y[iybs + iqs + 0]; tmp[local_id] += v1 * y[iybs + iqs + y_offset]; } // sum up partial sums and write back result threadgroup_barrier(mem_flags::mem_threadgroup); for (uint s=block_size/2; s>0; s>>=1) { if (local_id < s) { tmp[local_id] += tmp[local_id + s]; } threadgroup_barrier(mem_flags::mem_threadgroup); } if (local_id == 0) { dst[row] = tmp[0]; } } #endif // Original: // 88.5 GB/s // Switch from threadgroup to simdgroup sum: // 90.2 GB/s // Deinterleave the weights and scales: // 98.5 GB/s // Hard-code shader parameters and remove `if (local_id == 0)`: // 139.4 GB/s // Directly index `uint8_t` instead of a struct: // 143.6 GB/s // Don't read out-of-bounds vector data: // 172.6 GB/s // Read input vectors as half-precision: // 193.8 GB/s // Coalesce the accesses to y and read the correct value from `weights`: // 199.4 GB/s // Change threadgroup size from 32 to 64: // 210.4 GB/s // Change threadgroup size from 32 to 128: // 211.7 GB/s // Two rows per simd: // 225.4 GB/s // Four rows per simd: // 226.7 GB/s // Unroll two iterations of the loop / un-duplicate scale reads: // 243.3 GB/s // Coalesce two Y reads: // 245.7 GB/s // Perform both X reads at the same time: // 253.9 GB/s // Unroll four iterations of the loop: // (BAD DATA) 292.4 GB/s // Coalesce X reads: // (BAD DATA) 353.6 GB/s // Coalesce four Y reads and use the correct index within a row: // (BAD DATA) 374.8 GB/s // Change how the buffers are indexed: // (BAD DATA) 406.9 GB/s // Use the correct value for 'i': // 304.0 GB/s // Optimize how 'vui' is stored in registers: // 306.1 GB/s // Optimize the generation of the index for scales: // 319.3 GB/s constant ushort ncols [[function_constant(0)]]; kernel void gemv_q4_0 ( device uchar4 *weights [[buffer(0)]], device half2 *scales [[buffer(1)]], device half *y [[buffer(2)]], device half *dst [[buffer(3)]], uint tid [[thread_position_in_grid]]) { // 8-wide groupings of threads, each thread reads 8 values per iteration. // 'groupings of threads' != 'threadgroups' #define WEIGHTS_PER_UINT 8 #define GROUPING_SIZE 8 uint row = tid / GROUPING_SIZE; ushort local_id = tid % GROUPING_SIZE; float acc = 0; // Changing this to a `while` loop harms performance. Perhaps it triggers a // separate assembly instruction for control flow. for (uint i = 0; i < ncols;) { uchar4 vui = weights[i / WEIGHTS_PER_UINT + local_id]; const uint blocks_in_row = ncols / 32; half2 d = scales[row * blocks_in_row / 2 + i / 32 / 2]; { half4 y_value = *(device half4*)(y + i + 4 * local_id); i += 4 * GROUPING_SIZE; const short vi0 = vui.x & 0xF; const short vi1 = vui.x >> 4; float v0 = (vi0 - 8) * d.x; float v1 = (vi1 - 8) * d.x; acc += v0 * y_value[0]; acc += v1 * y_value[1]; const short vi2 = vui.y & 0xF; const short vi3 = vui.y >> 4; float v2 = (vi2 - 8) * d.x; float v3 = (vi3 - 8) * d.x; acc += v2 * y_value[2]; acc += v3 * y_value[3]; } { half4 y_value = *(device half4*)(y + i + 4 * local_id); i += 4 * GROUPING_SIZE; const short vi0 = vui.z & 0xF; const short vi1 = vui.z >> 4; float v0 = (vi0 - 8) * d.y; float v1 = (vi1 - 8) * d.y; acc += v0 * y_value[0]; acc += v1 * y_value[1]; const short vi2 = vui.w & 0xF; const short vi3 = vui.w >> 4; float v2 = (vi2 - 8) * d.y; float v3 = (vi3 - 8) * d.y; acc += v2 * y_value[2]; acc += v3 * y_value[3]; } } acc += quad_shuffle_xor(acc, 1); acc += quad_shuffle_xor(acc, 2); acc += simd_shuffle_xor(acc, 4); dst[row] = acc; #undef WEIGHTS_PER_UINT #undef GROUPING_SIZE } ```
ggerganov commented 1 year ago

@philipturner

Thanks for the info

I recommend using metal-cpp instead of the ObjC or Swift bindings

What are the benefits of metal-cpp?

My M1 GPU implementation is here: https://github.com/ggerganov/ggml/pull/108 I currently have prepared the Metal example to be able to load the ggml compute graph together with all necessary data. Next step is mapping it to a command buffer and implementing the custom Metal kernels as needed.

The path is overall clear, with the only question of how exactly to support dynamic shapes (i.e. tensors with size that depends on the number of input / processed tokens). The straightforward way seems to be to recreate the command buffer for each generation - not sure about the overhead of this. If the overhead is too much, would need to think about some alternative approach.

philipturner commented 1 year ago

My code example makes a single command buffer for each token generation. Even a single command buffer per layer would be reasonable, just not a single cmdbuf per elementary operation (what PyTorch does). Also don’t break the cmdbuf into multiple encoders (which removes the benefit of one cmdbuf). If you need to copy buffers via blit encoder, I wrote a very fast compute shader with the same functionality.

Regarding dynamic sizes, I highly recommend you look through my MPSGraph Swift code example a few comments above.

I prefer metal-cpp because ObjC is pretty much a deprecated language. It’s been replaced by Swift and I refuse to learn ObjC just to write Metal code. I have a long history about that. So for C stuff, I will use C++ bindings over ObjC any day.

My choice is mostly personal preference, however, metal-cpp has the same functionality as Metal ObjC bindings. As long as you understand the NS::SharedPtr memory model, it’s quite straightforward to use. Also more comprehensible to non-Apple devs (what is an @interface and @implementation? vs what is class A: class B { public:?). For example, VkFFT has used the C++ bindings.

philipturner commented 1 year ago

To remove the dependency on MPS/MPSMatrixMultiplication, use the early stage SIMD-group matmul here (faster than MPS for FP16). It requires aligned and non-transposed matrices - the latter restriction is trivial to lift. For the former:

First matmul in attention: 32/40/64 are multiples of 8, 52 is not (zero pad to 56)

LLaMA 6.7B: 32-wide block size for second matmul in attention

LLaMA 13.0B: 40-wide block size

LLaMA 32.5B: two shader invocations, one with block 32, another block 24, and modify the code to stride the memory accesses to 56

LLaMA 65.2B: 32-wide block size

marcothedeveloper123 commented 1 year ago

i have an M2 Max 96GB at my disposal should you like me to perform tests. i have some experience in ML using Python. would very much like to help

ggerganov commented 1 year ago

Closing this as the Metal implementation has now officially landed on master