kherud / java-llama.cpp

Java Bindings for llama.cpp - A Port of Facebook's LLaMA model in C/C++
MIT License
305 stars 32 forks source link

Performance issues on Macbook Pro M1 Max #4

Closed tiguchi closed 1 year ago

tiguchi commented 1 year ago

I'm running into completion performance problems on my Macbook Pro Max. I usually get completion performance of about 17 tokens per second with a 13B GGUF model. When I run the same completion task using the shared library it is really slow, and I get at most 4 tokens / second. It feels like the completion runs on the CPU instead on the GPU cores via Apple Metal. I see in the log output when the shared library loads a model that "metal" is repeatedly mentioned, though, so I'm a bit stumped.

I tried a variety of settings such as GPU layers, but nothing has an effect.

When I run the llama.cpp main application from the terminal then I get fully accelerated performance. I even don't have to specify any layers or thread counts, it just works at optimum performance.

Just a guess: it looks like context handling is done in Java (e.g. truncating the context), could there be some performance bottleneck caused by copying context data between llama.cpp's native memory and the Java heap?

kherud commented 1 year ago

I can reproduce the issues on my M1 macbook. Also I agree that the problems likely come from the JNA inherent memory copying overhead. Since performance is the first priority for this library I think, I started working on an alternative JNI implementation:

As soon as it is mature enough, I will merge this into master (hopefully during the coming week).

Currently, the basic functionality is working, but the compilation is still a little bit tricky. The remaining work is mainly to automate this via GitHub actions. If you want to have an early preview, you can look at the jni branch. The setup isn't documented so far, but there is a makefile that hopefully works.

AayushSameerShah commented 1 year ago

Hie @kherud, thanks for this thread.

Actually I have the requirements like given below:

Kind of poor game here I agree 😓 But I am getting really slow speed here as well around 2 tokens/second. I am not sure whether it is using all cores or I need to install BLAS or something else.

Is this the MAX speed I can get with CPU? If not, what steps should I follow to get more speed?

Please advise, thanks 🙏🏻


PS: I also checked on Colab (using python) where I am still getting the speed such as 3 tokens / sec. I am not sure using only pip install llama-cpp-python enough or not but here is the runnable colab.

I hope I am doing something wrong here, and hopefully by some fixes I will get better speed in Java too.

kherud commented 1 year ago

Hi @AayushSameerShah the poor performance currently comes from the way Java interfaces with the C++ library (the underlying technology is JNA). However, I am working on a new version using JNI, which will have much better performance. If you want, you can have an early look at the jni branch of this repository, the setup is basically the same cmake workflow as with llama.cpp. I will also publish a release in the next few days to Maven Central, which won't require any setup for CPU inference.

AayushSameerShah commented 1 year ago

Thanks @kherud I will surely check that out.

A weird request, any chance that you happen to know about any library in Java which I can use to run the T5 model?

I have tried John Snow Labs but it has certain memory related issues. I can't find other libraries in Java for the same.

There is DJL but it doesn't provide any generation parameters, temperature or sampling etc.

Can you suggest any? Thanks.

kherud commented 1 year ago

@AayushSameerShah My first bet would be to look into PyTorch bindings for Java. I think if you're memory constrained and looking for something like llama.cpp, the only option is to wait for the ggml project, which seems to have it on their roadmap (although Java bindings unfortunately don't exist yet, I think).

@tiguchi I just released version 2.0 of this library, which now uses JNI instead of JNA. I hope this resolves the performance issues. Let me know if you have any problems. I'm closing this ticket for now, but feel free to re-open if you still experience issues.

AayushSameerShah commented 1 year ago

Hello @kherud! I am amazed with your contribution for making this library. Since I have used this library before this JNI update I wanted to test out the results after this 2.0.0 update.

Now, the thing is I am still using the same model (Llama-chat-7b-4QM) which I was using before and after this update still I am getting around the same generation speed on CPU.

Let me walk you through my resources and steps I used:

1️⃣ Library installation

Since I'm using Windows, and it is having some issues (#11) I had to manually build the .dll files.

So I used:

cmake .. -DBUILD_SHARED_LIBS=ON -DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS

cmake --build . --config Release

I actually used with and without BLAS but the same speed in generation.

2️⃣ Resources

Base speed:2.10 GHz Sockets:1 Cores:6 Logical processors:12 Virtualization:Enabled L1 cache:384 KB L2 cache:3.0 MB L3 cache:8.0 MB

- **RAM**: 16 GB

### 3️⃣ The performance

llama_print_timings: load time = 1653.54 ms llama_print_timings: sample time = 24.73 ms / 72 runs ( 0.34 ms per token, 2910.97 tokens per second) llama_print_timings: prompt eval time = 11045.00 ms / 166 tokens ( 66.54 ms per token, 15.03 tokens per second) llama_print_timings: eval time = 14023.68 ms / 71 runs ( 197.52 ms per token, 5.06 tokens per second) llama_print_timings: total time = 25100.86 ms



___

Now, it looks slow, actually, it takes around 20 seconds to generate the SQL (for my prompt). While I am aware that I can't expect more speed while depending solely on CPU 😓 but, if there are any chances to increase the speed then I would **love to know it**.
___

> 🤔  And, what should I use in the logger to see the status of `BLAS=1` as to confirm BLAS is enabled? 

*(Here is the txt file for the whole response: 
[the_sql.txt](https://github.com/kherud/java-llama.cpp/files/12701808/the_sql.txt))*

Thank you 🤗 
AayushSameerShah commented 1 year ago

Hello @kherud 👋🏻 Apologies for commenting again, just to make sure if you've received my issue and a pretty silly question 😅

Thanks ^^

https://github.com/kherud/java-llama.cpp/issues/4#issuecomment-1731520398

khanjandharaiya commented 1 year ago

Hello @AayushSameerShah same issue for me.

@kherud Thank you for contributions it is greatly appreciated. please guide us to resolve this issue. 🙏

Thank you..

tiguchi commented 1 year ago

Hello @kherud! I am amazed with your contribution for making this library. Since I have used this library before this JNI update I wanted to test out the results after this 2.0.0 update.

Now, the thing is I am still using the same model (Llama-chat-7b-4QM) which I was using before and after this update still I am getting around the same generation speed on CPU.

Let me walk you through my resources and steps I used:

1️⃣ Library installation

Since I'm using Windows, and it is having some issues (#11) I had to manually build the .dll files.

So I used:

cmake .. -DBUILD_SHARED_LIBS=ON -DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS

cmake --build . --config Release

I actually used with and without BLAS but the same speed in generation.

2️⃣ Resources

  • CPU:
AMD Ryzen 5 5500U with Radeon Graphics

Base speed:2.10 GHz
Sockets:1
Cores:6
Logical processors:12
Virtualization:Enabled
L1 cache:384 KB
L2 cache:3.0 MB
L3 cache:8.0 MB
  • RAM: 16 GB

3️⃣ The performance

llama_print_timings: load time = 1653.54 ms
llama_print_timings: sample time = 24.73 ms / 72 runs ( 0.34 ms per token, 2910.97 tokens per second)
llama_print_timings: prompt eval time = 11045.00 ms / 166 tokens ( 66.54 ms per token, 15.03 tokens per second)
llama_print_timings: eval time = 14023.68 ms / 71 runs ( 197.52 ms per token, 5.06 tokens per second)
llama_print_timings: total time = 25100.86 ms

Now, it looks slow, actually, it takes around 20 seconds to generate the SQL (for my prompt). While I am aware that I can't expect more speed while depending solely on CPU 😓 but, if there are any chances to increase the speed then I would love to know it.

🤔 And, what should I use in the logger to see the status of BLAS=1 as to confirm BLAS is enabled?

_(Here is the txt file for the whole response: the_sql.txt)_

Thank you 🤗

Can you share your Java code, specifically the ModelParameters configuration you use?

AayushSameerShah commented 1 year ago

Can you share your Java code, specifically the ModelParameters configuration you use?

Hello @tiguchi , Here is the JavaCode that I use...

// imports...

public class lammaTest {
public static void main(String... args) throws IOException {
LlamaModel.setLogger((level, message) -> System.out.print(message));
    ModelParameters modelParams = new ModelParameters.Builder()
            .build();
    InferenceParameters inferParams = new InferenceParameters.Builder()
            .setTemperature(0.7f)
            .setPenalizeNl(true)
            .setMirostat(InferenceParameters.MiroStat.V2)
            .setAntiPrompt(new String[]{"\n"})
            .build();

    String modelPath = "path.gguf";
    String instruction = "Instructions:\n Write MYSQL queries to solve the following problem that obeys the constraints. No need to explain the queries and give only query. Please wrap your query answer using ```:\n"
            + "Please use column given in metadata below do not generate any additional column and do not create dummy data for this task\n"
            + "Be careful to not query for columns that do not exist.\n"
            + "List the table name which you are using"
            + "Never query for all columns from a table. You must query only the columns that are needed to answer the question.";
    String userPrompt = "Question: What was total sales last month?\n"
            +  "CREATE TABLE Sales"
         + "(GrossSales number,"
         + "ProductName text,"
         + "ProductCategory text,"
         + "CostofGoods number,"
         + "EmployeeName text,"
         + "Date Date,"
         + "State text,"
         + "City text,"
         + "Stock number)\n"
         + "Response:\n";

    String codeTemplate = "```sql\n{prompt}\n```";

    try (LlamaModel model = new LlamaModel(modelPath, modelParams)) {
        String separator = "\n"; // Choose a separator that your LlamaModel can recognize
        String prompt = instruction + separator + userPrompt;
        String formattedPrompt = codeTemplate.replace("{prompt}", prompt);

        System.out.print("Prompt: " + formattedPrompt);
        System.out.print("\nModel: ");

        long inference_start = System.nanoTime();
        String out = model.complete(prompt, inferParams);
        System.out.println(out);
        long inference_end = System.nanoTime();
        System.out.print("Execution FOR INFERENCE time is \n"+ (inference_end - inference_start) / (1_000_000_000.0) + " seconds");

    }    
}    
}

Thanks 🤗

tiguchi commented 1 year ago

Can you share your Java code, specifically the ModelParameters configuration you use?

Hello @tiguchi , Here is the JavaCode that I use...


    ModelParameters modelParams = new ModelParameters.Builder()
            .build();

There seems to be the problem right away. You are not specifying any GPU layers in the model parameters. I assume it defaults to zero, in which case the CPU is used for inference.

Try the following setting instead, and gradually raise the number of layers until you get an error message:

        ModelParameters modelParams = new ModelParameters.Builder()
                .setNGpuLayers(10)
                .build();