mukel / llama2.java

Inference Llama 2 in one file of pure Java
MIT License
211 stars 28 forks source link

A Java port of Andrej Karpathy's llama2.c

**Check the successor of this project: Llama3.java: Practical Llama (3) inference in a single Java file, with additional features, including a --chat mode.

This is a pure Java port of Andrej Karpathy's awesome llama2.c, a very simple implementation to run inference of models with a Llama2-like transformer-based LLM architecture.

Currently, there isn't anything really original here, but I'll continue polishing it while keeping it in sync with the original.
Besides the educational value, this project will be used to test and tune compiler optimizations on the JVM, particularly for the Graal compiler. This port used llama2.scala initially as a reference.

Build

Java 21+ is required, in particular the MemorySegment mmap-ing feature.

The code expects tokenizer.bin in the current directory. You can use TinyStories checkpoints or get LLama2 models by following instructions.

wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin
wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin

To build and run manually:

javac --enable-preview -source 21 --add-modules=jdk.incubator.vector Llama2.java
java --enable-preview --add-modules=jdk.incubator.vector Llama2 stories15M.bin

Or run it directly with JBang:

jbang Llama2.java stories15M.bin
# With additional -D options and custom Java home.
JAVA_HOME=/path/to/java/home jbang -Djava.util.concurrent.ForkJoinPool.common.parallelism=0 -Dllama2.VectorAPI=false Llama2.java stories15M.bin

A Makefile and a run.sh script are also provided:

make # optional, run.sh already runs make

JAVA_HOME=$GRAALVM_HOME \
JAVA_RUNTIME_OPTIONS=-Djava.util.concurrent.ForkJoinPool.common.parallelism=8 \
./run.sh stories15M.bin

Native image

A standalone native image can be created with GraalVM

JAVA_HOME=$GRAALVM_HOME NATIVE_IMAGE_OPTIONS="-march=native" make native-image
./llama2 stories15M.bin

Or can also be built with Profile-Guided Optimizations (PGO), on Oracle GaaalVM:

JAVA_HOME=$GRAALVM_HOME \
NATIVE_IMAGE_OPTIONS="--pgo-instrument -march=native --initialize-at-build-time=Llama2 -Dllama2.VectorAPI=false" \
make native-image

# Profile run to generate default.iprof, with no parallelism to speedup profiling.
./llama2 -Djava.util.concurrent.ForkJoinPool.common.parallelism=0 stories15M.bin

# Build optimized image
JAVA_HOME=$GRAALVM_HOME \
NATIVE_IMAGE_OPTIONS="--pgo -march=native --initialize-at-build-time=Llama2 -Dllama2.VectorAPI=false" \
make native-image

# Should run ~2X faster than regular image.
./llama2 stories15M.bin

Performance

Quick numbers on an AMD Ryzen 3950X 64GB, Arch Linux.
llama2.java executed on OpenJDK 20.0.2+9.
To make things fair w.r.t. to vectorization, the Java version has a matmul implementation using the Vector API.
In these measurements the JVM is warmed up enough to reach peak tokens/s.
On GraalVM, please note that the Graal compiler doesn't support the Vector API yet, to avoid unexpected performance degradation, run with -Dllama2.VectorAPI=false.

**Notes
The numbers below were collected using aggressive (gcc) compiler flags e.g. regular gcc -O2 ... wouldn't be as fast.

Single-threaded

llama2.c compiled with gcc -Ofast -march=native run.c -lm -o run -march=native
llama2.java executed with -Djava.util.concurrent.ForkJoinPool.common.parallelism=0

Model Tokens per second Speedup vs. llama2.c Implementation
stories15M.bin 363 1.0 llama2.c
stories15M.bin 237 0.65 llama2.java
stories110M.bin 51.71 1.0 llama2.c
stories110M.bin 42.20 0.81 llama2.java
llama2_7B.bin 0.92 1.0 llama2.c
llama2_7B.bin 0.88 0.95 llama2.java

Multi-threaded

llama2.c compiled with gcc -Ofast -fopenmp -march=native run.c -lm -o run -march=native
llama2.c executed with OMP_NUM_THREADS=8
llama2.java executed with -Djava.util.concurrent.ForkJoinPool.common.parallelism=8

Model Tokens per second Speedup vs. llama2.c Implementation
stories15M.bin 1233 1.0 llama2.c
stories15M.bin 438 0.35 llama2.java
stories110M.bin 90 1.0 llama2.c
stories110M.bin 80 0.88 llama2.java
llama2_7B.bin 1.68 1.0 llama2.c
llama2_7B.bin 1.65 0.98 llama2.java

**Notes
In stories15M.bin, the C version shows a huge speedup, very likely a cache effect, this is considered an outlier. Running with 16/32 threads may actually cause a slowdown; the performance is, in most cases, U-shaped w.r.t to the # of threads. With that many threads, vectorization does not give any advantage, since throughput is limited by memory bandwidth.

Performance is already comparable to the original C code, bar vectorization, even if the Java code has not been optimized yet.

License

MIT