Open nelson-liu opened 3 years ago
Hi @nelson-liu ,
I think the bug comes from the add_symbol function in fairseq/fairseq/data/dictionary.py file:
def add_symbol(self, word, n=1, overwrite=False):
"""Adds a word to the dictionary"""
if word in self.indices and not overwrite:
idx = self.indices[word]
self.count[idx] = self.count[idx] + n
return idx
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(n)
return idx
The condition should be changed to if word in self.indices and overwrite:. Otherwise, even when overwrite is indeed set to True, the symbols will not be overwritten.
In your case, this results the dictionary having duplicate special tokens. Normally, when loading from file, the dictionary should start with 4 special tokens (<s>
, <pad>
, </s>
and <unk>
in that order), followed by the entries in the file. If your file already has special tokens with #fairseq:overwrite
, they should be overwritten. The dict.txt saved during preprocessing is meant to skip the those first tokens, and should not have those #fairseq:overwrite
tags.
I have created a pull request that fixes the bug (#5329).
🐛 Bug
To Reproduce
Steps to reproduce the behavior (always include the command you ran):
fairseq-preprocess
with a--srcdict
that has#fairseq:overwrite
. For example, the command in the roberta pretraining tutorial https://github.com/pytorch/fairseq/blob/master/examples/roberta/README.pretraining.md#1-preprocess-the-data#fairseq:overwrite
is not preserved indict.txt
.Expected behavior
When using
--srcdict
, thedict.txt
should be exactly the same as the one passed in tofairseq-preprocess
Environment
pip
, source): pip