keras-team / keras-nlp

Modular Natural Language Processing workflows with Keras
Apache License 2.0
758 stars 227 forks source link

Stop on multiple end tokens #1518

Closed grasskin closed 5 months ago

mattdangerw commented 5 months ago

I think we should think first of the overall API experience we want here. What about something like this?

# Default. Stop at gemma_lm.preprocessor.tokenizer.end_token_id, or error if
# self.preprocessor is none.
gemma_lm.generate(
    prompt,
    max_length=64,
    stop_token_ids="auto",
)
# Don't stop till max_length!
gemma_lm.generate(
    prompt,
    max_length=64,
    stop_token_ids=None,
)
# Custom. Provide multiple stop tokens, in this case we also stop on the literal word stop.
gemma_lm.generate(
    prompt,
    max_length=64,
    stop_token_ids=[tokenizer.end_token_id, tokenizer.token_to_id("stop")],
)

I don't really like setting this on the tokenizer. Tokenizer special token ids are not generally set by a user. Every tokenizer.xx_token_id is just a single integer right now. Preprocessing can also be detached from the task, in which case, the CausalLM does not even have a tokenizer to query.

mattdangerw commented 5 months ago

If we go we with above proposal, we should update the sampler API to also take in stop_tokens_ids, but it does not need the "auto" value.

We can do this with Gemma at first, but we should eventually update all models to have a consistent API surface.

We also might want to refactor a helper into tensor_utils.py. Would help readability:

def any_equal(inputs, values):
    """Return a mask that is True anywhere `inputs` has a value in `values`."""
    output = ops.equal(inputs, values[0])
    for value in values[1:]:
        output = ops.logical_or(outputs, value)
    return output
grasskin commented 5 months ago

We're currently defaulting to a mix of if preprocessor is specified use "auto" otherwise go with None. Should we error out if no preprocessor is specified or just switch to None?

grasskin commented 5 months ago

Discussed offline - we're going to do a full refactor and go with the more sane choice of erroring if "auto" is specified with no preprocessor. API will be more consistent for multitoken requirements.

grasskin commented 5 months ago

@mattdangerw this works for Gemma, if overall method lgty we can replicate in other models. Given that we're switching to stop_token_ids we very explicitly require iterables instead of single int, fixed sampling tests already.