sarah-ek / gemm

MIT License
76 stars 11 forks source link

Candle example uses 10% of CPU when fma is active for x86 #20

Open kstavro opened 11 months ago

kstavro commented 11 months ago

Coming here after noticing that CPU inference in the llama example over at candle only utilizes 10% of my CPU (AMD Ryzen 5800X3D). As I mentioned over at the candle repo, this might be because the implementation of gemm only needs to really use a specific amount of cores due to stack management/limitations? Could the 10% CPU utilization make sense? I have notice there is another PR in the repo here, where the number of threads gets upper bounded for some reason that might have to do with the stack, which is not something one might understand just by glancing the code. So, not sure it this is related.

I have tried to implement a minimal gemm example by simulating the matmul from the llama example by copying all the parameters for the respective gemm that takes place during inference, but I get stack overflow, so I am already a bit out of my league here, since I have no idea why this happens.

For reference, here is the issue from candle: (huggingface/candle#1103)

And here is the example I tried to recreate, in case you can correct it on the spot or it might help to reproduce the low CPU utilization.

A llama gemm attempt (that sadly overflows the stack) ```rust use gemm_common::Parallelism; use gemm_f16::gemm::f16::fma::gemm_basic; // I made fma public so that I can import it use half::f16; use rand_distr::{Distribution, Normal}; use std::convert::TryInto; fn convert_to_array(v: Vec) -> [T; N] { v.try_into() .unwrap_or_else(|v: Vec| panic!("Expected a Vec of length {} but it was {}", N, v.len())) } fn main() { let parallelism = Parallelism::Rayon(16); // my CPU has 16 threads, no need for additional deps to get the value let (m, k, n) = (4096, 1, 64); // [candle-core\src\cpu_backend.rs:1522] &dst_p.len() = 262144 // [candle-core\src\cpu_backend.rs:1523] &dst_cs = 1 // [candle-core\src\cpu_backend.rs:1524] &dst_rs = 64 // [candle-core\src\cpu_backend.rs:1525] &lhs_p.len() = 4096 // [candle-core\src\cpu_backend.rs:1526] &lhs_cs = 1 // [candle-core\src\cpu_backend.rs:1527] &lhs_rs = 1 // [candle-core\src\cpu_backend.rs:1528] &rhs_p.len() = 64 // [candle-core\src\cpu_backend.rs:1529] &rhs_cs = 1 // [candle-core\src\cpu_backend.rs:1530] &rhs_rs = 64 const DST_LEN: usize = 262144; const LHS_LEN: usize = 4096; const RHS_LEN: usize = 64; let mut rng = rand::thread_rng(); let normal = Normal::new(0.0, 1.0).expect("Cant create dist"); let vals_dst: Vec = (0..DST_LEN) .map(|_| f16::from_f32(normal.sample(&mut rng))) // .take(262144) .collect(); println!("Created dst"); let vals_lhs: Vec = (0..LHS_LEN) .map(|_| f16::from_f32(normal.sample(&mut rng))) // .take(262144) .collect::>(); println!("Created lhs"); let vals_rhs: Vec = (0..RHS_LEN) .map(|_| f16::from_f32(normal.sample(&mut rng))) // .take(262144) .collect::>(); println!("Created rhs"); let mut dst_p: [f16; DST_LEN] = convert_to_array(vals_dst); let dst_cs = 1; let dst_rs = 64; let lhs_p: [f16; LHS_LEN] = convert_to_array(vals_lhs); let lhs_cs = 1; let lhs_rs = 1; let rhs_p: [f16; RHS_LEN] = convert_to_array(vals_rhs); let rhs_cs = 1; let rhs_rs = 64; println!("Starting gemm16 loop"); loop { unsafe { gemm_basic( /* m: usize = */ m, /* n: usize = */ n, /* k: usize = */ k, /* dst: *mut T = */ dst_p.as_mut_ptr(), /* dst_cs: isize = */ dst_cs as isize, /* dst_rs: isize = */ dst_rs as isize, /* read_dst: bool = */ false, /* lhs: *const T = */ lhs_p.as_ptr(), /* lhs_cs: isize = */ lhs_cs as isize, /* lhs_rs: isize = */ lhs_rs as isize, /* rhs: *const T = */ rhs_p.as_ptr(), /* rhs_cs: isize = */ rhs_cs as isize, /* rhs_rs: isize = */ rhs_rs as isize, /* alpha: T = */ f16::from_f32(0.0), /* beta: T = */ f16::from_f32(1.0), /* conj_dst: bool = */ false, /* conj_lhs: bool = */ false, /* conj_rhs: bool = */ false, parallelism, ) } } } ```
sarah-ek commented 11 months ago

@kstavro can you try without the array stuff? putting data on the heap should avoid the stack overflow

kstavro commented 11 months ago

can you try without the array stuff?

@sarah-ek Could you elaborate a bit what you mean? Do you mean getting rid of the conversions of the vecs to arrays? Without those, the gemm function complains, as it expects arrays as inputs.

By the way, once you explained that I am allocating too much inside the stack with the arrays, I tried with smaller params and the overflow problem went away. Unfortunately, now I am getting (exit code: 0xc000013a, STATUS_CONTROL_C_EXIT), eg with

const DST_LEN: usize = 65536;
const LHS_LEN: usize = 4096;
const RHS_LEN: usize = 64;
kstavro commented 11 months ago

@sarah-ek after having to do a little bit of my own research to understand what everything should mean in the gemm call and debug the above, I realized that:

  1. I indeed don't need to pass arrays inside your gemm function (even though the cpu_backend inside candle seems to be passing $mut [T] or $[T] inside gemm for its matmul (but why don't I get a stack overflow there even if I copy the exact same parameters and generate matrices of the same dimensions?).
  2. DST_LEN (dst probably standing for destination?) has to be equal to m*n. Setting it back to 262144 = 4096*64 made the loop work.

I can confirm steady 9-10% CPU utilization like over at candle. Not sure if this has to do with block/stack optimization of the 5800x3D which has quite more cache than normal commercial CPUs.

kstavro commented 11 months ago

It seems that the CPU utilization bottleneck in the above example is k=1. This makes it practically a dot product and so gevv is then called.

It seems that gevv doesn't implement any parallelism, just SIMD. Maybe it would help to introduce some parallelism there for large vectors, as in the example above? Once I increase k to at least 3, I directly go to >96% CPU util (k=8 ->99% and k=16 -> 100%).

I think what happens with the inference of llama over at candle being fixed at 9% CPU utilization when inference of the new tokens starts and kv_cache kicks in, is that with having a kv_cache most of the matmuls are actually vector-matrix matmuls. I assume that there gemv kicks in? As far as I can see from the code, gemv also only relies on SIMD, which would explain the CPU util.

Passing here some gemm input from candle when inferencing new tokens for reference (there are some big matmuls as well):

A sample of consecutive llama gemm inputs [candle-core\src\cpu_backend.rs:1516] m = 1 [candle-core\src\cpu_backend.rs:1517] k = 4096 [candle-core\src\cpu_backend.rs:1518] n = 4096 [candle-core\src\cpu_backend.rs:1519] &dst_p.len() = 4096 [candle-core\src\cpu_backend.rs:1520] &dst_cs = 1 [candle-core\src\cpu_backend.rs:1521] &dst_rs = 4096 [candle-core\src\cpu_backend.rs:1522] &lhs_p.len() = 4096 [candle-core\src\cpu_backend.rs:1523] &lhs_cs = 1 [candle-core\src\cpu_backend.rs:1524] &lhs_rs = 4096 [candle-core\src\cpu_backend.rs:1525] &rhs_p.len() = 16777216 [candle-core\src\cpu_backend.rs:1526] &rhs_cs = 4096 [candle-core\src\cpu_backend.rs:1527] &rhs_rs = 1 [candle-core\src\cpu_backend.rs:1516] m = 1 [candle-core\src\cpu_backend.rs:1517] k = 4096 [candle-core\src\cpu_backend.rs:1518] n = 11008 [candle-core\src\cpu_backend.rs:1519] &dst_p.len() = 11008 [candle-core\src\cpu_backend.rs:1520] &dst_cs = 1 [candle-core\src\cpu_backend.rs:1521] &dst_rs = 11008 [candle-core\src\cpu_backend.rs:1522] &lhs_p.len() = 4096 [candle-core\src\cpu_backend.rs:1523] &lhs_cs = 1 [candle-core\src\cpu_backend.rs:1524] &lhs_rs = 4096 [candle-core\src\cpu_backend.rs:1525] &rhs_p.len() = 45088768 [candle-core\src\cpu_backend.rs:1526] &rhs_cs = 4096 [candle-core\src\cpu_backend.rs:1527] &rhs_rs = 1 [candle-core\src\cpu_backend.rs:1516] m = 1 [candle-core\src\cpu_backend.rs:1517] k = 4096 [candle-core\src\cpu_backend.rs:1518] n = 11008 [candle-core\src\cpu_backend.rs:1519] &dst_p.len() = 11008 [candle-core\src\cpu_backend.rs:1520] &dst_cs = 1 [candle-core\src\cpu_backend.rs:1521] &dst_rs = 11008 [candle-core\src\cpu_backend.rs:1522] &lhs_p.len() = 4096 [candle-core\src\cpu_backend.rs:1523] &lhs_cs = 1 [candle-core\src\cpu_backend.rs:1524] &lhs_rs = 4096 [candle-core\src\cpu_backend.rs:1525] &rhs_p.len() = 45088768 [candle-core\src\cpu_backend.rs:1526] &rhs_cs = 4096 [candle-core\src\cpu_backend.rs:1527] &rhs_rs = 1 [candle-core\src\cpu_backend.rs:1516] m = 1 [candle-core\src\cpu_backend.rs:1517] k = 11008 [candle-core\src\cpu_backend.rs:1518] n = 4096 [candle-core\src\cpu_backend.rs:1519] &dst_p.len() = 4096 [candle-core\src\cpu_backend.rs:1520] &dst_cs = 1 [candle-core\src\cpu_backend.rs:1521] &dst_rs = 4096 [candle-core\src\cpu_backend.rs:1522] &lhs_p.len() = 11008 [candle-core\src\cpu_backend.rs:1523] &lhs_cs = 1 [candle-core\src\cpu_backend.rs:1524] &lhs_rs = 11008 [candle-core\src\cpu_backend.rs:1525] &rhs_p.len() = 45088768 [candle-core\src\cpu_backend.rs:1526] &rhs_cs = 11008 [candle-core\src\cpu_backend.rs:1527] &rhs_rs = 1 [candle-core\src\cpu_backend.rs:1516] m = 1 [candle-core\src\cpu_backend.rs:1517] k = 4096 [candle-core\src\cpu_backend.rs:1518] n = 4096 [candle-core\src\cpu_backend.rs:1519] &dst_p.len() = 4096 [candle-core\src\cpu_backend.rs:1520] &dst_cs = 1 [candle-core\src\cpu_backend.rs:1521] &dst_rs = 4096 [candle-core\src\cpu_backend.rs:1522] &lhs_p.len() = 4096 [candle-core\src\cpu_backend.rs:1523] &lhs_cs = 1 [candle-core\src\cpu_backend.rs:1524] &lhs_rs = 4096 [candle-core\src\cpu_backend.rs:1525] &rhs_p.len() = 16777216 [candle-core\src\cpu_backend.rs:1526] &rhs_cs = 4096 [candle-core\src\cpu_backend.rs:1527] &rhs_rs = 1 [candle-core\src\cpu_backend.rs:1516] m = 1 [candle-core\src\cpu_backend.rs:1517] k = 4096 [candle-core\src\cpu_backend.rs:1518] n = 4096 [candle-core\src\cpu_backend.rs:1519] &dst_p.len() = 4096 [candle-core\src\cpu_backend.rs:1520] &dst_cs = 1 [candle-core\src\cpu_backend.rs:1521] &dst_rs = 4096 [candle-core\src\cpu_backend.rs:1522] &lhs_p.len() = 4096 [candle-core\src\cpu_backend.rs:1523] &lhs_cs = 1 [candle-core\src\cpu_backend.rs:1524] &lhs_rs = 4096 [candle-core\src\cpu_backend.rs:1525] &rhs_p.len() = 16777216 [candle-core\src\cpu_backend.rs:1526] &rhs_cs = 4096 [candle-core\src\cpu_backend.rs:1527] &rhs_rs = 1 [candle-core\src\cpu_backend.rs:1516] m = 1 [candle-core\src\cpu_backend.rs:1517] k = 4096 [candle-core\src\cpu_backend.rs:1518] n = 4096 [candle-core\src\cpu_backend.rs:1519] &dst_p.len() = 4096 [candle-core\src\cpu_backend.rs:1520] &dst_cs = 1 [candle-core\src\cpu_backend.rs:1521] &dst_rs = 4096 [candle-core\src\cpu_backend.rs:1522] &lhs_p.len() = 4096 [candle-core\src\cpu_backend.rs:1523] &lhs_cs = 1 [candle-core\src\cpu_backend.rs:1524] &lhs_rs = 4096 [candle-core\src\cpu_backend.rs:1525] &rhs_p.len() = 16777216 [candle-core\src\cpu_backend.rs:1526] &rhs_cs = 4096 [candle-core\src\cpu_backend.rs:1527] &rhs_rs = 1 [candle-core\src\cpu_backend.rs:1516] m = 1 [candle-core\src\cpu_backend.rs:1517] k = 128 [candle-core\src\cpu_backend.rs:1518] n = 13 [candle-core\src\cpu_backend.rs:1519] &dst_p.len() = 416 [candle-core\src\cpu_backend.rs:1520] &dst_cs = 1 [candle-core\src\cpu_backend.rs:1521] &dst_rs = 13 [candle-core\src\cpu_backend.rs:1522] &lhs_p.len() = 4096 [candle-core\src\cpu_backend.rs:1523] &lhs_cs = 1 [candle-core\src\cpu_backend.rs:1524] &lhs_rs = 128 [candle-core\src\cpu_backend.rs:1525] &rhs_p.len() = 53248 [candle-core\src\cpu_backend.rs:1526] &rhs_cs = 128 [candle-core\src\cpu_backend.rs:1527] &rhs_rs = 1 [candle-core\src\cpu_backend.rs:1516] m = 1 [candle-core\src\cpu_backend.rs:1517] k = 128 [candle-core\src\cpu_backend.rs:1518] n = 13 [candle-core\src\cpu_backend.rs:1519] &dst_p.len() = 403 [candle-core\src\cpu_backend.rs:1520] &dst_cs = 1 [candle-core\src\cpu_backend.rs:1521] &dst_rs = 13 [candle-core\src\cpu_backend.rs:1522] &lhs_p.len() = 3968 [candle-core\src\cpu_backend.rs:1523] &lhs_cs = 1 [candle-core\src\cpu_backend.rs:1524] &lhs_rs = 128 [candle-core\src\cpu_backend.rs:1525] &rhs_p.len() = 51584 [candle-core\src\cpu_backend.rs:1526] &rhs_cs = 128 [candle-core\src\cpu_backend.rs:1527] &rhs_rs = 1 [candle-core\src\cpu_backend.rs:1516] m = 1 [candle-core\src\cpu_backend.rs:1517] k = 128 [candle-core\src\cpu_backend.rs:1518] n = 13 [candle-core\src\cpu_backend.rs:1519] &dst_p.len() = 390 [candle-core\src\cpu_backend.rs:1520] &dst_cs = 1 [candle-core\src\cpu_backend.rs:1521] &dst_rs = 13 [candle-core\src\cpu_backend.rs:1522] &lhs_p.len() = 3840 [candle-core\src\cpu_backend.rs:1523] &lhs_cs = 1 [candle-core\src\cpu_backend.rs:1524] &lhs_rs = 128 [candle-core\src\cpu_backend.rs:1525] &rhs_p.len() = 49920 [candle-core\src\cpu_backend.rs:1526] &rhs_cs = 128 [candle-core\src\cpu_backend.rs:1527] &rhs_rs = 1
sarah-ek commented 11 months ago

gevv doesn't parallelize becaues the computation is memory-bound, and doesn't benefit much from parallelism

kstavro commented 11 months ago

Ok, I see. And what about gemv?

sarah-ek commented 11 months ago

same thing.