sarah-quinones / faer-rs

Linear algebra foundation for the Rust programming language
https://faer-rs.github.io
MIT License
1.82k stars 61 forks source link

Add support for matrix multiplication with `half::f16`/`half::bf16` #32

Closed coreylowman closed 1 year ago

coreylowman commented 1 year ago

Is your feature request related to a problem? Please describe. Currently none of the matrix multiplication crates support the half (fp16) datatype. There is the half crate (https://crates.io/crates/half) that includes a rust type for this.

This means for any crate that needs to support matrix multiplication on CPU with f16 datatype, they need to manually implement a matrix multiplication algorithm, which can be very slow.

Describe the solution you'd like It'd be great if faer included configurable support for these data types.

Describe alternatives you've considered I'm currently using this super naive matmul implementation:

fn naive_gemm<F: num_traits::Float + std::ops::AddAssign, M: Dim, K: Dim, N: Dim>(
    (m, k, n): (M, K, N),
    ap: *const F,
    a_strides: [usize; 2],
    bp: *const F,
    b_strides: [usize; 2],
    cp: *mut F,
    c_strides: [usize; 2],
) {
    for i_m in 0..m.size() {
        for i_k in 0..k.size() {
            for i_n in 0..n.size() {
                unsafe {
                    let a = *ap.add(a_strides[0] * i_m + a_strides[1] * i_k);
                    let b = *bp.add(b_strides[0] * i_k + b_strides[1] * i_n);
                    let c = cp.add(c_strides[0] * i_m + c_strides[1] * i_n);
                    *c += a * b;
                }
            }
        }
    }
}

Which is extremely slow. I'm not sure what the other alternatives are.

Additional context

I'm the author of dfdx. dfdx has both cuda/CPU support, and CUDA does have hardware acceleration for f16. It'd just be nice if f16 matmul on CPU was a bit faster!

sarah-quinones commented 1 year ago

f16 matrix multiplication is now implemented in gemm v0.15, though integration in faer is unlikely. i would ideally like to wait until f16 gets proper simd support so i can figure out how to best integrate it into the library