Closed dranger003 closed 4 months ago
The paper says they use mean pooling for the embeddings, but currently embeddings for generative models like this are hard-coded to last token. If you change the code around llama.cpp:8087
to:
const int64_t embd_pos = 0;
const int64_t embd_size = n_embd * n_tokens;
It'll pull in embeddings for each token then you can manually do the mean pooling. That said, I just tried it and I'm stilling getting different results. So there must be something else going on as well.
I'm pretty excited about GritLM, especially the document caching tricks they show. Would be cool to get it working here.
Thanks, I saw that but don't know what it means (no pun intended). Do I also need to change the way I call llama_decode
?
Also, I added some quants on HF.
Good news @dranger003! I finally got the numbers to match up exactly. Check it out at iamlemec/llama.cpp:gritlm. Here are the changes that were needed:
llama.cpp
)Also note the slight fix to the cosine similarity (we want similarity, not distance).
Thanks, just tested it and it works great! You even included the sample code, amazing. I think there is a lot of potential in using the same model for both representation and generation, especially when resources are scarce.
I've been trying to add the text generation piece and I noticed two things:
embedding = false
and pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED
, is that right? I thought we could maybe save having to unload/reload the model but maybe not.@iamlemec Am I right in thinking there is a part two for master branch to work? I saw ggerganov's commit yesterday to fix embeddings but gritlm doesn't work in master, so this is why I'm asking.
EDIT: Just saw your v2 branch. If I can help with anything let me know!
@dranger003 Yup, that's my attempt at rebasing after the new embedding fixes. Unfortunately it's not giving the right results currently. Trying to figure it out now, but if you spot anything, do tell!
@dranger003 Finally figured it out. My gritlm-v2
branch is working now. We switched away from using the KV cache for non-causal BERT models, but Llama type models still use it, so had to tweak the attention mask construction.
@iamlemec I added the text generation code to the sample and it works great without reloading the model for both modes. Should we submit a PR?
EDIT: So it looks like we cannot use both embeddings and causal on the same context? The text generation works but the embeddings are different when I set both embeddings and causal to true. Is that expected (i.e. they need to have a separate context)?
Just pushed a thing where you can toggle embedding mode with llama_set_embeddings
. This has the advantage of not having to manually fiddle with causal_attn
, everything goes through embeddings
. I haven't tested generation extensively, but the output is sensible at least.
Let's test and think about this for a day or two to see if it's ready for prime time. I might also try to see if I can get the doc caching example working.
I reproduced the embeddings sample from GritLM and llama.cpp returns unexpected embedding values. I have been able to get embeddings to work with other models. I verified the tokenization and all seems good (with and without special tokens and bos/eos).
Below is a sample program to reproduce the issue with the unexpected results. The gritlm python inference code can be found here. The model inference supports both text generation and text representation (embeddings) and is based on Mistral 7B.
There is nothing special in this code, and this is based off the embeddings sample in llama.cpp so I'm not sure what is going. Any guidance is appreciated.
Note: It seems HF is in maintenance right now, but I'll add a gguf link when they're back online.
sample source code
``` static float dot_product(const std::vectoroutput
``` ./embeddings -m ggml-gritlm-7b-q8_0.gguf -ngl 33 ggml_init_cublas: GGML_CUDA_FORCE_MMQ: no ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes ggml_init_cublas: found 1 CUDA devices: Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes llama_model_loader: loaded meta data with 24 key-value pairs and 291 tensors from ggml-gritlm-7b-q8_0.gguf (version GGUF V3 (latest)) llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output. llama_model_loader: - kv 0: general.architecture str = llama llama_model_loader: - kv 1: general.name str = GritLM llama_model_loader: - kv 2: llama.context_length u32 = 32768 llama_model_loader: - kv 3: llama.embedding_length u32 = 4096 llama_model_loader: - kv 4: llama.block_count u32 = 32 llama_model_loader: - kv 5: llama.feed_forward_length u32 = 14336 llama_model_loader: - kv 6: llama.rope.dimension_count u32 = 128 llama_model_loader: - kv 7: llama.attention.head_count u32 = 32 llama_model_loader: - kv 8: llama.attention.head_count_kv u32 = 8 llama_model_loader: - kv 9: llama.attention.layer_norm_rms_epsilon f32 = 0.000010 llama_model_loader: - kv 10: llama.rope.freq_base f32 = 10000.000000 llama_model_loader: - kv 11: general.file_type u32 = 7 llama_model_loader: - kv 12: tokenizer.ggml.model str = llama llama_model_loader: - kv 13: tokenizer.ggml.tokens arr[str,32000] = ["", "", "<0x00>", "<... llama_model_loader: - kv 14: tokenizer.ggml.scores arr[f32,32000] = [0.000000, 0.000000, 0.000000, 0.0000... llama_model_loader: - kv 15: tokenizer.ggml.token_type arr[i32,32000] = [2, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, ... llama_model_loader: - kv 16: tokenizer.ggml.bos_token_id u32 = 1 llama_model_loader: - kv 17: tokenizer.ggml.eos_token_id u32 = 2 llama_model_loader: - kv 18: tokenizer.ggml.unknown_token_id u32 = 0 llama_model_loader: - kv 19: tokenizer.ggml.padding_token_id u32 = 1 llama_model_loader: - kv 20: tokenizer.ggml.add_bos_token bool = true llama_model_loader: - kv 21: tokenizer.ggml.add_eos_token bool = false llama_model_loader: - kv 22: tokenizer.chat_template str = {{ bos_token }}{% for message in mess... llama_model_loader: - kv 23: general.quantization_version u32 = 2 llama_model_loader: - type f32: 65 tensors llama_model_loader: - type q8_0: 226 tensors llm_load_vocab: special tokens definition check successful ( 259/32000 ). llm_load_print_meta: format = GGUF V3 (latest) llm_load_print_meta: arch = llama llm_load_print_meta: vocab type = SPM llm_load_print_meta: n_vocab = 32000 llm_load_print_meta: n_merges = 0 llm_load_print_meta: n_ctx_train = 32768 llm_load_print_meta: n_embd = 4096 llm_load_print_meta: n_head = 32 llm_load_print_meta: n_head_kv = 8 llm_load_print_meta: n_layer = 32 llm_load_print_meta: n_rot = 128 llm_load_print_meta: n_embd_head_k = 128 llm_load_print_meta: n_embd_head_v = 128 llm_load_print_meta: n_gqa = 4 llm_load_print_meta: n_embd_k_gqa = 1024 llm_load_print_meta: n_embd_v_gqa = 1024 llm_load_print_meta: f_norm_eps = 0.0e+00 llm_load_print_meta: f_norm_rms_eps = 1.0e-05 llm_load_print_meta: f_clamp_kqv = 0.0e+00 llm_load_print_meta: f_max_alibi_bias = 0.0e+00 llm_load_print_meta: n_ff = 14336 llm_load_print_meta: n_expert = 0 llm_load_print_meta: n_expert_used = 0 llm_load_print_meta: pooling type = 0 llm_load_print_meta: rope type = 0 llm_load_print_meta: rope scaling = linear llm_load_print_meta: freq_base_train = 10000.0 llm_load_print_meta: freq_scale_train = 1 llm_load_print_meta: n_yarn_orig_ctx = 32768 llm_load_print_meta: rope_finetuned = unknown llm_load_print_meta: model type = 7B llm_load_print_meta: model ftype = Q8_0 llm_load_print_meta: model params = 7.24 B llm_load_print_meta: model size = 7.17 GiB (8.50 BPW) llm_load_print_meta: general.name = GritLM llm_load_print_meta: BOS token = 1 '' llm_load_print_meta: EOS token = 2 '' llm_load_print_meta: UNK token = 0 '' llm_load_print_meta: LF token = 13 '<0x0A>' llm_load_tensors: ggml ctx size = 0.22 MiB llm_load_tensors: offloading 32 repeating layers to GPU llm_load_tensors: offloading non-repeating layers to GPU llm_load_tensors: offloaded 33/33 layers to GPU llm_load_tensors: CPU buffer size = 132.81 MiB llm_load_tensors: CUDA0 buffer size = 7205.83 MiB llama_new_context_with_model: n_ctx = 512 llama_new_context_with_model: freq_base = 10000.0 llama_new_context_with_model: freq_scale = 1 llama_kv_cache_init: CUDA0 KV buffer size = 64.00 MiB llama_new_context_with_model: KV self size = 64.00 MiB, K (f16): 32.00 MiB, V (f16): 32.00 MiB llama_new_context_with_model: CUDA_Host input buffer size = 10.01 MiB llama_new_context_with_model: CUDA0 compute buffer size = 73.00 MiB llama_new_context_with_model: CUDA_Host compute buffer size = 8.00 MiB llama_new_context_with_model: graph splits (measure): 2 [1:][523: <][28766:|][18320:embed][28766:|][28767:>][13: ][28741:A][21690: purely][13669: peer][28733:-][532:to][28733:-][14720:peer][2751: version][302: of][13176: electronic][7877: cash][682: would][1914: allow][3270: online][14923: payments][298: to][347: be][2662: sent][5090: directly][477: from][624: one][4150: party][298: to][1698: another][1671: without][1404: going][1059: through][264: a][5593: financial][16854: institution][28723:.][13770: Digital][1492: sign][2863:atures][3084: provide][744: part][302: of][272: the][5165: solution][28725:,][562: but][272: the][2191: main][7196: benefits][460: are][3654: lost][513: if][264: a][16437: trusted][4008: third][4150: party][349: is][1309: still][3030: required][298: to][5297: prevent][3579: double][28733:-][886:sp][2570:ending][28723:.][816: We][19333: propose][264: a][5165: solution][298: to][272: the][3579: double][28733:-][886:sp][2570:ending][2700: problem][1413: using][264: a][13669: peer][28733:-][532:to][28733:-][14720:peer][3681: network][28723:.][415: The][3681: network][5104: tim][374:est][10991:amps][15852: transactions][486: by][659: has][2299:hing][706: them][778: into][396: an][15260: ongoing][7650: chain][302: of][7135: hash][28733:-][5527:based][7167: proof][28733:-][1009:of][28733:-][1328:work][28725:,][20345: forming][264: a][2395: record][369: that][3573: cannot][347: be][4648: changed][1671: without][312: re][2432:do][288:ing][272: the][7167: proof][28733:-][1009:of][28733:-][1328:work][28723:.][415: The][23397: longest][7650: chain][459: not][865: only][14449: serves][390: as][7167: proof][302: of][272: the][7768: sequence][302: of][3926: events][24385: witnessed][28725:,][562: but][7167: proof][369: that][378: it][1988: came][477: from][272: the][7639: largest][6313: pool][302: of][14865: CPU][1982: power][28723:.][1136: As][1043: long][390: as][264: a][7757: majority][302: of][14865: CPU][1982: power][349: is][12888: controlled][486: by][9249: nodes][369: that][460: are][459: not][18468: cooper][1077:ating][298: to][3517: attack][272: the][3681: network][28725:,][590: they][28742:'][584:ll][8270: generate][272: the][23397: longest][7650: chain][304: and][575: out][2644:pace][3517: attack][404:ers][28723:.][415: The][3681: network][3837: itself][6948: requires][13383: minimal][4693: structure][28723:.][351: M][9251:essages][460: are][11837: broadcast][356: on][264: a][1489: best][4261: effort][6451: basis][28725:,][304: and][9249: nodes][541: can][3530: leave][304: and][312: re][5906:join][272: the][3681: network][438: at][622: will][28725:,][22368: accepting][272: the][23397: longest][7167: proof][28733:-][1009:of][28733:-][1328:work][7650: chain][390: as][7167: proof][302: of][767: what][4243: happened][1312: while][590: they][654: were][4214: gone][28723:.] [1:][523: <][28766:|][18320:embed][28766:|][28767:>][13: ][2595:All][2245: text][28733:-][5527:based][3842: language][4418: problems][541: can][347: be][9397: reduced][298: to][2477: either][8342: generation][442: or][28643: embedding][28723:.][10929: Current][4994: models][865: only][2225: perform][1162: well][438: at][624: one][442: or][272: the][799: other][28723:.][816: We][13097: introduce][1350: gener][1197:ative][2904: represent][1249:ational][13126: instruction][15013: tun][288:ing][325: (][8369:GR][1153:IT][28731:)][970: where][1403:by][264: a][2475: large][3842: language][2229: model][349: is][10898: trained][298: to][4269: handle][1560: both][1350: gener][1197:ative][304: and][28643: embedding][9796: tasks][486: by][11731: distingu][5596:ishing][1444: between][706: them][1059: through][11382: instructions][28723:.][3880: Comp][1327:ared][298: to][799: other][1565: open][4994: models][28725:,][813: our][10503: resulting][420: G][872:rit][27149:LM][28705: ][28787:7][28760:B][6491: sets][264: a][633: new][1665: state][302: of][272: the][1524: art][356: on][272: the][7576: Mass][495:ive][7379: Text][18065: Emb][286:ed][3202:ding][4121: Ben][338:ch][3325:mark][325: (][28755:M][3392:TE][28760:B][28731:)][304: and][575: out][487:per][14367:forms][544: all][4994: models][582: up][298: to][871: its][1669: size][356: on][264: a][2819: range][302: of][1350: gener][1197:ative][9796: tasks][28723:.][2463: By][19903: scaling][582: up][3629: further][28725:,][420: G][872:rit][27149:LM][28705: ][28783:8][28814:X][28787:7][28760:B][575: out][487:per][14367:forms][544: all][1565: open][1350: gener][1197:ative][3842: language][4994: models][369: that][478: we][3851: tried][1312: while][1309: still][1250: being][3352: among][272: the][1489: best][28643: embedding][4994: models][28723:.][2280: Not][1907:ably][28725:,][478: we][1300: find][369: that][19348: GR][1153:IT][9019: matches][4154: training][356: on][865: only][1350: gener][1197:ative][442: or][28643: embedding][1178: data][28725:,][5884: thus][478: we][541: can][521: un][1575:ify][1560: both][438: at][708: no][4397: performance][4320: loss][28723:.][13927: Among][799: other][7196: benefits][28725:,][272: the][521: un][2500:ification][4213: via][19348: GR][1153:IT][27480: speeds][582: up][8337: Ret][10212:riev][282:al][28733:-][21575:Aug][466:ment][286:ed][26802: Generation][325: (][28754:R][2377:AG][28731:)][486: by][876: >][28705: ][28784:6][28734:0][28823:%][354: for][1043: long][10181: documents][28725:,][486: by][708: no][3774: longer][22579: requiring][7681: separate][17913: retriev][282:al][304: and][8342: generation][4994: models][28723:.][3813: Mod][1190:els][28725:,][2696: code][28725:,][4345: etc][28723:.][460: are][21964: freely][2632: available][438: at][4449: https][1508:://][6222:github][28723:.][675:com][28748:/][2083:Context][840:ual][11741:AI][28748:/][820:gr][279:it][24174:lm][28723:.] [1:][523: <][28766:|][1838:user][28766:|][28767:>][13: ][28777:G][5067:iven][264: a][10469: scientific][3830: paper][3941: title][28725:,][20132: retrieve][272: the][3830: paper][28742:'][28713:s][11576: abstract][13: ][28789:<][28766:|][18320:embed][28766:|][28767:>][13: ][8443:Bit][10817:coin][28747::][330: A][3242: Pe][263:er][28733:-][532:to][28733:-][22163:Peer][10394: Elect][7624:ronic][23439: Cash][2135: System] [1:][523: <][28766:|][1838:user][28766:|][28767:>][13: ][28777:G][5067:iven][264: a][10469: scientific][3830: paper][3941: title][28725:,][20132: retrieve][272: the][3830: paper][28742:'][28713:s][11576: abstract][13: ][28789:<][28766:|][18320:embed][28766:|][28767:>][13: ][3602:Gener][1197:ative][17891: Represent][1249:ational][3133: Inst][3112:ruction][22756: Tun][288:ing] Cosine similarity between "Bitcoin: A Peer-to-Peer Electronic Cash System" and "A purely peer-to-peer version of electronic cash w" is: 0.551 Cosine similarity between "Bitcoin: A Peer-to-Peer Electronic Cash System" and "All text-based language problems can be reduced to" is: 0.794 Cosine similarity between "Generative Representational Instruction Tuning" and "A purely peer-to-peer version of electronic cash w" is: 0.730 Cosine similarity between "Generative Representational Instruction Tuning" and "All text-based language problems can be reduced to" is: 0.803 ```