clovaai / donut

Official Implementation of OCR-free Document Understanding Transformer (Donut) and Synthetic Document Generator (SynthDoG), ECCV 2022
https://arxiv.org/abs/2111.15664
MIT License
5.53k stars 444 forks source link

Issue with tokenizing '1' preceded by a char #229

Open arnaudstiegler opened 11 months ago

arnaudstiegler commented 11 months ago

Looking at the Donut tokenizer on transformers, it seems that '1' is the only digit missing a standalone token in the tokenizer. As a result, any 1 preceded by a character is tokenized into an unknown token by the tokenizer

Reproduction:

from transformers import DonutProcessor

processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")

print(processor.decode(processor.tokenizer('A1')['input_ids']))
<s> A<unk></s>
print(processor.decode(processor.tokenizer('A 1')['input_ids']))
<s> A 1</s>
print(processor.decode(processor.tokenizer('A2')['input_ids']))
<s> A2</s>

The issue seems to be coming from the base tokenizer used for Donut:

from transformers import XLMRobertaTokenizer
test = XLMRobertaTokenizer.from_pretrained("hyunwoongko/asian-bart-ecjk")

print(test.decode(test('A1')['input_ids']))
# <s> A<unk></s>
print(test.decode(test('A 1')['input_ids']))
# <s> A 1</s>

Interestingly, when using the tokenizer assigned by default (i.e Bart), the problem is not there anymore:

from transformers import AutoTokenizer, XLMRobertaTokenizer
tokenizer = AutoTokenizer.from_pretrained("hyunwoongko/asian-bart-ecjk")
​
print(tokenizer.decode(tokenizer('A1')['input_ids']))
# ▁A 1 </s> en_XX
print(tokenizer.decode(tokenizer('A 1')['input_ids']))
# ▁A ▁1 </s> en_XX
DoctorSlimm commented 11 months ago

bump

arnaudstiegler commented 11 months ago

One solution that works is:

It's not ideal because you basically drop a lot of tokens from the embeddings (every token that's formatted as ▁{token}) which means you don't access parts of the "pretraining knowledge" and force the model to adapt during fine-tuning. Empirically, I'm seeing some performance gaps coming from doing this

DoctorSlimm commented 11 months ago

@arnaudstiegler interesting, what about replacing all "1" in the input text with \<one>, and adding token \<one> to the tokenizer, would that confuse the default loss calculation in the Trainer?

felixvor commented 5 months ago

I made a community post about this in the donut-base model repo with a bit more detail, but didn't get a response yet. I just found this issue and would like to leave a bump.

Has anyfound found a better solution for this? We still finetune with token and fix the strings in post processing. Training runs on small datasets we did seem to indicate that the model knows that means 1 but we're still not entirely sure 😅

hancheolcho commented 4 months ago

I have the same problem.

In addition, there is quite number of unk tokens, when it comes to Korean text.

에스케이뷰아파트 => 에스케이<unk>아파트
자원의 탐사, 채취 =>  자원의▁<unk>사,▁채취
전산시스템의 공동활용 => 전산시스<unk>의▁공동활용
...
hancheolcho commented 4 months ago

I found an interesting case with CORD example.

I thought that at least the original tokenizer hyunwoongko/asian-bart-ecjk will not make such an error. So I tested XLMRobertaTokenizer tokenizer with hyunwoongko/asian-bart-ecjk, and surprisingly it still produces <unk> for the token 1!

And by chance I tested XLMRobertaTokenizerFast tokenizer with the same model, then surprisingly it DOES NOT produces <unk> for the token 1! Would it be a difference between Fast and non-Fast tokenizer implementation, similar to the thing @arnaudstiegler mentioned?

Did I miss something here, OTL?

from transformers import XLMRobertaTokenizer, XLMRobertaTokenizerFast

new_tokenizer2 = XLMRobertaTokenizer.from_pretrained("hyunwoongko/asian-bart-ecjk")
new_tokenizer2.add_special_tokens(
    {"additional_special_tokens": [
        "<s_cord-v2>", "</s>", "<s_menu>", "</s_menu>", "<s_nm>", "</s_nm>", "<s_unitprice>", "</s_unitprice>", "<s_cnt>", "</s_cnt>", 
        "<s_price>", "</s_price>", "<s_total>", "</s_total>", "<s_total_price>", "</s_total_price>", "<s_cashprice>", "</s_cashprice>",
        "<s_changeprice>", "</s_changeprice>", "<s_menuqty_cnt>", "</s_menuqty_cnt>"]},
     replace_additional_special_tokens=False)

new_tokenizer2_fast = XLMRobertaTokenizerFast.from_pretrained("hyunwoongko/asian-bart-ecjk")
new_tokenizer2_fast.add_special_tokens(
    {"additional_special_tokens": [
        "<s_cord-v2>", "</s>", "<s_menu>", "</s_menu>", "<s_nm>", "</s_nm>", "<s_unitprice>", "</s_unitprice>", "<s_cnt>", "</s_cnt>", 
        "<s_price>", "</s_price>", "<s_total>", "</s_total>", "<s_total_price>", "</s_total_price>", "<s_cashprice>", "</s_cashprice>",
        "<s_changeprice>", "</s_changeprice>", "<s_menuqty_cnt>", "</s_menuqty_cnt>"]},
     replace_additional_special_tokens=False)

# CORD example
s = "<s_cord-v2><s_menu><s_nm>2005-CHEESE JOHN</s_nm><s_unitprice>9.500,00</s_unitprice><s_cnt>x1</s_cnt><s_price>9.500,00</s_price></s_menu><s_total><s_total_price>9.500,00</s_total_price><s_cashprice>20.000,00</s_cashprice><s_changeprice>10.500</s_changeprice><s_menuqty_cnt>1</s_menuqty_cnt></s_total></s>"

input_ids2 = new_tokenizer2(s, add_special_tokens=False)["input_ids"]
restored_tokens2 = "".join(new_tokenizer2.convert_ids_to_tokens(input_ids2))
print(restored_tokens2)
if new_tokenizer2.unk_token_id in input_ids2:
    print("unk found!")
else:
    print("ok!")
print()

input_ids2_fast = new_tokenizer2_fast(s, add_special_tokens=False)["input_ids"]
restored_tokens2_fast = "".join(new_tokenizer2_fast.convert_ids_to_tokens(input_ids2_fast))
print(restored_tokens2_fast)
if new_tokenizer2_fast.unk_token_id in input_ids2_fast:
    print("unk found!")
else:
    print("ok!")
print()
<s_cord-v2><s_menu><s_nm>▁2005-CHEESE▁JOHN</s_nm><s_unitprice>▁9.500,00</s_unitprice><s_cnt>▁x<unk></s_cnt><s_price>▁9.500,00</s_price></s_menu><s_total><s_total_price>▁9.500,00</s_total_price><s_cashprice>▁20.000,00</s_cashprice><s_changeprice>▁10.500</s_changeprice><s_menuqty_cnt>▁1</s_menuqty_cnt></s_total></s>
unk found!

<s_cord-v2><s_menu><s_nm>▁2005-CHEESE▁JOHN</s_nm><s_unitprice>▁9.500,00</s_unitprice><s_cnt>▁x1</s_cnt><s_price>▁9.500,00</s_price></s_menu><s_total><s_total_price>▁9.500,00</s_total_price><s_cashprice>▁20.000,00</s_cashprice><s_changeprice>▁10.500</s_changeprice><s_menuqty_cnt>▁1</s_menuqty_cnt></s_total></s>
ok!
felixvor commented 4 months ago

My current take on this is to use UDOP instead of Donut..