Open WonbinKweon opened 1 week ago
Actually, I tried to use new tokens for atomic id (one unique token for each document)
new_tokens = list(set(train_dataset.train_data['text_id'] + test_dataset.train_data['text_id']))
tokenizer.add_tokens(new_tokens)
model.resize_token_embeddings(len(tokenizer))
decode_vocab = np.array(tokenizer(new_tokens).input_ids)[:,0].tolist()
def restrict_decode_vocab(batch_idx, prefix_beam):
return decode_vocab
and get about Hits@1 0.3 / Hits@10 0.6 ("new_tok" in the figure), which is similar to the DSI paper. your code is "int"/"int2" in the figure
Hi @WonbinKweon, good to see atomic id works well with our code! My original code was aiming to reproduce "Naive String Docid " setting, where no new token should be added into the vocab.
I see. I tried Naive string Docid with new tokens, but it shows poor performance. I don't know why. After all, thank you for your codes!
I found that using constrained beam search with trie shows slightly better performance. (int1_10k vs int1_trie)
docids = list(set(train_dataset.train_data['text_id'] + test_dataset.train_data['text_id']))
new_docids = {}
for idx, docid in enumerate(docids):
new_docid = str(idx)
new_docids[docid] = new_docid
train_dataset.new_docids = new_docids
eval_dataset.new_docids = new_docids
test_dataset.new_docids = new_docids
# new dv
t = pygtrie.Trie()
for k in list(new_docids.keys()):
t[[0]+tokenizer(new_docids[k]).input_ids] = k ## 앞에 0(pad), 뒤에 1(eos)
def restrict_decode_vocab(batch_idx, prefix_beam):
try:
child = t._get_node(prefix_beam.tolist())[0].children
except KeyError:
print("key error:", prefix_beam)
return [0]*10
if len(child) == 1: ## OneChild
return [child.step]*10
elif len(child) == 0: ## leaf
if type(child) != pygtrie._NoChildren:
print("no child:", prefix_beam)
return [0]*10
else: ## Children
return list(child)
I find you use the existing integer tokens in the tokenizer for DocIds. However, in the paper, the authors extent the token matrix for the DocIds (i.e., new tokens for integer DocIds) Did I understand correctly?