mukel / llama2.java

Inference Llama 2 in one file of pure Java
MIT License
211 stars 28 forks source link

Llamafile comparison? #8

Open hrstoyanov opened 6 months ago

hrstoyanov commented 6 months ago

@mukel thank you for creating this project! I would like to discuss the following topics:

  1. Please enable the Discussions tab for posts like this, which are not real "issues"

  2. Do you plan on releasing Llama3 code?

  3. Do you plan on quantized llama models with Java vector api?

  4. Can you run a benchmark against llamafile, the vector version of which (AVX, neon) claims to be the performance king for inference. (I am deciding between using your project or wrapping around the llamafile c code with Java 22 foreign function apis)

  5. Do you plan to implant model training as well? If so, take a look at Andrey's LLM.c repo

mukel commented 6 months ago

Please enable the Discussions tab for posts like this, which are not real "issues"

Done! Thanks for the hint!

Do you plan on releasing Llama3 code?

Yes, I have a VERY cool demo running locally, stay tuned. I'll release the demo first, then Llama (3) as a Java library later e.g. including a backend for Langchain4j, supporting function calling and structured responses (JSON) based on NousResearch's Hermes fine-tunes.

Do you plan on quantized llama models with Java vector api?

Yes, I have already fast matmul routines for Q4_0, Q4_1 and Q8_0 using the Vector API. k-quants are still a WIP.

Can you run a benchmark against llamafile, the vector version of which (AVX, neon) claims to be the performance king for inference. (I am deciding between using your project or wrapping around the llamafile c code with Java 22 foreign function apis)

I bet llamafile is a bit faster, but hopefully soon, the Graal compiler will add enough support of the Vector API to optimize, at least, the matmul routines. I big advantage is that you have direct access to the model guts e.g. you can serialize/compress/cache/hack the model state, tweak the inference rather easy.

Do you plan to implant model training as well? If so, take a look at Andrey's LLM.c repo

I was tempted to port llm.c to Java as well, but training requires way more resources and accessing GPUs from Java is rather cumbersome ATM. With TornadoVM and Project Babylon this could change...

mukel commented 6 months ago

Please check Llama3.java.

I compared mistral-7b-instruct-v0.2.Q4_0.llamafile and Mistral.java using the following HW configuration:
Intel 13900H 6pC/8eC/20T 64GB (5200) Linux 6.6.30

Mistral.java

jbang Mistral.java \
  --model Mistral-7B-Instruct-0.2-Q4_0.gguf \
  --max-tokens 512 \
  --prompt "Why is the sky blue?"

Output:

Parse Mistral-7B-Instruct-v0.2-Q4_0.gguf: 75 millis
Load LlaMa model: 100 millis
...
7.41 tokens/s (186)

llamafile

./mistral-7b-instruct-v0.2.Q4_0.llamafile \
  -n 512 \
  -p "[INST] Why is the sky blue? [/INST] "

Output:

...
llama_print_timings:        load time =     239.69 ms
llama_print_timings:      sample time =       4.91 ms /   164 runs   (    0.03 ms per token, 33421.64 tokens per second)
llama_print_timings: prompt eval time =     651.57 ms /    15 tokens (   43.44 ms per token,    23.02 tokens per second)
llama_print_timings:        eval time =   25324.82 ms /   163 runs   (  155.37 ms per token,     6.44 tokens per second)
llama_print_timings:       total time =   26000.08 ms /   178 tokens

I discovered that llamafile, by default, uses 1 thread per performance core, no hyper-threading, no economy cores ... Running llamafile with -t 20 (all threads available):

...
llama_print_timings:        load time =     285.27 ms
llama_print_timings:      sample time =       4.32 ms /   141 runs   (    0.03 ms per token, 32646.45 tokens per second)
llama_print_timings: prompt eval time =     680.35 ms /    15 tokens (   45.36 ms per token,    22.05 tokens per second)
llama_print_timings:        eval time =   19403.89 ms /   140 runs   (  138.60 ms per token,     7.22 tokens per second)
llama_print_timings:       total time =   20105.27 ms /   155 tokens

With -t 12 (all threads of performance cores):

...
llama_print_timings:        load time =     234.36 ms
llama_print_timings:      sample time =       4.14 ms /   100 runs   (    0.04 ms per token, 24154.59 tokens per second)
llama_print_timings: prompt eval time =     958.05 ms /    15 tokens (   63.87 ms per token,    15.66 tokens per second)
llama_print_timings:        eval time =   13545.93 ms /    99 runs   (  136.83 ms per token,     7.31 tokens per second)
llama_print_timings:       total time =   14523.26 ms /   114 tokens

This is a simple comparison, nothing rigorous, llama.cpp beats Llama3.java and llamafile probably outperforms Llama3.java and Mistral.java; I find myself these results rather surprising and unexpected TBH. One thing to take into account is that llamafile/llama.cpp do ingest prompts much faster.

hrstoyanov commented 5 months ago

Just saw you dropped it .. will check it out! Thank you @mukel !