ggerganov / llama.cpp

LLM inference in C/C++
MIT License
63.67k stars 9.12k forks source link

Suport for Jamba JambaForCausalLM #6372

Open maziyarpanahi opened 4 months ago

maziyarpanahi commented 4 months ago

Prerequisites

Please answer the following questions for yourself before submitting an issue.

Feature Description

Please provide a detailed written description of what you were trying to do, and what you expected llama.cpp to do as an enhancement.

A new MoE model was released today: JambaForCausalLM https://huggingface.co/ai21labs/Jamba-v0.1

Motivation

Please provide a detailed written description of reasons why this feature is necessary and how it is useful to llama.cpp users.

Another very good and open LLM

Possible Implementation

If you have an idea as to how it can be implemented, please write a detailed description. Feel free to give links to external sources or share visuals that might be helpful to understand the details better.

I can test any PR candidate

nonetrix commented 4 months ago

Have smaller Mamba based LLMs already been added in the past?

Green-Sky commented 4 months ago

@compilade added mamba support. But Jamba seems to be a derivative and needs code modifications.

compilade commented 4 months ago

I'd like it very much if they released a smaller version of their model. I don't have enough RAM to run Mixtral (only have 8GB), and Jamba seems to be around the same size as Mixtral. A model with less than 1B total parameters (or even less than 200M) would be ideal for quickly figuring out implementation problems (and would waste much less disk space when debugging or modifying model conversion).

My free time is too scarce at the moment to work on this (until May). The KV cache of this model will be some complicated beast (it's both recurrent and attention-based, but never in the same "layer". This will require rethinking how the KV cache is allocated, and how Mamba's state is stored), but I think it should still be possible to support eventually, given enough effort.

Similarly to llm_build_ffn, I think there will need to be some kind of llm_build_mamba to more easily share the code building the graph of a Mamba block between Mamba and Jamba.

Anyone wanting to work on this should start by building a strong mental model of how Mamba's state is managed in llama.cpp, as well as how the KV cache works (at least what goes where, not necessarily why). This is necessary because modifications of both of these will likely be needed to make this work.

Mamba in llama.cpp uses 1 KV cell per sequence (we'll probably need to introduce some other tensor lists than k_l and v_l in llama_kv_cache to avoid conflicting with attention's one KV cells per token (a different set of cells will be required (and yet another session file format revision))). Sequences are selected with inp_s_seq in ggml_compute_forward_ssm_conv_f32 and ggml_compute_forward_ssm_scan_f32. Each token from a batch has one input state/sequence, but the resulting state is copied to all the sequences assigned to that token.

Simplifying how recurrent state operations are implemented is on my TODO list, and implementing both Jamba and RWKV should help with refactoring, but Jamba support in llama.cpp feels like a multi-week project, and I'll only have this kind of free time in May.

If anyone's too impatient, feel free to experiment and figure out a way to make Jamba work with llama.cpp. Even incomplete proofs of concept of how to manage the Jamba blocks should be useful.

maziyarpanahi commented 4 months ago

@compilade added mamba support. But Jamba seems to be a derivative and needs code modifications.

for reference: https://github.com/ggerganov/llama.cpp/pull/5328

sorasoras commented 4 months ago

Have smaller Mamba based LLMs already been added in the past?

It's not mamba based any more. it's a mix up between transformer and mamba so that's gonna be different.

trap20 commented 4 months ago

There is a Mini-Jamba on Huggingface now: https://huggingface.co/TechxGenus/Mini-Jamba-v2

Might be helpful for testing - if it actually is a working Mini-Jamba model, haven't checked that yet.

severian42 commented 4 months ago

Just checking to see if anyone has come close to getting Jamba working here. I've been working on figuring out fine-tuning and training on some new general chat Jamba models in prep for when they can be more standardized for everyone. Once we can get Jamba as a GGUF, I think it'll do some awesome stuff for all of us

https://huggingface.co/Severian/Jamba-Hercules https://huggingface.co/Severian/Jamba-Nexus-IKM

Any-Winter-4079 commented 4 months ago

Any update on Jamba support?

compilade commented 4 months ago

Any update on Jamba support?

I've worked on refactoring the KV cache in the past weeks to allow managing both recurrent states and Attention's KV cache at once. (See https://github.com/ggerganov/llama.cpp/compare/master...compilade/refactor-kv-cache) It's still a work-in-progess, but state checkpoints (necessary to avoid re-processing the whole prompt when removing the last few tokens) are implemented, but not yet handled in the server. I'll open a PR when it will be ready. I still need more time to think through the implementation (currently very busy with other things).

After that, work on specific hybrid model architectures like Jamba will be possible.

severian42 commented 4 months ago

@compilade Thank you so much for taking this on. I have been trying on my own but failing miserably to get Jamba quantizied with llama.cpp

I have been prepping by training as many Jamba models as possible since that is more my wheelhouse

For your endeavors, could I 'Buy You a Coffee' to help support? I know this extra work isn't easy by any means

erlebach commented 3 months ago

Could somebody write about why quantizing Jamba and providing a gguf is difficult? Thanks. Gordon.

compilade commented 3 months ago

For your endeavors, could I 'Buy You a Coffee' to help support?

@severian42 I appreciate the offer (it means a lot!), but I can't accept for now. Receiving international donations seems a bit complicated accounting-wise and I don't want to have to think about this (yet). Still nice to know you want this to succeed!

I know this extra work isn't easy by any means

Well, I don't see it as "work", more like exploring ideas. I like to be able to deeply think through some hard problems, and llama.cpp has plenty of that. :)

Could somebody write about why quantizing Jamba and providing a gguf is difficult? Thanks. Gordon.

@erlebach The main difficulty is how the state is managed; some layers (the Attention layers) will use the KV cache while others (the Mamba layers) will use recurrent states. This is what is taking the most effort/time to implement, since the API around copying, removing and allocating KV cells needs to be re-thought to support both types of cache at the same time.

I have more free time these days, so my work-in-progress of the above at https://github.com/ggerganov/llama.cpp/compare/master...compilade/refactor-kv-cache should advance a bit quicker than in the past weeks/month, though I'm currently working on simplifying convert-hf-to-gguf.py (#7031) to use lazy operations (#7075) to avoid having all the weights of a model in RAM during conversion. This should make testing of the conversion for big models (like Jamba, with its 100GB of bfloat16 weights) much easier and far less memory-hungry (and/or less disk-hungry if the --use-temp-file option was used).

Quantization will likely not be a problem, since it seemed to work well enough for bigger Mamba models. I don't know why people keep repeating it can't be quantized. The internal Mamba-specific stuff can't, but even in pure Mamba models it's less than ~5% of the weights, while the rest of the space is taken up by linear projections, which can be quantized.

Feel free to contribute code if you are though, you could help out @compilade which seems to be one piece of the puzzle

@nonetrix Thanks for reminding others that they too can help. (EDIT: hey, your comment was useful, you didn't need to delete it)

For examples of how to help:

nonetrix commented 3 months ago

Feel free to contribute code if you are though, you could help out @compilade which seems to be one piece of the puzzle

@nonetrix Thanks for reminding others that they too can help. (EDIT: hey, your comment was useful, you didn't need to delete it)

No, it was somewhat mean spirited. I should have said what I said, I apologize

erlebach commented 3 months ago

Thank you for the response. I was simply curious since it was the first time I noticed a quantization effort take so much time. Truely, I appreciate all the hard work you guys put into this. Good luck!

pszemraj commented 3 months ago

@compilade @trap20 It's not perfect/SoTA, but I pretrained a small Jamba arch (900M params, 8 experts) on about 20B tokens using the stock HF modeling code if this helps any testing: https://hf.co/pszemraj/jamba-900M-v0.13-KIx2

there's a notebook on there with an inference example (interestingly, it uses only a few GB of VRAM even if you generate 10k tokens!)

compilade commented 2 months ago

Okay, turns out I only had to put like, 2 to 3 more days of work on this and BAM it works.

As of today, in branch refactor-kv-cache, using the model from https://github.com/ggerganov/llama.cpp/issues/6372#issuecomment-2118981108, conversion works, loading works, and inference works (well, it seems to be working, at least, with coherent sentences (thank you @pszemraj for training it!)). I did not test quantization yet (except for Q8_0).

Example output from jamba-900M-v0.13-KIx2 (click to expand) ```console $ ./bin/main -m /srv/LLMstash/tmp/jamba-900M.bf16.gguf --temp 0 -e -p "I believe the meaning of life is" --repeat-penalty 1.2 --repeat-last-n 256 -c 16384 -n 256 Log start main: build = 3003 (0fd13e94) main: built with gcc (GCC) 13.2.0 for x86_64-unknown-linux-gnu main: seed = 1716594011 llama_model_loader: loaded meta data with 26 key-value pairs and 189 tensors from /srv/LLMstash/tmp/jamba-900M.bf16.gguf (version GGUF V3 (latest)) llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output. llama_model_loader: - kv 0: general.architecture str = jamba llama_model_loader: - kv 1: general.name str = jamba-900M-v0.13-KIx2 llama_model_loader: - kv 2: jamba.block_count u32 = 12 llama_model_loader: - kv 3: jamba.context_length u32 = 16384 llama_model_loader: - kv 4: jamba.embedding_length u32 = 1024 llama_model_loader: - kv 5: jamba.feed_forward_length u32 = 4096 llama_model_loader: - kv 6: jamba.attention.head_count u32 = 32 llama_model_loader: - kv 7: jamba.attention.head_count_kv arr[i32,12] = [0, 0, 8, 0, 0, 0, 8, 0, 0, 0, 8, 0] llama_model_loader: - kv 8: jamba.ssm.conv_kernel u32 = 4 llama_model_loader: - kv 9: jamba.ssm.inner_size u32 = 2048 llama_model_loader: - kv 10: jamba.ssm.state_size u32 = 16 llama_model_loader: - kv 11: jamba.ssm.time_step_rank u32 = 256 llama_model_loader: - kv 12: jamba.attention.layer_norm_rms_epsilon f32 = 0.000001 llama_model_loader: - kv 13: jamba.expert_count u32 = 8 llama_model_loader: - kv 14: jamba.expert_used_count u32 = 2 llama_model_loader: - kv 15: general.file_type u32 = 32 llama_model_loader: - kv 16: tokenizer.ggml.model str = gpt2 llama_model_loader: - kv 17: tokenizer.ggml.pre str = gpt-2 llama_model_loader: - kv 18: tokenizer.ggml.tokens arr[str,65024] = ["", "", "", "... llama_model_loader: - kv 19: tokenizer.ggml.token_type arr[i32,65024] = [3, 3, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, ... llama_model_loader: - kv 20: tokenizer.ggml.merges arr[str,64739] = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "ĠĠ �... llama_model_loader: - kv 21: tokenizer.ggml.bos_token_id u32 = 0 llama_model_loader: - kv 22: tokenizer.ggml.eos_token_id u32 = 0 llama_model_loader: - kv 23: tokenizer.ggml.unknown_token_id u32 = 0 llama_model_loader: - kv 24: tokenizer.ggml.padding_token_id u32 = 0 llama_model_loader: - kv 25: general.quantization_version u32 = 2 llama_model_loader: - type f32: 121 tensors llama_model_loader: - type bf16: 68 tensors llm_load_vocab: special tokens definition check successful ( 29/65024 ). llm_load_print_meta: format = GGUF V3 (latest) llm_load_print_meta: arch = jamba llm_load_print_meta: vocab type = BPE llm_load_print_meta: n_vocab = 65024 llm_load_print_meta: n_merges = 64739 llm_load_print_meta: n_ctx_train = 16384 llm_load_print_meta: n_embd = 1024 llm_load_print_meta: n_head = 32 llm_load_print_meta: n_head_kv = 32 llm_load_print_meta: n_layer = 12 llm_load_print_meta: n_rot = 32 llm_load_print_meta: n_embd_head_k = 32 llm_load_print_meta: n_embd_head_v = 32 llm_load_print_meta: n_gqa = 0 llm_load_print_meta: n_embd_k_gqa = 0 llm_load_print_meta: n_embd_v_gqa = 0 llm_load_print_meta: f_norm_eps = 0.0e+00 llm_load_print_meta: f_norm_rms_eps = 1.0e-06 llm_load_print_meta: f_clamp_kqv = 0.0e+00 llm_load_print_meta: f_max_alibi_bias = 0.0e+00 llm_load_print_meta: f_logit_scale = 0.0e+00 llm_load_print_meta: n_ff = 4096 llm_load_print_meta: n_expert = 8 llm_load_print_meta: n_expert_used = 2 llm_load_print_meta: causal attn = 1 llm_load_print_meta: pooling type = 0 llm_load_print_meta: rope type = -1 llm_load_print_meta: rope scaling = linear llm_load_print_meta: freq_base_train = 10000.0 llm_load_print_meta: freq_scale_train = 1 llm_load_print_meta: n_yarn_orig_ctx = 16384 llm_load_print_meta: rope_finetuned = unknown llm_load_print_meta: ssm_d_conv = 4 llm_load_print_meta: ssm_d_inner = 2048 llm_load_print_meta: ssm_d_state = 16 llm_load_print_meta: ssm_dt_rank = 256 llm_load_print_meta: model type = ?B llm_load_print_meta: model ftype = BF16 llm_load_print_meta: model params = 887.66 M llm_load_print_meta: model size = 1.67 GiB (16.19 BPW) llm_load_print_meta: general.name = jamba-900M-v0.13-KIx2 llm_load_print_meta: BOS token = 0 '' llm_load_print_meta: EOS token = 0 '' llm_load_print_meta: UNK token = 0 '' llm_load_print_meta: PAD token = 0 '' llm_load_print_meta: LF token = 133 'Ä' llm_load_tensors: ggml ctx size = 0.09 MiB llm_load_tensors: CPU buffer size = 1713.16 MiB ...................................... llama_new_context_with_model: n_ctx = 16384 llama_new_context_with_model: n_batch = 2048 llama_new_context_with_model: n_ubatch = 512 llama_new_context_with_model: flash_attn = 0 llama_new_context_with_model: freq_base = 10000.0 llama_new_context_with_model: freq_scale = 1 llama_cache_init: CPU cache buf size = 49.34 MiB llama_new_context_with_model: SSM state size = 1.34 MiB, R (f32): 0.21 MiB, S (f32): 1.12 MiB llama_new_context_with_model: KV cache size = 48.00 MiB, K (f16): 24.00 MiB, V (f16): 24.00 MiB llama_new_context_with_model: CPU output buffer size = 0.25 MiB llama_new_context_with_model: CPU compute buffer size = 1062.03 MiB llama_new_context_with_model: graph nodes = 621 llama_new_context_with_model: graph splits = 1 system_info: n_threads = 2 / 4 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 | sampling: repeat_last_n = 256, repeat_penalty = 1.200, frequency_penalty = 0.000, presence_penalty = 0.000 top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.000 mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000 sampling order: CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature generate: n_ctx = 16384, n_batch = 2048, n_predict = 256, n_keep = 0 I believe the meaning of life is not to be found in a single word, but rather as an expression of one's own feelings and thoughts. The idea that we are all born with our bodies, whether they are human or animal, has been around for centuries. It was believed by some that it was something like a body made up of bones, which were attached to each other at birth. The most common form of this type of bone is called a "bone." This is what makes it so hard to tell if you're alive or dead. In fact, there are many different types of bones, including those that have been used for various purposes such as healing wounds, wounding wounds, etc. In ancient times, people had a lot of teeth, and these were often very small. They could also be placed on top of their heads, where they would sit down and look at them. These were usually large, round stones, which were sometimes covered with hair. When the skin was removed from the head, the bones became more prominent, and the muscles began to grow larger. This kind of bone was known as a "bone" because it was made out of two parts: the outermost part (the innermost portion) and the innermost part (the outermost llama_print_timings: load time = 252.28 ms llama_print_timings: sample time = 303.07 ms / 256 runs ( 1.18 ms per token, 844.68 tokens per second) llama_print_timings: prompt eval time = 200.72 ms / 8 tokens ( 25.09 ms per token, 39.86 tokens per second) llama_print_timings: eval time = 12516.79 ms / 255 runs ( 49.09 ms per token, 20.37 tokens per second) llama_print_timings: total time = 13213.95 ms / 263 tokens Log end ```

Inference is CPU-only for now, because the Mamba implementation in llama.cpp is still CPU-only (ref: #6758).

To convert, I used

$ python3 convert-hf-to-gguf.py /srv/LLMstash/src/jamba-900M-v0.13-KIx2/ --outfile /srv/LLMstash/tmp/jamba-900M.{ftype}.gguf --outtype auto

(@severian42 note that convert-hf-to-gguf.py doesn't yet support loading 4-bit BitsAndBytes models (yet (I might work on fixing this eventually)). In the meantime, dequantizing to a 16-bit float type (either F16 or BF16) will be necessary.)

I'm going to open a PR very soon (in a few hours or tomorrow), I just need to write it up (there are quite a lot of lines changed and I want to explain).

For the impatient, the code is in https://github.com/ggerganov/llama.cpp/tree/compilade/refactor-kv-cache.

severian42 commented 2 months ago

Incredible! Thank you so much for putting in the work to get this running and updating us with the news. I will give this a try this weekend and report back. So excited to try this! Thanks again for using your smarts to further the open source LLM world. We owe a big one 💪

compilade commented 2 months ago

There is still more work I need to put into this. I've got inference working, but things that are not yet done are:

Building this will (currently) output lots of warnings because I've renamed many functions related to KV cache management, and then deprecated the old names, but I did not yet update their usages in the various examples.

Finishing this will take a few days still, but I think I will still open a PR tomorrow.

I think I won't change anything about how the GGUFs of Jamba are made, unless I've unexpectedly messed something up in the conversion code.

github-actions[bot] commented 1 month ago

This issue was closed because it has been inactive for 14 days since being marked as stale.