srush / llama2.rs

A fast llama2 decoder in pure Rust.
MIT License
995 stars 54 forks source link

Speed comparison #41

Open Will-Zhao0 opened 9 months ago

Will-Zhao0 commented 9 months ago

Hi I found this nice project while learning rust. At the same time I'm curious about its speed comparing to llama on pytorch, and what are the pros and cons of implementing it in rust on CPU?

Thank you

srush commented 9 months ago

Honestly it was mostly a learning project when I started. However, I do think it would be faster than pytorch on CPU. It contains a lot of optimizations that directly make use of quantization+SIMD. Think that is harder to do in pytorch directly witout incurring python overhead. Think this library is about as fast as you can make it on CPU without resorting to tricks like speculative sampling.

It is currently much slower on GPU, but also hoping to eventually make it much faster by writing kernels in JAX+PALLAS. Probably will not be as fast as some of the libraries out there (I heard MLC is wildly fast), but might be interesting to try.

navr32 commented 3 months ago

Do you have done some test with Burn ? Rust implementation of Tensor learning / computation...Burn is able to inference ..for example : https://github.com/Gadersd/llama2-burn Burn is able to switch with ease the computation back-end..so this very interesting...and burn looks very optimized ?