caikit / caikit-nlp

Apache License 2.0
12 stars 45 forks source link

Include input token count in results #334

Closed mynhardtburger closed 5 months ago

mynhardtburger commented 6 months ago

Depends on caikit data model updates in this PR: https://github.com/caikit/caikit/pull/675

This extends the embedding module to include the input_token_count in the results of the EmbeddingModule's run_ methods.

The sum_token_count(tokenized: BatchEncoding) -> int function calculates the count of tokens requiring model attention, based on the Encoding.attention_mask property, as returned by SentenceTransformerWithTruncate.tokenizer(). [PAD] is irrelevant for truncation and max_token_count parameters, while [CLS] and [SEP] are counted by the model when it considers the max length and truncation.

Additionally tests to confirm sort order is maintained was added.

Various other quality of life type hints were added.

mynhardtburger commented 6 months ago

FYI: @markstur

mynhardtburger commented 6 months ago

_truncate_input_tokens()'s error case could also possibly be refactored to make use of _sum_token_count() to avoid having to rerun the tokenization.

evaline-ju commented 6 months ago

@mynhardtburger caikit 0.26.14 is available with the data model update