xenova / transformers.js

State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!
https://huggingface.co/docs/transformers.js
Apache License 2.0
10.97k stars 668 forks source link

(v3) speed up NoBadWordsLogitsProcessor #913

Open tarekziade opened 2 weeks ago

tarekziade commented 2 weeks ago

System Info

v3

Environment/Platform

Description

consider use a Map in NoBadWordsLogitsProcessor

Reproduction

The NoBadWordsLogitsProcessor class nested loops can be slow to run when you have a bunch of bad words. That is my case for instance on distilgpt2 that has ~800 bad words in its vocabulary.

Building a static Map can speed up the look ups. something like

const bad_words_map = new Map();
for (const bad_word_ids of this.bad_words_ids) {
    const key = bad_word_ids.at(-1);
    if (!bad_words_map.has(key)) {
        bad_words_map.set(key, []);
    }
    bad_words_map.get(key).push(bad_word_ids.slice(0, -1));
}

and then

_call(input_ids, logits) {
    for (let i = 0; i < input_ids.length; ++i) {
        const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
        const ids = input_ids[i];
        const last_id = ids.at(-1);

        if (bad_words_map.has(last_id)) {
            const prefixes = bad_words_map.get(last_id);
            for (const prefix of prefixes) {
                if (ids.slice(-prefix.length).every((v, idx) => v === prefix[idx])) {
                    batch_logits_data[last_id] = -Infinity;
                    break;
                }
            }
        }
    }
    return logits;
}
tarekziade commented 2 weeks ago

Example of generation config with a heavy bad words ids list https://huggingface.co/Mozilla/distilvit/blob/main/generation_config.json

(btw I don't know how to change the label from bug to enhancement)

xenova commented 2 weeks ago

I've done a bit of testing and it's a bit trickier than that, unfortunately. The bad_words_ids is a list of lists structured as follows:

[
  [a], // ALWAYS block a
  [b, c], // only block c if preceded by [b]
  [d], // ALWAYS block d
  [e, f, g], // only block g if preceded by [e, f]
  ...
]

this means we still need to iterate over the entire list - especially to handle these "single bad words". This code:

for (let j = 1; j <= bad_word_ids.length - 1 && bad_word_ids.length < ids.length; ++j) {

    // NOTE: We use != instead of !== to compare bigint and number
    // @ts-ignore
    if (bad_word_ids.at(-j - 1) != ids.at(-j)) {
        // We have found a mismatch
        mark = false;
        break;
    }
}

will check if the tokens before the last in the block list match the last ids, and if not, we won't block the last id in the block list.


The good news is that you shouldn't see a massive difference in performance. For the block list of 800, I only see a ~10ms difference in the unit test I created. For a block list of 100 000, the difference is more noticeable, but I don't see that happening in practice.