srush / llama2.rs

A fast llama2 decoder in pure Rust.
MIT License
1.01k stars 56 forks source link

Quick Code Review: Auto-vectorization #2

Open gaxler opened 1 year ago

gaxler commented 1 year ago

Hi Sasha! Nice to see your take on the llama2.rs! I did a port of Anrej’s llama.c here

Had a chance to go over the code and compare to my version, the only thing I want to mention is that you can make the compiler auto-vectorize some computations (most notably matmul)

Did a quick benchmark of our implementations. Your's runs at ~52t/s mine runs at ~75t/s (on 2CPU/4Gb codespaces VM, running stories15M model). My guess is that most of the difference is because the compiler can’t auto-vectorize your matmul implementation.

A good way to help the compiler to auto-vectorize is to use iterators as much as possible. Key idea is to replace the following loop with an iterator:

  xout.par_iter_mut().enumerate().for_each(|(i, v)| {
        let mut val = 0.0;
        for j in 0..n {
            val += w[i * n + j] * x[j];
        }
        *v = val;
    })

So doing the following closes most of the gap and puts you implementation at ~74t/s

xout.par_iter_mut().enumerate().for_each(|(i, v)| {
        *v= w[i*n..(i+1)*n] 
            .iter()
            .zip(x.iter())
            .fold(0f32, |acc, (&_w, &_x)| acc + _w * _x);

You can force the compiler to try and auto-vectorize with avx2 by passing the compiler flags: RUSTFLAGS=“-C target-feature=+avx2,+fma"

srush commented 1 year ago

Amazing, that's really helpful to know. Thanks for pointing it out.

Do you plan on continuing to work on this? Was planning on moving on, but now I'm kind of curious to implement quantization to scale to bigger models. Would be happy to collaborate if you were interested.

srush commented 1 year ago

Nice. This bumped me up from 0.92 t/s to 1.02 t/2 on llama2 7B.

gaxler commented 1 year ago

Amazing, that's really helpful to know. Thanks for pointing it out.

Do you plan on continuing to work on this? Was planning on moving on, but now I'm kind of curious to implement quantization to scale to bigger models. Would be happy to collaborate if you were interested.

Yeah, that's what I hope to do before moving on. Would love to colaborate. I started playing with some prototypes on this branch (it's not very reader friendly yet)

Nice. This bumped me up from 0.92 t/s to 1.02 t/2 on llama2 7B.

Nice! I wonder why the speed-up is so small compared to the 15M mode. Maybe the CPU waits on mmap page swaps?

srush commented 1 year ago

Nice, I will try to catch up on your code.

Some of the HF people recommended trying to do GPTQ inference (quant-full mat-vec). Which version are you doing?

gaxler commented 1 year ago

My code is mostly experimenting with ways to do a clean matmul interface. Did a naive rowwise i8 quantization of weights and matmuls that gets accumulated to f32. But that's really just the first thing that poped to mind.

srush commented 1 year ago

hi! I saw that you are also a maintainer of Triton and worked on the AoT compiler. I'm playing around with trying to set this project up to use Triton just to learn it. Do you have any tips for getting this to work? I tried exporting PTX which worked reasonably well at first, but I think I am running into issues with calling into it from Rust. Curious if you had pointers to recommended ways to do it?

My hacky code: https://github.com/srush/llama2.rs/pull/35/files#diff-7c199e27f9cec983de845ad01b4fd4e558534ee33fd49d8134cbab879361af67R158

gaxler commented 1 year ago

What are the issues that you have? Is that slowness or something Triton related?

I left some comments in the PR.

srush commented 1 year ago

Thanks, once I got it running it was fast, but then when I tried to further optimize the Triton code, the rust version went out of sync with the python version. Trying to make a minimal example.