Closed tarekgh closed 4 months ago
This is issue we can use for the tokenizers design review.
CC @ericstj @michaelgsharp @stephentoub @luisquintanilla @terrajobst @JakeRadMSFT @LittleLittleCloud
Hey @tarekgh,
Thank you for sharing the design! I've had the chance to take a look and have some thoughts - fwiw/probably a bit from a BERT perspective ;-).
Memory<int>/<long>
into onnxruntime. If EncodeToIds returns IReadOnlyList<int>
, can we avoid copying the memory?Normalize(ReadOnlySpan<char> original)
returns string, wouldn't we do more allocations than needed for some normalizations?EncodeToIds
overload with maxTokenCount
has out string normalizedText
, wouldn't that require us to construct and allocate the normalized string even if the caller might not be interested in it?pretokenize -> try to encode -> on failure: normalize -> try to encode
and thereby skip the normalization work for most of the input?
BertNormalize = RemoveDiacritics(UnicodeNormalize(RemoveControlChars(Lowercase(input))))
. Now consider an input that is mostly ascii/latin1/... chars but not lowercase. Wouldn't we pay for the allocations for the whole input, because UnicodeNormalize is part of the normalization, but lowercasing would often be sufficient and could work alloc-free on Span<char>
?List<Normalizer>
or something similar wouldn't it be able to save allocations and normalization operaterions? E.g. by doing pretokenize -> try to encode -> on failure: normalize[0] -> try to encode -> on failure: normalize[1] -> try to encode -> ...
?MapTokenToId
handle tokens that map to multiple ids? Would it be able to handle a token that is only valid as a suffix, e.g. by calling it with "##suffix"?CountTokens
(and IndexOfTokenCount
and LastIndexOfTokenCount
) have faster implementations than e.g. EncodeToIds().Count? Not sure, but I think for BERT the cost might be similar.maxTokenCount
with a stride is out of scope?If you think I can contribute to/help with any of this please let me know :-)
@georg-jung thanks for the feedback.
IIRC, it's easiest to feed Memory
/ into onnxruntime. If EncodeToIds returns IReadOnlyList , can we avoid copying the memory?
I'll think more about it if we can return Memory<T>
instead. The challenge is during the encoding need to store the produced ids which need to allocate the array before hand. need to estimate the initial array allocation correctly to avoid allocating un-used memory and avoid re-allocation if there is no enough size. Anyway, the idea is worth to think about it.
Also, in the future we can expose a new APIs that can take a destination span/Memory
If Normalize(ReadOnlySpan
original) returns string, wouldn't we do more allocations than needed for some normalizations?
Yes, we are designing the library for generic use case, if there is normalization, then we need to have allocations. We always return the normalized string so callers can perform subsequent calls on that normalized string without more allocations nor re-doing the normalization. Anyone want to optimize and avoid the allocations can manually do the normalization before calling the tokenizer.
If the EncodeToIds overload with maxTokenCount has out string normalizedText, wouldn't that require us to construct and allocate the normalized string even if the caller might not be interested in it?
normalizedText is allocated only if the tokenizer is created with a normalizer object and the caller of EncodeToIds having considerNormalization = true
. The caller has full control over the operation.
Couldn't we do something like pretokenize -> try to encode -> on failure: normalize -> try to encode and thereby skip the normalization work for most of the input?
I don't think we can do that. You never know if the encode is failed or not. It is possible the encoding to succeed if the string is not normalized and return unexpected values. Users can choose to do that by calling the tokenizer with considerNormalization = false
and then check the result and decide to call again with true
if needed.
How would MapTokenToId handle tokens that map to multiple ids?
It doesn't map to multiple Ids. It will return null if cannot map it to a single Id. Users can do EncodeToIds if they need to get the full results.
Would it be able to handle a token that is only valid as a suffix, e.g. by calling it with "##suffix"?
This depends on the tokenizer. If the tokenizer support that, then it should be able to map it. Note, we are supporting many tokenizers so each tokenizer can decide how to map tokens to/from Ids.
Can CountTokens (and IndexOfTokenCount and LastIndexOfTokenCount) have faster implementations than e.g. EncodeToIds().Count? Not sure, but I think for BERT the cost might be similar.
Yes. we don't implement CountTokens
by calling EncodeToIds().Count
. We try to optimize for such cases. Sometimes we need to create a cache though which help for subsequent calls and speed it more.
Padding input_ids is out of scope? Combining multiple strings as a single model input (e.g. paragraph + question about paragraph) is out of scope? Batching is out of scope (e.g. for parallelization)? Encoding texts longer than maxTokenCount with a stride is out of scope?
For the current version, these are out of scope till we find enough demand on such features. We need to get the main features first which covering the majority of scenarios we are seeing so far.
If you think I can contribute to/help with any of this please let me know :-)
Sure. we appreciate all helps!
CC @stephentoub
Thanks for the detailed response!
The challenge is during the encoding need to store the produced ids which need to allocate the array before hand. need to estimate the initial array allocation correctly to avoid allocating un-used memory and avoid re-allocation if there is no enough size.
An overload that writes to a passed-in Span would of corse be easy in that regard as it's then up to the caller :D. An approach I took with my bert tokenizer is to re-use an internal buffer for subsequent calls, so the allocation cost becomes less relevant when encoding multiple documents.
Maybe I'm mistaken, regarding IReadOnlyList specifically, I was thinking if one would want to pass it as an input to onnxruntime, wouldn't it then always be needed to write something similar to
var res = EncodeToIds(...)
var modelInput = res is int[] i ? i.AsMemory() : (res is List<int> l ? l.AsMemory() : (int[])[..res])
if there is normalization, then we need to have allocations.
I was thinking of something like MemoryExtensions.ToLower or maybe RemoveControlAndReplacement(ReadOnlySpan<char> text, out ReadOnlySpan<char> cleaned)
with a re-used internal buffer or similar. Couldn't then at least many normalizations be alloc-free, e.g. lowercasing, uppercasing, stripping control chars and, probably most important, the "no-op" normalization, where the input already is normalized according to the normalizer it is passed to?
I'm a bit in a hurry and think about the other points soon... Thanks for always taking the time to discuss this, I think it is really interesting!
I think trying to optimize for normalization scenarios will complicate the tokenizer interfaces. What proposed would be enough to the users to decide how far they need to optimize. If they need really allocation free normalization, then they can do it themselves before calling the tokenizer. If they are ok with allocation but want to avoid the allocations/processing on the subsequent calls using the normalized string, they can use considerNormalization = false
. I am seeing trying to optimize for normalization while want to support all possible scenario would be very challenging and will make the APIs more complicated than normal users want.
namespace Microsoft.ML.Tokenizers
{
public struct EncodeResults<T>
{
public IReadOnlyList<T> Tokens { get; set; }
public string? NormalizedText { get; set; }
}
public struct EncodeSettings
{
public bool ConsiderNormalization { get; set; }
public bool ConsiderPreTokenization { get; set; }
public bool ProduceNormalizedString { get; set; }
public int MaxTokenCount { get; set; }
}
public abstract partial class Tokenizer
{
protected Tokenizer() { }
public virtual Normalizer? Normalizer { get { throw null; } }
public virtual PreTokenizer? PreTokenizer { get { throw null; } }
public IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true);
public IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedText, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedText, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true);
protected abstract EncodeResults<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings);
public IReadOnlyList<Token> EncodeToTokens(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
public IReadOnlyList<Token> EncodeToTokens(ReadOnlySpan<char> text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true);
protected abstract EncodeResults<Token> EncodeToTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings);
public int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
public int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true);
protected abstract int CountTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings);
public int GetIndexByTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
public int GetIndexByTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
public int GetIndexByTokenCountFromEnd(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
public int GetIndexByTokenCountFromEnd(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
protected abstract int GetIndexByTokenCount(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings, bool fromEnd, out string? normalizedString, out int tokenCount);
public abstract string? Decode(IEnumerable<int> ids) { throw null; }
public abstract OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, out int idsConsumed, out int charsWritten) { throw null; }
public virtual int? MapTokenToId(string token) { throw null; }
public abstract int? MapTokenToId(ReadOnlySpan<char> token);
public abstract string? MapIdToToken(int? id);
//
// Factory methods
//
public static Task<Tokenizer> CreateTiktokenAsync(Stream vocabStream, PreTokenizer? preTokenizer, Normalizer? normalizer, IReadOnlyDictionary<string, int> specialTokens = null,
int cacheSize = 8192, Threading.CancellationToken cancellationToken = null) { throw null; }
public static Task<Tokenizer> CreateTiktokenAsync(string vocabFilePath, PreTokenizer? preTokenizer, Normalizer? normalizer, IReadOnlyDictionary<string, int> specialTokensEncoder = null,
int cacheSize = 8192, Threading.CancellationToken cancellationToken = null) { throw null; }
public static Tokenizer CreateTiktokenForEncoding(string encodingName, IReadOnlyDictionary<string, int> extraSpecialTokens = null, Normalizer? normalizer = null) { throw null; }
public static Tokenizer CreateTiktokenForModel(string modelName, IReadOnlyDictionary<string, int> extraSpecialTokens = null, Normalizer? normalizer = null) { throw null; }
public static Tokenizer CreateTiktokenForModel(string modelName, Stream vocabStream, IReadOnlyDictionary<string, int> extraSpecialTokens = null,
int cacheSize = 8192, Normalizer? normalizer = null) { throw null; }
public static Task<Tokenizer> CreateTiktokenForModelAsync(string modelName, Stream vocabStream, IReadOnlyDictionary<string, int> extraSpecialTokens = null,
int cacheSize = 8192, Normalizer? normalizer = null, Threading.CancellationToken cancellationToken = null) { throw null; }
public static Tokenizer CreateLlama(Stream modelStream, bool addBeginOfSentence = true, bool addEndOfSentence = false) { throw null; }
public static Tokenizer CreateCodeGen(Stream vocabStream, Stream mergesStream, bool addPrefixSpace = false, bool addBeginOfSentence = false, bool addEndOfSentence = false) { throw null; }
public static Tokenizer CreatePhi2(Stream vocabStream, Stream mergesStream, bool addPrefixSpace = false, bool addBeginOfSentence = false, bool addEndOfSentence = false) { throw null; }
}
}
MapTokenToId
null
for cases where one token is mapped to more than one IDEncodeToIds
Create
/CreateAsync
to the concrete types
Tokenizer.CreateTiktokenAsync()
would become TiktokenTokenizer.CreateAsync()
Token
-> EncodedToken
TextToken
/TextTokenizer
Token
Index
and Range
and use thoseTokenizer
for consistency and discoverabilityBpe
FrozenDictionary<,>
Tiktoken
Encoder
and Decoder
as it's not clear how someone would use itVocab
-> Vocabulary
namespace Microsoft.ML.Tokenizers;
public struct EncodeResults<T>
{
public IReadOnlyList<T> Tokens { get; set; }
public string? NormalizedText { get; set; }
}
public struct EncodeSettings
{
public bool ConsiderNormalization { get; set; }
public bool ConsiderPreTokenization { get; set; }
public bool ProduceNormalizedString { get; set; }
public int MaxTokenCount { get; set; }
}
public abstract partial class Tokenizer
{
protected Tokenizer();
public virtual Normalizer? Normalizer { get; }
public virtual PreTokenizer? PreTokenizer { get; }
public IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true);
public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true);
public IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedText, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true);
public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedText, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true);
protected abstract EncodeResults<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings);
public IReadOnlyList<Token> EncodeToTokens(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true);
public IReadOnlyList<Token> EncodeToTokens(ReadOnlySpan<char> text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true);
protected abstract EncodeResults<Token> EncodeToTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings);
public int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true);
public int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true);
protected abstract int CountTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings);
public int GetIndexByTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true);
public int GetIndexByTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true);
public int GetIndexByTokenCountFromEnd(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true);
public int GetIndexByTokenCountFromEnd(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true);
protected abstract int GetIndexByTokenCount(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings, bool fromEnd, out string? normalizedString, out int tokenCount);
public abstract string? Decode(IEnumerable<int> ids);
public abstract OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, out int idsConsumed, out int charsWritten);
}
public abstract partial class Normalizer
{
protected Normalizer();
public string Normalize(string text);
public string Normalize(ReadOnlySpan<char> text);
protected abstract string Normalize(string? text, ReadOnlySpan<char> textSpan);
}
public abstract partial class PreTokenizer
{
protected PreTokenizer();
public IEnumerable<Range> PreTokenize(string text);
public IEnumerable<Range> PreTokenize(ReadOnlySpan<char> text);
protected abstract IEnumerable<Range> PreTokenize(string? text, ReadOnlySpan<char> textSpan);
}
public readonly struct EncodedToken
{
public Token(int id, string value, Range range);
public int Id { get; }
public Range Range { get; }
public string Value { get; }
}
public sealed partial class LowerCaseNormalizer : Normalizer
{
public static LowerCaseNormalizer Instance { get; }
public override string Normalize(ReadOnlySpan<char> original);
public override string Normalize(string original);
}
public sealed partial class UpperCaseNormalizer : Normalizer
{
public static UpperCaseNormalizer Instance { get; }
public override string Normalize(ReadOnlySpan<char> original);
public override string Normalize(string original);
}
public sealed partial class SentencePieceNormalizer : Normalizer
{
public SentencePieceNormalizer(bool removeExtraWhitespace, bool addDummyPrefix, bool escapeWhitespace, bool treatWhitespaceAsSuffix);
public bool AddDummyPrefix { get; }
public bool EscapeWhitespace { get; }
public bool RemoveExtraWhitespace { get; }
public bool TreatWhitespaceAsSuffix { get; }
public override string Normalize(ReadOnlySpan<char> original);
public override string Normalize(string original);
}
public sealed partial class TiktokenPreTokenizer : PreTokenizer
{
public static TiktokenPreTokenizer Create(Regex regex, IReadOnlyDictionary<string, int> specialTokensEncoder);
}
public sealed partial class WhitespacePreTokenizer : PreTokenizer
{
public static Whitespace Instance { get; }
}
public sealed partial class RobertaPreTokenizer : PreTokenizer
{
public static RobertaPreTokenizer Instance { get; }
}
public sealed partial class BpeTokenizer : Tokenizer
{
public static Create(string vocabFile, string? mergesFile, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, string? unknownToken = null,
string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool? fuseUnknownTokens = false);
public static Create(Stream vocabStream, Stream? mergesStream, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, string? unknownToken = null,
string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool? fuseUnknownTokens = false);
public string? ContinuingSubwordPrefix { get; }
public string? EndOfWordSuffix { get; }
public bool? FuseUnknownTokens { get; }
public string? UnknownToken { get; }
public IReadOnlyDictionary<string, int> Vocab { get; }
public string? Decode(IEnumerable<int> ids, bool considerSpecialTokens);
}
public sealed partial class TiktokenTokenizer : Tokenizer
{
public static TiktokenTokenizer Create(string vocabFilePath, PreTokenizer? preTokenizer, IReadOnlyDictionary<string, int> specialTokens = null, Normalizer? normalizer = null, int? cacheSize = 8192);
public static TiktokenTokenizer Create(Stream vocabStream, PreTokenizer? preTokenizer, IReadOnlyDictionary<string, int> specialTokens = null, Normalizer? normalizer = null, int? cacheSize = 8192);
// public IReadOnlyDictionary<int, ReadOnlyMemory<byte>> Decoder { get; }
// public IReadOnlyDictionary<ReadOnlyMemory<byte>, int> Encoder { get; }
public IReadOnlyDictionary<string, int> SpecialTokens { get; }
public IReadOnlyDictionary<string, int> Vocab { get; }
}
out int textLength
should be out int charsConsumed
EncodeResults<T>.CharsConsumed
Llama
we want a dedicated type (and other types like it)FrozenDictionary
due to their creation costnamespace Microsoft.ML.Tokenizers;
public struct EncodeResults<T>
{
// Existing:
// public IReadOnlyList<T> Tokens { get; set; }
// public string? NormalizedText { get; set; }
public int CharsConsumed { get; set; }
}
public partial class LlamaTokenizer : SentencePieceTokenizer
{
public static LlamaTokenizer Create();
}
LLM tokenizers are a crucial component in Large Language Models (LLMs) like GPT-3 or BERT. They are responsible for the tokenization process, which involves breaking down natural language text into smaller, manageable pieces called tokens. These tokens can be words, characters, sub-words, numbers, or symbols, and they allow the LLM to process and understand the text.
This issue presents the APIs proposed for the Microsoft.ML.Tokenizers library, intended for design review. The design introduces an abstract class named
Tokenizer
, which defines the primary interfaces for all supported tokenizers. Additionally, the Tokenizer class includes a factory method for creating various types of tokenizers.The Tokenizer can be optionally configured with normalizers, which are used to normalize the text before processing it. Normalization can take various forms such as uppercasing, lowercasing, Unicode Normalization, and removing or inserting specific characters from the input text. The normalization feature is optional for the tokenizer, and it is left to the discretion of either the tokenizer or the user to decide whether to utilize any normalizers.
Pre-tokenization is an additional component that the tokenizer can be configured with, aimed at splitting the input text into smaller units prior to processing. While pre-tokenization is also an optional feature, it is commonly utilized in most tokenizers. Many pre-tokenizers employ regex for this purpose.
The typical sequence of operations for the Tokenizer involves:
Tokenizers offer the following functionalities:
EncodeToIds
in the proposed design.CountTokens
in the proposed design.Encode
in the proposed design.IndexOfTokenCount
andLastIndexOfTokenCount
.Decode
in the proposed design.MapTokenToId
andMapIdToToken
in the proposed design.Tokenizers typically rely on vocabulary files, which are provided to the tokenizer during instantiation. Users commonly pass these vocabularies as either a file or a stream to the tokenizer constructor. Vocabulary files can vary in format, such as JSON, plain text, protobuf, and more. Each tokenizer determines the specific formats of files it can be instantiated with.
Usage Example:
Create BPE tokenizer using the constructor
Create Tiktoken tokenizer using factory method:
Encode to Ids:
Count Tokens
Ful Encoding:
Count tokens up to max token count:
Decoding Ids back to string
Map string token to Id and vice versa
Proposal:
Namespace
Tokenizer Abstraction
Normalization abstraction
Pre-tokenization abstraction
Token class
Concrete Normalizers
Concrete Pre-tokenizers
Concrete Tokenizer - Bpe
Concrete Tokenizer - Tiktoken
Concrete Tokenizer - EnglishRoberta
Concrete Tokenizer - CodeGen
Concrete Tokenizer - SentencePiece