xenova / transformers.js

State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!
https://huggingface.co/docs/transformers.js
Apache License 2.0
9.87k stars 582 forks source link

Help understanding logits and model vocabs #713

Closed thekevinscott closed 2 months ago

thekevinscott commented 2 months ago

Question

I'm trying to write a custom LogitsProcessor and have some questions. For reference, I'm using Xenova/phi-1_5_dev. I'm trying to implement a custom logic for white or blacklisting tokens, but running into difficulties understanding how to interpret token ids, tokens, and their decoded counterparts.

Here's what I think I understand:

I'd appreciate any insight or feedback on whether my assumptions above are correct or not. Thank you!

thekevinscott commented 2 months ago

The precipitating reason for this question is that I'm trying to force an eos token, but the model continues generating.

Here's a sample LogitsProcessor:

class LogitsProcessor {
  pipeline: TextGenerationPipeline;
  tokenIds: number[];
  stopTokenId: number;

  constructor(pipeline: TextGenerationPipeline, str: string) {
    this.pipeline = pipeline;
    const { input_ids, } = (pipeline.tokenizer as TokenizeFn)(str);
    this.tokenIds = [...input_ids.data,].map((n: bigint) => Number(n));
    this.stopTokenId = this.pipeline.tokenizer.model.convert_tokens_to_ids([
      this.pipeline.tokenizer.getToken('eos_token'),
    ])[0];
  }

  processors = [(inputTokens: number[], logits: Tensor) => {
    if (inputTokens.length > this.tokenIds.length) {
      console.warn('This should not happen');
      return logits;
    }
    const tokenId = this.tokenIds[inputTokens.length];
    const id = inputTokens.length === this.tokenIds.length ? this.stopTokenId : this.tokenIds[inputTokens.length];
    logits.data.fill(-Infinity);
    logits.data[id] = Infinity;
    return logits;
  },];

  [Symbol.iterator]() {
    return this.processors.values();
  }
}

instantiated with:

    const prompt = 'Write me some code';
    const logitsProcessor = new LogitsProcessor(this.pipeline, prompt + ' foo');

This will generate output like the following:

Write me some code foo<|endoftext|>
Student: A company has a budget of $5000 to spend on advertising. They want

I don't understand why <|endoftext|> is being treated like a part of the text output and not an indication to stop generation. I assume the answer is because I'm not understanding something in my initial question above.

thekevinscott commented 2 months ago

Hmm. I just tried with Xenova/gpt2 and I now see the following output:

Write me some code foo<|endoftext|>

So, maybe the issue with the model not stopping is specific to the model being used?

I still don't understand why <|endoftext|> is being returned as part of the text generated, though.

thekevinscott commented 2 months ago

I still don't understand why <|endoftext|> is being returned as part of the text generated, though.

This was user error. I'm calling .decode() manually and was neglecting to pass skip_special_tokens; passing that options successfully omits<|endoftext|> from the text:

    const decoded = tokenizer.decode(outputTokenIds[0], {
      skip_special_tokens: true,
    });

I still don't understand why phi-1_5_dev is not stopping on an eos token though.

thekevinscott commented 2 months ago

Here's what I've found for phi-1_5_dev:

Returning 2 as the eos token successfully stops generation for phi-1_5_dev (while returning 50256 does not stop generation). However, the 2 token ID gets incorrectly decoded:

    const decoded = tokenizer.decode(outputTokenIds[0], {
      skip_special_tokens: true,
    });

> "Write me some code foo#"

Which makes sense, as token ID 2 is marked as # in the vocab.json.

So I guess this entire thread boils down to: why the discrepancy? Does this indicate a bug in the model? Or am I misunderstanding how I should be decoding output tokens containing eos tokens?

xenova commented 2 months ago

Thanks for the report! Indeed, it looks like the EOS token is incorrectly set in the original version of the model (see here), which is why it's also (incorrectly) set to 2 in our version. I have updated it here (it's set to null here).

xenova commented 2 months ago

As for your original question, you can use the NoBadWordsLogitsProcessor logits processor for this (see here). You can use it by setting bad_words_ids in the generation params object:

// Generate text
const result = await generator(prompt, {
  max_new_tokens: 100,
  bad_words_ids: [[123]], // list of list of token ids (2D since you can specify a sequence of tokens to skip)
});
thekevinscott commented 2 months ago

Thanks for the response! Those edits look great. Does eos_token_id: null imply anything in particular, or does it just mean it falls back to whatever the default eos_token_id is?

As for your original question, you can use the NoBadWordsLogitsProcessor logits processor for this (see here)

Appreciate that reference. My actual use case is a bit more complicated - I'm trying to implement a GBNF grammar parser similar to llama.cpp's implementation. But good to know this exists so I don't have to reinvent the wheel in the future!