ArvinZhuang / DSI-transformers

A huggingface transformers implementation of "Transformer Memory as a Differentiable Search Index"
MIT License
154 stars 14 forks source link

DocIds overriding existing integer tokens? #10

Open WonbinKweon opened 1 week ago

WonbinKweon commented 1 week ago

I find you use the existing integer tokens in the tokenizer for DocIds. However, in the paper, the authors extent the token matrix for the DocIds (i.e., new tokens for integer DocIds) Did I understand correctly?

WonbinKweon commented 1 week ago

image

Actually, I tried to use new tokens for atomic id (one unique token for each document)

    new_tokens = list(set(train_dataset.train_data['text_id'] + test_dataset.train_data['text_id']))
    tokenizer.add_tokens(new_tokens)
    model.resize_token_embeddings(len(tokenizer))

    decode_vocab = np.array(tokenizer(new_tokens).input_ids)[:,0].tolist()
    def restrict_decode_vocab(batch_idx, prefix_beam):
        return decode_vocab

and get about Hits@1 0.3 / Hits@10 0.6 ("new_tok" in the figure), which is similar to the DSI paper. your code is "int"/"int2" in the figure

ArvinZhuang commented 1 week ago

Hi @WonbinKweon, good to see atomic id works well with our code! My original code was aiming to reproduce "Naive String Docid " setting, where no new token should be added into the vocab.

WonbinKweon commented 1 week ago

I see. I tried Naive string Docid with new tokens, but it shows poor performance. I don't know why. After all, thank you for your codes!

WonbinKweon commented 4 days ago

image I found that using constrained beam search with trie shows slightly better performance. (int1_10k vs int1_trie)

docids = list(set(train_dataset.train_data['text_id'] + test_dataset.train_data['text_id']))

new_docids = {}
for idx, docid in enumerate(docids):
    new_docid = str(idx)
    new_docids[docid] = new_docid

train_dataset.new_docids = new_docids
eval_dataset.new_docids = new_docids
test_dataset.new_docids = new_docids

# new dv
t = pygtrie.Trie()
for k in list(new_docids.keys()):
    t[[0]+tokenizer(new_docids[k]).input_ids] = k ## 앞에 0(pad), 뒤에 1(eos)

def restrict_decode_vocab(batch_idx, prefix_beam):
    try:
        child = t._get_node(prefix_beam.tolist())[0].children
    except KeyError:
        print("key error:", prefix_beam)
        return [0]*10

    if len(child) == 1: ## OneChild
        return [child.step]*10
    elif len(child) == 0: ## leaf
        if type(child) != pygtrie._NoChildren:
            print("no child:", prefix_beam)
        return [0]*10
    else: ## Children
        return list(child)