mukel / llama2.java

Inference Llama 2 in one file of pure Java
MIT License
211 stars 28 forks source link

Using llama2 7b but it's super slow #7

Closed Boofii closed 6 months ago

Boofii commented 6 months ago

The performance tab says that it runs on almost 1 token per second, mine takes a few minutes only for one word, and it's not consistent. How do I make it faster?

mukel commented 6 months ago

You need enough RAM to keep all the weights in memory e.g. the .bin version of Llama 2 7B requires > 32GB of RAM. This is a limitation of the .bin used by llama2.c which stores all weights as full floats.

Please note that GraalVM (if you are using it) does not support the Vector API yet, you can disable vectorized matmul with -Dllama2.VectorAPI=false. You can also run with OpenJDK which supports the Vector API.

plokhotnyuk commented 6 months ago

For yet more speed up you can use TornadoVM's port of this amazing project that uses GraalVM's auto-vectorization and GPU with different types of floats.

Boofii commented 6 months ago

You need enough RAM to keep all the weights in memory e.g. the .bin version of Llama 2 7B requires > 32GB of RAM. This is a limitation of the .bin used by llama2.c which stores all weights as full floats.

Please note that GraalVM (if you are using it) does not support the Vector API yet, you can disable vectorized matmul with -Dllama2.VectorAPI=false. You can also run with OpenJDK which supports the Vector API.

In my computer I have a total of 16 gb of ram

mukel commented 6 months ago

I'm working on a Llama 3 version that supports GGUF and some quantizations ... Llama 3 8B Q4_0 runs at 7 tokens/s on my laptop, stay tuned.

Boofii commented 6 months ago

I'm working on a Llama 3 version that supports GGUF and some quantizations ... Llama 3 8B Q4_0 runs at 7 tokens/s on my laptop, stay tuned.

How good is your laptop? and meanwhile is there some alternative fix I can do?

mukel commented 6 months ago

Inference is constrained by memory bandwidth, not compute. Llama 3 8B, quantized with Q4_0 is ~4.5GB, should fit even in modest machines and should be plenty fast e.g. within 10% of llama.cpp Basically, find your memory bandwidth, and divide it by the (quantized) model size to get good guess on the tokens/s you'd get.

I work on this only on my free time, give me a few days/weeks to polish and release the Llama 3 version.

Boofii commented 6 months ago

Inference is constrained by memory bandwidth, not compute. Llama 3 8B, quantized with Q4_0 is ~4.5GB, should fit even in modest machines and should be plenty fast e.g. within 10% of llama.cpp Basically, find your memory bandwidth, and divide it by the (quantized) model size to get good guess on the tokens/s you'd get.

I work on this only on my free time, give me a few days/weeks to polish and release the Llama 3 version.

Oh until now I used the llama2-7b.bin file I got by using the export.py script, it's 27gb so if I'll use a smaller file will it be faster? If so how can I get a smaller file? this Q4 you mentioned or something

mukel commented 6 months ago

llama2.java, in its current state, only supports full precision tensors (float32 weights) ... it needs extra work to support quantized tensors. I already implemented some quantizations (Q4_0, Q4_1, Q8_0 ... but not k-quants yet) , support for GGUF, Llama 3 support ...

Boofii commented 6 months ago

llama2.java, in its current state, only supports full precision tensors (float32 weights) ... it needs extra work to support quantized tensors. I already implemented some quantizations (Q4_0, Q4_1, Q8_0 ... but not k-quants yet) , support for GGUF, Llama 3 support ...

I see, so I'll wait until your llama 3 support will come out.