tjake / Jlama

Jlama is a modern LLM inference engine for Java
Apache License 2.0
656 stars 60 forks source link

CodeLlama loading is broken? #16

Closed jbellis closed 8 months ago

jbellis commented 9 months ago

This worked in Oct 15 jlama:

$ ./run-cli.sh complete -p "def fib(" -t 0.2 -tc 24 -n 100 models/CodeLlama-7b-hf

Now it OOMs (note that I have doubled the default Xmx, which was not necessary in Oct)

Exception in thread "main" picocli.CommandLine$ExecutionException: Error while running command (com.github.tjake.jlama.cli.commands.CompleteCommand@32b260fa): java.lang.RuntimeException: java.lang.reflect.InvocationTargetException
    at picocli.CommandLine.executeUserObject(CommandLine.java:2035)
    at picocli.CommandLine.access$1500(CommandLine.java:148)
    at picocli.CommandLine$RunLast.executeUserObjectOfLastSubcommandWithSameParent(CommandLine.java:2461)
    at picocli.CommandLine$RunLast.handle(CommandLine.java:2453)
    at picocli.CommandLine$RunLast.handle(CommandLine.java:2415)
    at picocli.CommandLine$AbstractParseResultHandler.handleParseResult(CommandLine.java:2264)
    at picocli.CommandLine.parseWithHandlers(CommandLine.java:2664)
    at picocli.CommandLine.parseWithHandler(CommandLine.java:2599)
    at com.github.tjake.jlama.cli.JlamaCli.main(JlamaCli.java:30)
Caused by: java.lang.RuntimeException: java.lang.reflect.InvocationTargetException
    at com.github.tjake.jlama.model.ModelSupport.loadModel(ModelSupport.java:111)
    at com.github.tjake.jlama.model.ModelSupport.loadModel(ModelSupport.java:66)
    at com.github.tjake.jlama.cli.commands.CompleteCommand.run(CompleteCommand.java:16)
    at picocli.CommandLine.executeUserObject(CommandLine.java:2026)
    ... 8 more
Caused by: java.lang.reflect.InvocationTargetException
    at java.base/jdk.internal.reflect.DirectConstructorHandleAccessor.newInstance(DirectConstructorHandleAccessor.java:74)
    at java.base/java.lang.reflect.Constructor.newInstanceWithCaller(Constructor.java:502)
    at java.base/java.lang.reflect.Constructor.newInstance(Constructor.java:486)
    at com.github.tjake.jlama.model.ModelSupport.loadModel(ModelSupport.java:107)
    ... 11 more
Caused by: java.lang.OutOfMemoryError
    at java.base/jdk.internal.reflect.DirectConstructorHandleAccessor.newInstance(DirectConstructorHandleAccessor.java:62)
    at java.base/java.lang.reflect.Constructor.newInstanceWithCaller(Constructor.java:502)
    at java.base/java.lang.reflect.Constructor.newInstance(Constructor.java:486)
    at java.base/java.util.concurrent.ForkJoinTask.getThrowableException(ForkJoinTask.java:542)
    at java.base/java.util.concurrent.ForkJoinTask.reportException(ForkJoinTask.java:567)
    at java.base/java.util.concurrent.ForkJoinTask.invoke(ForkJoinTask.java:670)
    at java.base/java.util.stream.ForEachOps$ForEachOp.evaluateParallel(ForEachOps.java:160)
    at java.base/java.util.stream.ForEachOps$ForEachOp$OfInt.evaluateParallel(ForEachOps.java:189)
    at java.base/java.util.stream.AbstractPipeline.evaluate(AbstractPipeline.java:233)
    at java.base/java.util.stream.IntPipeline.forEach(IntPipeline.java:463)
    at java.base/java.util.stream.IntPipeline$Head.forEach(IntPipeline.java:620)
    at com.github.tjake.jlama.model.llama.LlamaModel.loadTransformerBlockWeights(LlamaModel.java:56)
    at com.github.tjake.jlama.model.AbstractModel.<init>(AbstractModel.java:109)
    at com.github.tjake.jlama.model.llama.LlamaModel.<init>(LlamaModel.java:31)
    at java.base/jdk.internal.reflect.DirectConstructorHandleAccessor.newInstance(DirectConstructorHandleAccessor.java:62)
    ... 14 more
Caused by: java.lang.OutOfMemoryError: Cannot reserve 180355136 bytes of direct buffer memory (allocated: 25708094948, limit: 25769803776)
    at java.base/java.nio.Bits.reserveMemory(Bits.java:178)
    at java.base/java.nio.DirectByteBuffer.<init>(DirectByteBuffer.java:127)
    at java.base/java.nio.ByteBuffer.allocateDirect(ByteBuffer.java:360)
    at com.github.tjake.jlama.util.UnsafeDirectByteBuffer.allocateAlignedByteBuffer(UnsafeDirectByteBuffer.java:36)
    at com.github.tjake.jlama.tensor.FloatBufferTensor.<init>(FloatBufferTensor.java:73)
    at com.github.tjake.jlama.safetensors.Weights.load(Weights.java:112)
    at com.github.tjake.jlama.safetensors.WeightLoader.load(WeightLoader.java:16)
    at com.github.tjake.jlama.safetensors.SafeTensorIndex.load(SafeTensorIndex.java:172)
    at com.github.tjake.jlama.model.llama.LlamaModel.lambda$loadTransformerBlockWeights$1(LlamaModel.java:70)
    at java.base/java.util.stream.ForEachOps$ForEachOp$OfInt.accept(ForEachOps.java:205)
    at java.base/java.util.stream.Streams$RangeIntSpliterator.forEachRemaining(Streams.java:104)
    at java.base/java.util.Spliterator$OfInt.forEachRemaining(Spliterator.java:712)
    at java.base/java.util.stream.AbstractPipeline.copyInto(AbstractPipeline.java:509)
    at java.base/java.util.stream.ForEachOps$ForEachTask.compute(ForEachOps.java:291)
    at java.base/java.util.concurrent.CountedCompleter.exec(CountedCompleter.java:754)
    at java.base/java.util.concurrent.ForkJoinTask.doExec(ForkJoinTask.java:387)
    at java.base/java.util.concurrent.ForkJoinPool$WorkQueue.topLevelExec(ForkJoinPool.java:1312)
    at java.base/java.util.concurrent.ForkJoinPool.scan(ForkJoinPool.java:1843)
    at java.base/java.util.concurrent.ForkJoinPool.runWorker(ForkJoinPool.java:1808)
    at java.base/java.util.concurrent.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:188)
tjake commented 9 months ago

I believe this is because in OCT the quantization was happening at runtime. You can do this with the command line args in run-cli.sh using the -q Q4 parameter, or you can quantize the model once with the quantize command.

Here's how I quantize llama models.

./run-cli.sh quantize -q Q4 -s "model.embed_tokens.weight" -s "lm_head.weight" models/Mixtral-8x7B-Instruct-v0.1/