abetlen / llama-cpp-python

Python bindings for llama.cpp
https://llama-cpp-python.readthedocs.io
MIT License
7.55k stars 906 forks source link

Cannot run T5-based models #1587

Open kyteinsky opened 1 month ago

kyteinsky commented 1 month ago

Prerequisites

Please answer the following questions for yourself before submitting an issue.

Expected Behavior

I'm trying to use models based on the T5 architecture, like "google/flan-t5-small" and "google/madlad400-3b-mt" after converting them to gguf format (8 bit). It works nicely with llama-cli.

It was a recent addition to llama.cpp: https://github.com/ggerganov/llama.cpp/pull/8141

Related issue but I don't understand what needs to be changed here in the python package: https://github.com/ggerganov/llama.cpp/issues/8398

Current Behavior

It loads the model into gpu/memory but the execution fails with a GGML_ASSERT and (core dumped)

Environment and Context

I'm running v0.2.82 using the compiled whl provided for cuda 12.4.

$ lscpu

Architecture:            x86_64
  CPU op-mode(s):        32-bit, 64-bit
  Address sizes:         39 bits physical, 48 bits virtual
  Byte Order:            Little Endian
CPU(s):                  12
  On-line CPU(s) list:   0-11
Vendor ID:               GenuineIntel
  Model name:            Intel(R) Core(TM) i5-10400 CPU @ 2.90GHz
    CPU family:          6
    Model:               165
    Thread(s) per core:  2
    Core(s) per socket:  6
    Socket(s):           1
    Stepping:            5
    CPU(s) scaling MHz:  80%
    CPU max MHz:         4300.0000
    CPU min MHz:         800.0000
    BogoMIPS:            5799.77
    Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc 
                         art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid 
                         sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow
                          flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx rdseed adx smap clflushopt intel_pt xsaveopt xsavec xgetbv1 xsaves dtherm i
                         da arat pln pts hwp hwp_notify hwp_act_window hwp_epp vnmi pku ospke md_clear flush_l1d arch_capabilities
Virtualization features: 
  Virtualization:        VT-x
Caches (sum of all):     
  L1d:                   192 KiB (6 instances)
  L1i:                   192 KiB (6 instances)
  L2:                    1.5 MiB (6 instances)
  L3:                    12 MiB (1 instance)
NUMA:                    
  NUMA node(s):          1
  NUMA node0 CPU(s):     0-11
Vulnerabilities:         
  Gather data sampling:  Mitigation; Microcode
  Itlb multihit:         KVM: Mitigation: VMX disabled
  L1tf:                  Not affected
  Mds:                   Not affected
  Meltdown:              Not affected
  Mmio stale data:       Mitigation; Clear CPU buffers; SMT vulnerable
  Retbleed:              Mitigation; Enhanced IBRS
  Spec rstack overflow:  Not affected
  Spec store bypass:     Mitigation; Speculative Store Bypass disabled via prctl
  Spectre v1:            Mitigation; usercopy/swapgs barriers and __user pointer sanitization
  Spectre v2:            Mitigation; Enhanced / Automatic IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
  Srbds:                 Mitigation; Microcode
  Tsx async abort:       Not affected

$ uname -a

Linux work 6.6.8-200.fsync.fc38.x86_64 #1 SMP PREEMPT_DYNAMIC TKG Sat Dec 30 04:42:45 UTC 2023 x86_64 GNU/Linux
$ python3 --version
Python 3.11.9

$ make --version
GNU Make 4.4.1
Built for x86_64-redhat-linux-gnu

$ g++ --version
g++ (GCC) 13.2.1 20240316 (Red Hat 13.2.1-7)

Failure Information (for bugs)

Please help provide information about the failure if this is a bug. If it is not a bug, please remove the rest of this template.

Steps to Reproduce

Please provide detailed steps for reproducing the issue. We are not sitting in front of your screen, so the more detail the better.

  1. Quantise a T5 model like flan-t5-small using the latest master of llama-cpp (python convert_hf_to_gguf.py ./flan-t5-small --outtype q8_0)
  2. Install llama-cpp-python using the pre-built wheel (llama_cpp_python @ https://github.com/abetlen/llama-cpp-python/releases/download/v0.2.82-cu121/llama_cpp_python-0.2.82-cp311-cp311-linux_x86_64.whl#sha256=8faeb9347543e9752c3dc117cd2f4d214a158444d648f6ddf0b9c360e96f2694)
  3. Try to load and run the model without any fancy params, just the minimum.
llm = Llama(
    "./flan-t5-small/ggml-model-q8_0.gguf",
    n_gpu_layers=-1,
)

print(llm("hello, how are you?", max_tokens=1024, temperature=0.1))

Note: Many issues seem to be regarding functional or performance issues / differences with llama.cpp. In these cases we need to confirm that you're comparing against the version of llama.cpp that was built with your python package, and which parameters you're passing to the context.

Failure Logs

It is always the same.

``` llama_model_loader: loaded meta data with 28 key-value pairs and 190 tensors from ./flan-t5-small/ggml-model-q8_0.gguf (version GGUF V3 (latest)) llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output. llama_model_loader: - kv 0: general.architecture str = t5 llama_model_loader: - kv 1: general.name str = T5 llama_model_loader: - kv 2: t5.context_length u32 = 512 llama_model_loader: - kv 3: t5.embedding_length u32 = 512 llama_model_loader: - kv 4: t5.feed_forward_length u32 = 1024 llama_model_loader: - kv 5: t5.block_count u32 = 8 llama_model_loader: - kv 6: t5.attention.head_count u32 = 6 llama_model_loader: - kv 7: t5.attention.key_length u32 = 64 llama_model_loader: - kv 8: t5.attention.value_length u32 = 64 llama_model_loader: - kv 9: t5.attention.layer_norm_epsilon f32 = 0.000001 llama_model_loader: - kv 10: t5.attention.relative_buckets_count u32 = 32 llama_model_loader: - kv 11: t5.attention.layer_norm_rms_epsilon f32 = 0.000001 llama_model_loader: - kv 12: t5.decoder_start_token_id u32 = 0 llama_model_loader: - kv 13: general.file_type u32 = 7 llama_model_loader: - kv 14: tokenizer.ggml.model str = t5 llama_model_loader: - kv 15: tokenizer.ggml.pre str = default llama_model_loader: - kv 16: tokenizer.ggml.tokens arr[str,32128] = ["", "", "", "▁", "X"... llama_model_loader: - kv 17: tokenizer.ggml.scores arr[f32,32128] = [0.000000, 0.000000, 0.000000, -2.012... llama_model_loader: - kv 18: tokenizer.ggml.token_type arr[i32,32128] = [3, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... llama_model_loader: - kv 19: tokenizer.ggml.add_space_prefix bool = true llama_model_loader: - kv 20: tokenizer.ggml.remove_extra_whitespaces bool = true llama_model_loader: - kv 21: tokenizer.ggml.precompiled_charsmap arr[u8,237539] = [0, 180, 2, 0, 0, 132, 0, 0, 0, 0, 0,... llama_model_loader: - kv 22: tokenizer.ggml.eos_token_id u32 = 1 llama_model_loader: - kv 23: tokenizer.ggml.unknown_token_id u32 = 2 llama_model_loader: - kv 24: tokenizer.ggml.padding_token_id u32 = 0 llama_model_loader: - kv 25: tokenizer.ggml.add_bos_token bool = false llama_model_loader: - kv 26: tokenizer.ggml.add_eos_token bool = true llama_model_loader: - kv 27: general.quantization_version u32 = 2 llama_model_loader: - type f32: 42 tensors llama_model_loader: - type f16: 2 tensors llama_model_loader: - type q8_0: 146 tensors llm_load_vocab: special tokens cache size = 131 llm_load_vocab: token to piece cache size = 0.2123 MB llm_load_print_meta: format = GGUF V3 (latest) llm_load_print_meta: arch = t5 llm_load_print_meta: vocab type = UGM llm_load_print_meta: n_vocab = 32128 llm_load_print_meta: n_merges = 0 llm_load_print_meta: vocab_only = 0 llm_load_print_meta: n_ctx_train = 512 llm_load_print_meta: n_embd = 512 llm_load_print_meta: n_layer = 8 llm_load_print_meta: n_head = 6 llm_load_print_meta: n_head_kv = 6 llm_load_print_meta: n_rot = 64 llm_load_print_meta: n_swa = 0 llm_load_print_meta: n_embd_head_k = 64 llm_load_print_meta: n_embd_head_v = 64 llm_load_print_meta: n_gqa = 1 llm_load_print_meta: n_embd_k_gqa = 384 llm_load_print_meta: n_embd_v_gqa = 384 llm_load_print_meta: f_norm_eps = 0.0e+00 llm_load_print_meta: f_norm_rms_eps = 1.0e-06 llm_load_print_meta: f_clamp_kqv = 0.0e+00 llm_load_print_meta: f_max_alibi_bias = 0.0e+00 llm_load_print_meta: f_logit_scale = 0.0e+00 llm_load_print_meta: n_ff = 1024 llm_load_print_meta: n_expert = 0 llm_load_print_meta: n_expert_used = 0 llm_load_print_meta: causal attn = 1 llm_load_print_meta: pooling type = 0 llm_load_print_meta: rope type = -1 llm_load_print_meta: rope scaling = linear llm_load_print_meta: freq_base_train = 10000.0 llm_load_print_meta: freq_scale_train = 1 llm_load_print_meta: n_ctx_orig_yarn = 512 llm_load_print_meta: rope_finetuned = unknown llm_load_print_meta: ssm_d_conv = 0 llm_load_print_meta: ssm_d_inner = 0 llm_load_print_meta: ssm_d_state = 0 llm_load_print_meta: ssm_dt_rank = 0 llm_load_print_meta: model type = 80M llm_load_print_meta: model ftype = Q8_0 llm_load_print_meta: model params = 76.96 M llm_load_print_meta: model size = 78.04 MiB (8.51 BPW) llm_load_print_meta: general.name = T5 llm_load_print_meta: EOS token = 1 '' llm_load_print_meta: UNK token = 2 '' llm_load_print_meta: PAD token = 0 '' llm_load_print_meta: LF token = 3 '▁' llm_load_print_meta: max token length = 20 ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no ggml_cuda_init: found 1 CUDA devices: Device 0: NVIDIA GeForce RTX 4060 Ti, compute capability 8.9, VMM: yes llm_load_tensors: ggml ctx size = 0.15 MiB llm_load_tensors: offloading 8 repeating layers to GPU llm_load_tensors: offloading non-repeating layers to GPU llm_load_tensors: offloaded 9/9 layers to GPU llm_load_tensors: CPU buffer size = 77.45 MiB llm_load_tensors: CUDA0 buffer size = 61.38 MiB ............................................................ llama_new_context_with_model: n_ctx = 512 llama_new_context_with_model: n_batch = 512 llama_new_context_with_model: n_ubatch = 512 llama_new_context_with_model: flash_attn = 0 llama_new_context_with_model: freq_base = 10000.0 llama_new_context_with_model: freq_scale = 1 llama_kv_cache_init: CUDA0 KV buffer size = 6.00 MiB llama_new_context_with_model: KV self size = 6.00 MiB, K (f16): 3.00 MiB, V (f16): 3.00 MiB llama_new_context_with_model: CUDA_Host output buffer size = 0.12 MiB llama_new_context_with_model: CUDA0 compute buffer size = 64.75 MiB llama_new_context_with_model: CUDA_Host compute buffer size = 10.00 MiB llama_new_context_with_model: graph nodes = 454 llama_new_context_with_model: graph splits = 18 AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 0 | Model metadata: {'tokenizer.ggml.padding_token_id': '0', 'tokenizer.ggml.eos_token_id': '1', 'tokenizer.ggml.add_bos_token': 'false', 'tokenizer.ggml.remove_extra_whitespaces': 'true', 'tokenizer.ggml.add_eos_token': 'true', 'tokenizer.ggml.add_space_prefix': 'true', 'general.quantization_version': '2', 'tokenizer.ggml.model': 't5', 'general.file_type': '7', 'general.architecture': 't5', 't5.attention.key_length': '64', 'tokenizer.ggml.pre': 'default', 'general.name': 'T5', 't5.attention.layer_norm_epsilon': '0.000001', 't5.decoder_start_token_id': '0', 't5.attention.relative_buckets_count': '32', 't5.context_length': '512', 'tokenizer.ggml.unknown_token_id': '2', 't5.embedding_length': '512', 't5.feed_forward_length': '1024', 't5.block_count': '8', 't5.attention.head_count': '6', 't5.attention.value_length': '64', 't5.attention.layer_norm_rms_epsilon': '0.000001'} Using fallback chat format: llama-2 GGML_ASSERT: /tmp/pip-install-01mkxjkt/llama-cpp-python_0e927e07a7a14bca8a4650c8034eda8a/vendor/llama.cpp/ggml/src/ggml.c:5278: !ggml_is_transposed(a) [New LWP 203804] [New LWP 203808] [New LWP 203809] This GDB supports auto-downloading debuginfo from the following URLs: Enable debuginfod for this session? (y or [n]) [answered N; input not from terminal] Debuginfod has been disabled. To make this setting permanent, add 'set debuginfod enabled off' to .gdbinit. [Thread debugging using libthread_db enabled] Using host libthread_db library "/lib64/libthread_db.so.1". 0x00007f016f8f8bf3 in wait4 () from /lib64/libc.so.6 #0 0x00007f016f8f8bf3 in wait4 () from /lib64/libc.so.6 #1 0x00007f0153c2da1b in ggml_print_backtrace () from /home/tyrell/new-data/translate_bench/translate2/.venv/lib64/python3.11/site-packages/llama_cpp/lib/libggml.so #2 0x00007f0153c55e41 in ggml_mul_mat () from /home/tyrell/new-data/translate_bench/translate2/.venv/lib64/python3.11/site-packages/llama_cpp/lib/libggml.so #3 0x00007f0161c47e94 in llm_build_context::build_t5() () from /home/tyrell/new-data/translate_bench/translate2/.venv/lib64/python3.11/site-packages/llama_cpp/lib/libllama.so #4 0x00007f0161bc50c1 in ?? () from /home/tyrell/new-data/translate_bench/translate2/.venv/lib64/python3.11/site-packages/llama_cpp/lib/libllama.so #5 0x00007f0161bdc61f in llama_decode () from /home/tyrell/new-data/translate_bench/translate2/.venv/lib64/python3.11/site-packages/llama_cpp/lib/libllama.so #6 0x00007f016ffe4be6 in ffi_call_unix64 () from /lib64/libffi.so.8 #7 0x00007f016ffe14bf in ffi_call_int.lto_priv () from /lib64/libffi.so.8 #8 0x00007f016ffe418e in ffi_call () from /lib64/libffi.so.8 #9 0x00007f016f62ae5a in _ctypes_callproc.cold () from /usr/lib64/python3.11/lib-dynload/_ctypes.cpython-311-x86_64-linux-gnu.so #10 0x00007f016f62a722 in PyCFuncPtr_call.cold () from /usr/lib64/python3.11/lib-dynload/_ctypes.cpython-311-x86_64-linux-gnu.so #11 0x00007f016fbb2793 in _PyObject_MakeTpCall () from /lib64/libpython3.11.so.1.0 #12 0x00007f016fbbae8d in _PyEval_EvalFrameDefault () from /lib64/libpython3.11.so.1.0 #13 0x00007f016fbeddea in gen_send_ex2 () from /lib64/libpython3.11.so.1.0 #14 0x00007f016fbedc78 in gen_iternext () from /lib64/libpython3.11.so.1.0 #15 0x00007f016fbbad0d in _PyEval_EvalFrameDefault () from /lib64/libpython3.11.so.1.0 #16 0x00007f016fbeddea in gen_send_ex2 () from /lib64/libpython3.11.so.1.0 #17 0x00007f016fc1ef43 in builtin_next () from /lib64/libpython3.11.so.1.0 #18 0x00007f016fbd4701 in cfunction_vectorcall_FASTCALL () from /lib64/libpython3.11.so.1.0 #19 0x00007f016fbc82b7 in PyObject_Vectorcall () from /lib64/libpython3.11.so.1.0 #20 0x00007f016fbbae8d in _PyEval_EvalFrameDefault () from /lib64/libpython3.11.so.1.0 #21 0x00007f016fbb715a in _PyEval_Vector () from /lib64/libpython3.11.so.1.0 #22 0x00007f016fbb5591 in _PyObject_FastCallDictTstate () from /lib64/libpython3.11.so.1.0 #23 0x00007f016fbde2dc in _PyObject_Call_Prepend () from /lib64/libpython3.11.so.1.0 #24 0x00007f016fc81352 in slot_tp_call () from /lib64/libpython3.11.so.1.0 #25 0x00007f016fbb2793 in _PyObject_MakeTpCall () from /lib64/libpython3.11.so.1.0 #26 0x00007f016fbbae8d in _PyEval_EvalFrameDefault () from /lib64/libpython3.11.so.1.0 #27 0x00007f016fbb715a in _PyEval_Vector () from /lib64/libpython3.11.so.1.0 #28 0x00007f016fc3d27c in PyEval_EvalCode () from /lib64/libpython3.11.so.1.0 #29 0x00007f016fc5a723 in run_eval_code_obj () from /lib64/libpython3.11.so.1.0 #30 0x00007f016fc56c8a in run_mod () from /lib64/libpython3.11.so.1.0 #31 0x00007f016fc6bfc2 in pyrun_file () from /lib64/libpython3.11.so.1.0 #32 0x00007f016fc6b7a8 in _PyRun_SimpleFileObject () from /lib64/libpython3.11.so.1.0 #33 0x00007f016fc6b468 in _PyRun_AnyFileObject () from /lib64/libpython3.11.so.1.0 #34 0x00007f016fc65d78 in Py_RunMain () from /lib64/libpython3.11.so.1.0 #35 0x00007f016fc2d43b in Py_BytesMain () from /lib64/libpython3.11.so.1.0 #36 0x00007f016f845b4a in __libc_start_call_main () from /lib64/libc.so.6 #37 0x00007f016f845c0b in __libc_start_main () from /lib64/libc.so.6 #38 0x000055f651012095 in _start () [Inferior 1 (process 203792) detached] [1] 203792 IOT instruction (core dumped) python test.py ```
fairydreaming commented 1 month ago

@kyteinsky Inference with T5-like models requires using some new API functions (e.g. llama_encode()). Without this not only the llama_cpp_python, but all other software based on llama.cpp won't be able to run encoder-decoder models like T5. I think you should report this bug to the python module.

kyteinsky commented 1 month ago

hey @fairydreaming, thanks for the T5 implementation! Which python module do you refer to? I was under the impression that llama-cpp-python is the python wrapper to use. Also, it doesn't depend on any such package related to llama.cpp.

fairydreaming commented 1 month ago

@kyteinsky Sure, is the python wrapper to use, but it's like a completely separate project from llama.cpp. It's not a part of llama.cpp project, so I think feature requests like support for encoder-decoder T5 models should be reported here: https://github.com/abetlen/llama-cpp-python

kyteinsky commented 1 month ago

That's where we are :)

fairydreaming commented 1 month ago

@kyteinsky lol, didn't notice that all that time, sorry

mtsdurica commented 1 month ago

Same problem here, really hoping the T5 architecture will get implemented soon!

Sadeghi85 commented 1 month ago

@fairydreaming

Hi,

Can you help with the python code? I tried these, but output is gibberish:

I added these functions to "_internals.py" (https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/_internals.py#L354):

def llama_model_has_encoder(self, model: _LlamaModel):

    return llama_cpp.llama_model_has_encoder(
        model
    )

def encode(self, batch: "_LlamaBatch"):
    assert self.ctx is not None
    assert batch.batch is not None
    return_code = llama_cpp.llama_encode(
        self.ctx,
        batch.batch,
    )
    if return_code != 0:
        raise RuntimeError(f"llama_encode returned {return_code}")

then these lines in "llama.py" (https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/llama.py#L633):

if self._ctx.llama_model_has_encoder(self._model.model):
    self._ctx.encode(self._batch)

self._ctx.decode(self._batch)
fairydreaming commented 1 month ago

@Sadeghi85 One thing that is missing in your code is a preparation of the input to decode() call. Check how it's done in llama-cli source code:

https://github.com/ggerganov/llama.cpp/blob/b841d0740855c5af1344a81f261139a45a2b39ee/examples/main/main.cpp#L536-L552

So before calling decode() you have to call llama_model_decoder_start_token() and prepare a batch containing only this single token (assuming that it contains only a single sequence). This will be the input to decode() call.

Sadeghi85 commented 1 month ago

I tried, but I really don't know what I'm doing. Waiting for @abetlen to take a look at this.

yugaljain1999 commented 2 weeks ago

@Sadeghi85 Have you been able to fixed this issue? cc @fairydreaming @abetlen

@Sadeghi85 One thing that is missing in your code is a preparation of the input to decode() call. Check how it's done in llama-cli source code:

https://github.com/ggerganov/llama.cpp/blob/b841d0740855c5af1344a81f261139a45a2b39ee/examples/main/main.cpp#L536-L552

So before calling decode() you have to call llama_model_decoder_start_token() and prepare a batch containing only this single token (assuming that it contains only a single sequence). This will be the input to decode() call.

fairydreaming commented 2 weeks ago

I see there are serious problems with using T5, so I added a branch with a high-level example of inference with T5 model:

https://github.com/fairydreaming/llama-cpp-python/tree/t5

Also there is a second branch containing low-level example of T5 inference:

https://github.com/fairydreaming/llama-cpp-python/tree/fix-low-level-examples

Hope it helps.

yugaljain1999 commented 2 weeks ago

@fairydreaming Thanks a lot for early response and scripts you provided.

May I know can we pass batch prompts as well rather than passing single prompt to model?

If yes, how can we do that? Your response would be highly appreciable.

fairydreaming commented 2 weeks ago

@yugaljain1999 Yes, you can pass multiple prompts. I don't know how it works in llama-cpp-python high-level API, but in llama.cpp (low-level API) you do it by creating a batch containing tokens with different seq_id.

For example if you have two prompts tokenized to pt1 and pt2 token arrays then you create a batch and:

batch.n_tokens = len(pt1) + len(pt2)
batch_index = 0
for i in range(len(pt1)):
    batch.token[batch_index] = pt1[i];
    batch.pos[batch_index] = i
    batch.n_seq_id[batch_index] = 1
    batch.seq_id[batch_index][0] = 0 # first prompt has sequence id 0
    batch.logits[batch_index] = False
    batch_index += 1

for i in range(len(pt2)):
    batch.token[batch_index] = pt2[i];
    batch.pos[batch_index] = i
    batch.n_seq_id[batch_index] = 1
    batch.seq_id[batch_index][0] = 1 # second prompt has sequence id 1
    batch.logits[batch_index] = False
    batch_index += 1

then you call llama_encode() on this batch.

For llama_decode you initialize the batch with decoder_start_token for both sequences:

batch.n_tokens = 2
for i in range(2):
    batch.token[i] = llama_cpp.llama_model_decoder_start_token(model);
    batch.pos[i] = 0
    batch.n_seq_id[0] = 1
    batch.seq_id[i][0] = i # first token has seq_id 0, second has seq_id 1
    batch.logits[i] = True

then you call llama_decode() on this batch, sample next tokens for both sequences, add them to the batch, call llama_decode() again to generate next tokens for both sequences etc.

yugaljain1999 commented 2 weeks ago

@fairydreaming Thanks for this info. Though I was thinking if we could use llama-server for scalable parallel batch inference. But when I try this command, got below error,

Command - ./llama-server -m ../madlad400-3b-mt.gguf -np 16 --temp 0 -ngl -1

Error -

examples/server/server.cpp:693: GGML_ASSERT(llama_add_eos_token(model) != 1) failed
./llama-server(+0x25c87b)[0x6448e414e87b]
./llama-server(+0x25e717)[0x6448e4150717]
./llama-server(+0x4775ff)[0x6448e43695ff]
./llama-server(+0x63fe8)[0x6448e3f55fe8]

Any idea how can we resolve it? Any help would be appreciable.

Thanks

fairydreaming commented 2 weeks ago

@yugaljain1999 T5 models are still not supported in llama-server

yugaljain1999 commented 1 week ago

@fairydreaming thanks for informing.

As llama-server won't work with T5 so now I am continuing working with low level api example of T5 as you shared, may I know what params we need to change in this script to leverage gpu acceleration fully?

Thanks

I see there are serious problems with using T5, so I added a branch with a high-level example of inference with T5 model:

https://github.com/fairydreaming/llama-cpp-python/tree/t5

Also there is a second branch containing low-level example of T5 inference:

https://github.com/fairydreaming/llama-cpp-python/tree/fix-low-level-examples

Hope it helps.