Open tarekziade opened 2 weeks ago
Example of generation config with a heavy bad words ids list https://huggingface.co/Mozilla/distilvit/blob/main/generation_config.json
(btw I don't know how to change the label from bug
to enhancement
)
I've done a bit of testing and it's a bit trickier than that, unfortunately. The bad_words_ids
is a list of lists structured as follows:
[
[a], // ALWAYS block a
[b, c], // only block c if preceded by [b]
[d], // ALWAYS block d
[e, f, g], // only block g if preceded by [e, f]
...
]
this means we still need to iterate over the entire list - especially to handle these "single bad words". This code:
for (let j = 1; j <= bad_word_ids.length - 1 && bad_word_ids.length < ids.length; ++j) {
// NOTE: We use != instead of !== to compare bigint and number
// @ts-ignore
if (bad_word_ids.at(-j - 1) != ids.at(-j)) {
// We have found a mismatch
mark = false;
break;
}
}
will check if the tokens before the last in the block list match the last ids, and if not, we won't block the last id in the block list.
The good news is that you shouldn't see a massive difference in performance. For the block list of 800, I only see a ~10ms difference in the unit test I created. For a block list of 100 000, the difference is more noticeable, but I don't see that happening in practice.
System Info
v3
Environment/Platform
Description
consider use a Map in
NoBadWordsLogitsProcessor
Reproduction
The NoBadWordsLogitsProcessor class nested loops can be slow to run when you have a bunch of bad words. That is my case for instance on distilgpt2 that has ~800 bad words in its vocabulary.
Building a static Map can speed up the look ups. something like
and then