ggerganov / llama.cpp

LLM inference in C/C++
MIT License
65.41k stars 9.38k forks source link

Add Support for IBM Granite #7116

Closed YorkieDev closed 3 months ago

YorkieDev commented 4 months ago

Prerequisites

Please answer the following questions for yourself before submitting an issue.

Feature Description

IBM recently released their Granite models. A series of 3b -> 34b coding models with base and instruct finetunes.

https://huggingface.co/collections/ibm-granite/granite-code-models-6624c5cec322e4c148c8b330 https://github.com/ibm-granite

Many thanks to the llama.cpp community for their awesome work! It would be awesome to see this feature added. GGUF's can be made already, but when you try to load them you get a tokenizer error.

sroecker commented 4 months ago

The PR to add granite support for transformers (add MLP bias - gate, up, down) can be found here: https://github.com/huggingface/transformers/pull/30031/files

psyv282j9d commented 4 months ago

Based on the discussion in transformers mlp_bias PR, It's similar to Llama with just the mlp_bias added

sroecker commented 4 months ago

I tried to do this here: https://github.com/sroecker/llama.cpp/tree/add_mlp_bias Just adding bias to FFN_GATE, FFN_DOWN and FFN_UP. The tensor shapes seem to be correct but the model outputs gibberish.

./main -m ~/Downloads/granite-3b-code-base.Q8_0.gguf -p "Question: Python code to calculate the Fibonacci series\n\nAnswer:\n" with the GGUF from https://huggingface.co/NikolayKozloff/granite-3b-code-instruct-Q8_0-GGUF (

mayank31398 commented 4 months ago

@sroecker are you tying the word embeddings? unlike llama, the input word embeddings and output projection matrix are tied for granite models

sroecker commented 4 months ago

Ah, not yet. Thanks! I guess then we need to define an additional ARCH (or save the mlp_bias boolean in the GGUF) and implement it like with MPT https://github.com/ggerganov/llama.cpp/blob/7e0b6a7b3ba94ff624dc27c1e0e735fded8819b8/llama.cpp#L5287

Mayank Mishra @.***> schrieb am Mi., 8. Mai 2024, 10:59:

@sroecker https://github.com/sroecker are you tying the word embeddings? unlike llama, the input word embeddings and output projection matrix are tied for granite models

— Reply to this email directly, view it on GitHub https://github.com/ggerganov/llama.cpp/issues/7116#issuecomment-2100098327, or unsubscribe https://github.com/notifications/unsubscribe-auth/AACYR3PWJGC2D4PTSLE7D6DZBHSNZAVCNFSM6AAAAABHKJDCHOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMBQGA4TQMZSG4 . You are receiving this because you were mentioned.Message ID: @.***>

JohnClaw commented 4 months ago

(or save the mlp_bias boolean in the GGUF

Does exist a way to add mlp_bias to already made gguf? I ask about that because you mentioned my q8 gguf in one of your previous messages.

sroecker commented 4 months ago

(or save the mlp_bias boolean in the GGUF

Does exist a way to add mlp_bias to already made gguf? I ask about that because you mentioned my q8 gguf in one of your previous messages.

You could hack something with gguf writer https://pypi.org/project/gguf/

sroecker commented 4 months ago

So I've adapted build_llama to include the MLP biases as well. I've added a few FIXMEs to my branch to indicate places that might need to be adapted for the different Granite models. Now the output is valid text but unfortunately repeats itself: <|endoftext|>Question: Fibonacci series in Python? \n\nAnswer: Python? \n\n series in Python? \n\n series in Python? \n\n series in Python? \

dataf3l commented 4 months ago

I'm here to write words of support, I am interested in exploring what IBM + OLLAMA can do

adrianpuiu commented 4 months ago

I'm here to write words of support, I am interested in exploring what IBM + OLLAMA can do

+1 to this

davideuler commented 4 months ago

Is there any progress for the support of Granite models?

jpodivin commented 4 months ago

AFAIK, we have been stuck on the issue of repeating text output. It appears that the tokenizer is the culprit, but it does seem to be in order, correct token ids etc. I don't know if @sroecker made any strides since.

sroecker commented 4 months ago

AFAIK, we have been stuck on the issue of repeating text output. It appears that the tokenizer is the culprit, but it does seem to be in order, correct token ids etc. I don't know if @sroecker made any strides since.

Yes, unfortunately. The lab version of granite works well with llama.cpp: https://huggingface.co/instructlab/granite-7b-lab-GGUF It doesn't have the MLP bias nodes and uses a different tokenizer though. I've tried a few things regarding tokenization. I checked that the tokenizer creates the same input tokens with ./main -m granite-3b-code-base.gguf -p "def generate():" -ngl 0 --override-kv tokenizer.ggml.add_bos_token=bool:false. 0 -> '<|endoftext|>' 589 -> 'def' 4450 -> ' generate' 2262 -> '():' I recreated the f16 GGUF forcing the pre tokenizer to be llama-bpe instead of refact. No game so far. There's a lot of ARCH specific code all over llama.cppwhich might change important parameters so I'm thinking about creating a simple debugging example based on examples/simple/simple.cpp.

mayank31398 commented 4 months ago

the lab version is a different model not to be confused with this one

sroecker commented 4 months ago

the lab version is a different model not to be confused with this one

I'm aware of that, it did work out of the box with LLM_ARCH_LLAMA settings though so I'm trying to find out why exactly. But you're right to point this out, a few people mixed these up.

I will check the convert-hf-to-gguf-update.py script again to rule out the tokenizer before I start digging deeper.

mayank31398 commented 4 months ago

Hmm, a quick question: are we tying the word embeddings and output logits matrix? llama doesn't do that and granite has tied embeddings. maybe thats the issue? I don't think the tokenizer should be issue since all granite models use starcoder tokenizer.

sroecker commented 4 months ago

Hmm, a quick question: are we tying the word embeddings and output logits matrix? llama doesn't do that and granite has tied embeddings. maybe thats the issue? I don't think the tokenizer should be issue since all granite models use starcoder tokenizer.

If no output layer is found the word embeddings are used instead: https://github.com/ggerganov/llama.cpp/blob/541600201e6480f54ae09e58d16b154d4b4b331d/llama.cpp#L4926-L4932

mayank31398 commented 4 months ago

Hmm, ok so there are these differences between llama and granite:

  1. attention has bias (llama doesn't)
  2. mlp has bias (llama doesn't)
  3. tied word embeddings (llama doesn't)
  4. starcoder tokenizer
sroecker commented 4 months ago

Hmm, ok so there are these differences between llama and granite:

  1. attention has bias (llama doesn't)
  2. mlp has bias (llama doesn't)
  3. tied word embeddings (llama doesn't)
  4. starcoder tokenizer

Do all Granite code models use the starcoder tokenizer? Based on your HF repo comment I tried to get 20 and 34b to run. They are recognized as Starcoder arch by the convert-hf-to-gguf script and all I had to modify was to tie the embedding weights. 20b instruct works quite well, even with the bos token. The Q3_K_L quant comes down to 11GB. Please have a try with these changes: https://github.com/sroecker/llama.cpp/commit/6f201480de46aba0d5f718a2a8bdf424bd8e8274

For the 3 and 8b models 1) and 4) remain. We have to check if the attention bias is set up correctly in llm_build_kv, build_refact should be good for comparison.

mayank31398 commented 4 months ago

yeah all are using starcoder tokenizer.

DigitLib commented 4 months ago

If help for 8b-instruct model. After convert using

python3 llama.cpp/convert.py granite-8b-ins --outfile granite-8b-ins/granite-8b-instruct.bin --outtype q8_0 --vocab-type bpe --pad-vocab`

got this err when start ./llama.cpp/main -m ./granite-8b-ins/granite-8b-instruct.bin

llama_model_load: error loading model: done_getting_tensors: wrong number of tensors; expected 578, got 470

The same numbers shows when using convert-hf-to-gguf.py.

During the conversion it shows 578

Last two lines INFO:convert:[577/578] Writing tensor blk.35.attn_v.weight | size 1024 x 4096 | type Q8_0 | T+ 84 INFO:convert:[578/578] Writing tensor output_norm.weight | size 4096 | type F32 | T+ 84

Tried with q8_0, f16 and f32 same err.

Thank you for this great work!

mayank31398 commented 4 months ago

I don't think that 3b and 8b are working yet @DigitLib the 34b and 20b PR is merged and its working: https://github.com/ggerganov/llama.cpp/pull/7324

20b-base GGUF is available now: https://huggingface.co/ibm-granite/granite-20b-code-base-GGUF I will add the instruct and 34b tomorrow

DigitLib commented 4 months ago

@mayank31398 I know just wanted to help with 8b-instruct. Thank you!

giuseppe commented 4 months ago

@DigitLib you need https://github.com/sroecker/llama.cpp/commit/36dc5bbffe083545045ec2441ddc7f5c085d3caf to load the smaller models

mayank31398 commented 4 months ago

if that commit is working, can we open a PR @sroecker ?

giuseppe commented 4 months ago

that doesn't seem enough. The model is loaded but it doesn't produce any good result https://github.com/ggerganov/llama.cpp/issues/7116#issuecomment-2100061526

coder543 commented 4 months ago

https://huggingface.co/coder543/granite-20b-code-instruct-GGUF/tree/main

I've uploaded the q8_0, q6_K, and q4_0 gguf files for the 20B Instruct model here. I've only lightly tested them, and this is my first time quantizing any LLMs, but it seemed like they were working okay?

If anyone wants to test them, I'm curious if they work for you.

The chat template seems to be something like this:

Question:
Write a React TypeScript component

Answer:
giuseppe commented 4 months ago

I've managed to get some output that makes some sense with the 3b model, I've opened a PR:

IMHO it makes sense to define a new architecture for granite, as there are substantial differences with the base llama model. To convert the hf model using the code in my PR, I modified the config.json file in the granite model and used:

  "architectures": [
    "GraniteForCausalLM"
  ],

@mayank31398 what do you think?

celek commented 4 months ago

@giuseppe did you get the 3b gguf working with #7481 ? if you teach me how I can get it locally on my M1 I can run some tests too :)

HunterGerlach commented 4 months ago

To reproduce locally you can run the following:

  1. Clone down Giuseppe's branch, pip install necessary packages (e.g. torch, transformers, numpy, sentencepiece), and build Llama.cpp (i.e. run make)
  2. Download the 3B or 8B model from HF
  3. Modify the config.json per @giuseppe 's comment above (i.e. LlamaForCausalLM -> GraniteForCausalLM)
  4. Convert to GGUF (e.g. ./convert-hf-to-gguf.py path-to-granite-model --outtype q8_0 --outfile path-to-converted-model/converted-model-name.gguf)
  5. Run inference against the GGUF model (e.g. ./main -m path-to-converted-model/converted-model-name.gguf -p "Write a simple hello world script in python.")

Inference output should be something like the following (ignoring logging output for brevity):

print("Hello World")
mayank31398 commented 4 months ago

@giuseppe I think the problem is that it won't work out of the box when converting the model. I am not sure what a good solution is for this problem though

giuseppe commented 4 months ago

I've tried to extend the new arch to the bigger models, but it doesn't make sense as GPTBigCodeForCausalLM already works fine.

To avoid confusion, I've renamed the new architecture to GraniteSmallForCausalLM so it is clear that it applies only to the smaller models.

giuseppe commented 4 months ago

@giuseppe I think the problem is that it won't work out of the box when converting the model.

could we fix the name in the model files or would it cause other issues?

mayank31398 commented 4 months ago

Is there no other solution? we cannot change the name on HF since a lot of people have already created forks of the models. if there is an alternative method that we can explore, it would be better

maybe by checking if mlp_bias is true or not?

giuseppe commented 4 months ago

that looks like a generic setting that other models could use in future, but I have no weight in this decision. I will implement as it fits better for the granite model and llama.cpp maintainers.

Could a flag to the conversion script that forces the arch be good enough?

@ggerganov do you have any suggestions?

giuseppe commented 4 months ago

I've added the following patch to the PR:

diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py
index 2d05de42..eb7d061a 100755
--- a/convert-hf-to-gguf.py
+++ b/convert-hf-to-gguf.py
@@ -2571,6 +2571,10 @@ def parse_args() -> argparse.Namespace:
         "--no-lazy", action="store_true",
         help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)",
     )
+    parser.add_argument(
+        "--architecture", type=str, default=None,
+        help="force the architecture to use",
+    )
     parser.add_argument(
         "--model-name", type=str, default=None,
         help="name of the model",
@@ -2626,7 +2630,7 @@ def main() -> None:
     hparams = Model.load_hparams(dir_model)

     with torch.inference_mode():
-        model_class = Model.from_model_architecture(hparams["architectures"][0])
+        model_class = Model.from_model_architecture(args.architecture if args.architecture is not None else hparams["architectures"][0])
         model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, args.no_lazy)

         logger.info("Set model parameters")

so we can just add --architecture GraniteSmallForCausalLM to the command line and it works without having to change the model file.

Is this an acceptable solution?

mayank31398 commented 4 months ago

I think its better @ggerganov gives this a review.