ml-explore / mlx-examples

Examples in the MLX framework
MIT License
6.25k stars 892 forks source link

T5 tokenizer decoding error with CodeT5+ #1021

Open zcbenz opened 1 month ago

zcbenz commented 1 month ago
$ python3 convert.py --model codet5p-220m
$ python3 t5.py --model codet5p-220m --prompt 'def print_hello_world():<extra_id_0>' --max-tokens 10
[INFO] Generating with T5...
Input:  def print_hello_world():<extra_id_0>
<extra_id_0>ĊĠĠĠĠprintĠ"HelloĠWorld"ĊĊ

The hf_t5.py can do correct output with changes:

diff --git a/t5/hf_t5.py b/t5/hf_t5.py
index 98c6da8..23d9644 100644
--- a/t5/hf_t5.py
+++ b/t5/hf_t5.py
@@ -23,11 +23,11 @@ def embed(t5_model: str):

 def generate(t5_model: str):
-    prompt = "translate English to German: As much as six inches of rain could fall in the New York City region through Monday morning, and officials warned of flooding along the coast."
+    prompt = "def print_hello_world():<extra_id_0>"
     tokenizer = AutoTokenizer.from_pretrained(t5_model)
     torch_model = AutoModelForSeq2SeqLM.from_pretrained(t5_model)
     torch_tokens = tokenizer(prompt, return_tensors="pt", padding=True).input_ids
-    outputs = torch_model.generate(torch_tokens, do_sample=False, max_length=512)
+    outputs = torch_model.generate(torch_tokens, do_sample=False, max_length=10)
     print(tokenizer.decode(outputs[0], skip_special_tokens=True))
$ python3 hf_t5.py --model codet5p-220m
    print "Hello World"

It seems that the tokenizer does not work well with streaming decoding.

awni commented 1 month ago

Thanks for flagging. Indeed the way we do streaming decode in the T5 example is not correct for most tokenizers (you can't typically decode each new token individually as we do here). It should either be a proper streaming decoder or we just eat the quadratic cost and redecode the entire prefix.

Will mark this as a bug, should be a fairly simple fix.