dotnet / machinelearning

ML.NET is an open source and cross-platform machine learning framework for .NET.
https://dot.net/ml
MIT License
9.05k stars 1.88k forks source link

Tokenizers Library Design #7144

Closed tarekgh closed 4 months ago

tarekgh commented 7 months ago

LLM tokenizers are a crucial component in Large Language Models (LLMs) like GPT-3 or BERT. They are responsible for the tokenization process, which involves breaking down natural language text into smaller, manageable pieces called tokens. These tokens can be words, characters, sub-words, numbers, or symbols, and they allow the LLM to process and understand the text.

This issue presents the APIs proposed for the Microsoft.ML.Tokenizers library, intended for design review. The design introduces an abstract class named Tokenizer, which defines the primary interfaces for all supported tokenizers. Additionally, the Tokenizer class includes a factory method for creating various types of tokenizers.

The Tokenizer can be optionally configured with normalizers, which are used to normalize the text before processing it. Normalization can take various forms such as uppercasing, lowercasing, Unicode Normalization, and removing or inserting specific characters from the input text. The normalization feature is optional for the tokenizer, and it is left to the discretion of either the tokenizer or the user to decide whether to utilize any normalizers.

Pre-tokenization is an additional component that the tokenizer can be configured with, aimed at splitting the input text into smaller units prior to processing. While pre-tokenization is also an optional feature, it is commonly utilized in most tokenizers. Many pre-tokenizers employ regex for this purpose.

The typical sequence of operations for the Tokenizer involves:

Tokenizers offer the following functionalities:

Tokenizers typically rely on vocabulary files, which are provided to the tokenizer during instantiation. Users commonly pass these vocabularies as either a file or a stream to the tokenizer constructor. Vocabulary files can vary in format, such as JSON, plain text, protobuf, and more. Each tokenizer determines the specific formats of files it can be instantiated with.

Usage Example:

Create BPE tokenizer using the constructor

    Tokenizer tokenizer = new Bpe(vocabStream: vocabStream, , mergesStream: mergesStream, normalizer: null, preTokenizer: WhiteSpace.Instance);

Create Tiktoken tokenizer using factory method:

    Dictionary<string, int> specialTokens = new Dictionary<string, int> { { IMStart, 100264}, { IMEnd, 100265}, };
    Tokenizer tokenizer = Tokenizer.CreateTiktokenForModel("gpt-4", specialTokens);

Encode to Ids:

    IReadOnlyList<int> encoded = tokenizer.EncodeToIds("Hello World");

Count Tokens

    int idsCount = tokenizer.CountTokens("Hello World");

Ful Encoding:

    // APIs return any information related to the input or normalized text will usually out normalizedString which can be null if there is no normalization performed.
    // Token contain the string token, the token ID, and the offset of the token mapped to the input or normalized text.
    IReadOnlyList<Token> result = tokenizer.Encode(text, out string? normalizedString);

Count tokens up to max token count:

    int length = tokenizer.IndexOfTokenCount(text, maxTokenCount: 10, out string? normalizedString, out int tokenCount);

    int index = tokenizer.LastIndexOfTokenCount(text, maxTokenCount: 3, out normalizedString, out tokenCount)

Decoding Ids back to string

string decodedText = tokenizer.Decode(idsArray);

Map string token to Id and vice versa

int? id = tokenizer.MapTokenToId("Hello");

string? token = MapIdToToken(tokenId);

Proposal:

Namespace

namespace Microsoft.ML.Tokenizers

Tokenizer Abstraction

    public abstract partial class Tokenizer
    {
        protected Tokenizer() { }

        public virtual Normalizer? Normalizer { get { throw null; } }

        public virtual PreTokenizer? PreTokenizer { get { throw null; } }

        public virtual IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public abstract IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true);

        public virtual IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedText, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public abstract IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedText, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true);

        public virtual int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public abstract int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true);

        public virtual IReadOnlyList<Token> Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public abstract IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true);

        public virtual int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public abstract int IndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true);

        public virtual int LastIndexOfTokenCount(string text, int maxTokenCount, out string? processedText, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public abstract int LastIndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? processedText, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true);

        public virtual string? Decode(IEnumerable<int> ids) { throw null; }

        public virtual int? MapTokenToId(string token) { throw null; }
        public abstract int? MapTokenToId(ReadOnlySpan<char> token);

        public abstract string? MapIdToToken(int? id);

       //
       // Factory methods
       // 

        public static Task<Tokenizer> CreateTiktokenAsync(Stream vocabStream, PreTokenizer? preTokenizer, Normalizer? normalizer, IReadOnlyDictionary<string, int> specialTokens = null, 
                                                                                                  int cacheSize = 8192, Threading.CancellationToken cancellationToken = null) { throw null; }

        public static Task<Tokenizer> CreateTiktokenAsync(string vocabFilePath, PreTokenizer? preTokenizer, Normalizer? normalizer, IReadOnlyDictionary<string, int> specialTokensEncoder = null, 
                                                                                                  int cacheSize = 8192, Threading.CancellationToken cancellationToken = null) { throw null; }

        public static Tokenizer CreateTiktokenForEncoding(string encodingName, IReadOnlyDictionary<string, int> extraSpecialTokens = null, Normalizer? normalizer = null) { throw null; }

        public static Tokenizer CreateTiktokenForModel(string modelName, IReadOnlyDictionary<string, int> extraSpecialTokens = null, Normalizer? normalizer = null) { throw null; }

        public static Tokenizer CreateTiktokenForModel(string modelName, Stream vocabStream, IReadOnlyDictionary<string, int> extraSpecialTokens = null, 
                                                                                                    int cacheSize = 8192, Normalizer? normalizer = null) { throw null; }

        public static Task<Tokenizer> CreateTiktokenForModelAsync(string modelName, Stream vocabStream, IReadOnlyDictionary<string, int> extraSpecialTokens = null, 
                                                                                                   int cacheSize = 8192, Normalizer? normalizer = null, Threading.CancellationToken cancellationToken = null) { throw null; }

        public static Tokenizer CreateLlama(Stream modelStream, bool addBeginOfSentence = true, bool addEndOfSentence = false) { throw null; }

        public static Tokenizer CreateCodeGen(Stream vocabStream, Stream mergesStream, bool addPrefixSpace = false, bool addBeginOfSentence = false, bool addEndOfSentence = false) { throw null; }

        public static Tokenizer CreatePhi2(Stream vocabStream, Stream mergesStream, bool addPrefixSpace = false, bool addBeginOfSentence = false, bool addEndOfSentence = false) { throw null; }
    }

Normalization abstraction

    public abstract partial class Normalizer
    {
        protected Normalizer() { }

        public abstract string Normalize(string original);
        public abstract string Normalize(ReadOnlySpan<char> original);
    }

Pre-tokenization abstraction

    public abstract partial class PreTokenizer
    {
        protected PreTokenizer() { }

        public abstract IEnumerable<(int, int)> PreTokenize(string text);
        public abstract IEnumerable<(int, int)> PreTokenize(ReadOnlySpan<char> text);
    }

Token class

   // returned from Tokenizer.Encode(...)

    public readonly struct Token
    {
        public Token(int id, string value, (int, int) offset) { }

        public int Id { get { throw null; } }

        public (int Index, int Length) Offset { get { throw null; } }

        public string Value { get { throw null; } }
    }

Concrete Normalizers

    public sealed partial class LowerCaseNormalizer : Normalizer
    {
        public override string Normalize(ReadOnlySpan<char> original) { throw null; }
        public override string Normalize(string original) { throw null; }
    }

    public sealed partial class UpperCaseNormalizer : Normalizer
    {
        public override string Normalize(ReadOnlySpan<char> original) { throw null; }

        public override string Normalize(string original) { throw null; }
    }

    public sealed partial class SentencePieceNormalizer : Normalizer
    {
        public SentencePieceNormalizer(bool removeExtraWhiteSpaces, bool addDummyPrefix, bool escapeWhiteSpaces, bool treatWhitespaceAsSuffix) { }
        public bool AddDummyPrefix { get { throw null; } }
        public bool EscapeWhiteSpaces { get { throw null; } }
        public bool RemoveExtraWhiteSpaces { get { throw null; } }
        public bool TreatWhitespaceAsSuffix { get { throw null; } }

        public override string Normalize(ReadOnlySpan<char> original) { throw null; }
        public override string Normalize(string original) { throw null; }
    }

Concrete Pre-tokenizers

    public sealed partial class TiktokenPreTokenizer : PreTokenizer
    {
        public TiktokenPreTokenizer(Text.RegularExpressions.Regex regex, IReadOnlyDictionary<string, int> specialTokensEncoder) { }

        public override IEnumerable<(int, int)> PreTokenize(string text) { throw null; }
        public override IEnumerable<(int, int)> PreTokenize(ReadOnlySpan<char> text) { throw null; }
    }

    public sealed partial class WhiteSpace : PreTokenizer
    {
        public static WhiteSpace Instance { get { throw null; } }

        public override IEnumerable<(int, int)> PreTokenize(string text) { throw null; }
        public override IEnumerable<(int, int)> PreTokenize(ReadOnlySpan<char> text) { throw null; }
    }

    public sealed partial class RobertaPreTokenizer : PreTokenizer
    {
        public static RobertaPreTokenizer Instance { get { throw null; } }

        public override IEnumerable<(int, int)> PreTokenize(string text) { throw null; }
        public override IEnumerable<(int, int)> PreTokenize(ReadOnlySpan<char> text) { throw null; }
    }

Concrete Tokenizer - Bpe

    public sealed partial class Bpe : Tokenizer
    {
        public Bpe(string vocabFile, string? mergesFile, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, string? unknownToken = null, 
                           string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool? fuseUnknownTokens = false) { }

        public Bpe(Stream vocabStream, Stream? mergesStream, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, string? unknownToken = null, 
                          string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool? fuseUnknownTokens = false) { }

        public string? ContinuingSubwordPrefix { get { throw null; } }

        public string? EndOfWordSuffix { get { throw null; } }

        public bool? FuseUnknownTokens { get { throw null; } }

        public string? UnknownToken { get { throw null; } }

        public IReadOnlyDictionary<string, int> Vocab { get { throw null; } }

        public string? Decode(IEnumerable<int> ids, bool considerSpecialTokens) { throw null; }

        public override Normalizer? Normalizer { get { throw null; } }
        public override PreTokenizer? PreTokenizer { get { throw null; } }
        public override int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override string? Decode(IEnumerable<int> ids) { throw null; }
        public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override IReadOnlyList<Token> Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override int IndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override int LastIndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override int LastIndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override string? MapIdToToken(int? id) { throw null; }
        public override int? MapTokenToId(ReadOnlySpan<char> token) { throw null; }
    }

Concrete Tokenizer - Tiktoken

    public sealed partial class Tiktoken : Tokenizer
    {
        public Tiktoken(Stream vocabStream, PreTokenizer? preTokenizer, IReadOnlyDictionary<string, int> specialTokens = null, Normalizer? normalizer = null, int? cacheSize = 8192) { }

        public Tiktoken(string vocabFilePath, PreTokenizer? preTokenizer, IReadOnlyDictionary<string, int> specialTokens = null, Normalizer? normalizer = null, int? cacheSize = 8192) { }

        public IReadOnlyDictionary<int, ReadOnlyMemory<Byte>> Decoder { get { throw null; } }

        public IReadOnlyDictionary<ReadOnlyMemory<Byte>, int> Encoder { get { throw null; } }

        public IReadOnlyDictionary<string, int> SpecialTokens { get { throw null; } }

        public IReadOnlyDictionary<string, int> Vocab { get { throw null; } }

        public override Normalizer? Normalizer { get { throw null; } }
        public override PreTokenizer? PreTokenizer { get { throw null; } }
        public override int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override string? Decode(IEnumerable<int> ids) { throw null; }
        public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override IReadOnlyList<Token> Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override int IndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override int LastIndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override int LastIndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override string? MapIdToToken(int? id) { throw null; }
        public override int? MapTokenToId(ReadOnlySpan<char> token) { throw null; }
    }

Concrete Tokenizer - EnglishRoberta

    public sealed partial class EnglishRoberta : Tokenizer
    {
        public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highestOccurrenceMappingStream, PreTokenizer? preTokenizer, Normalizer? normalizer, bool filterUnsupportedChars, bool disposeStream) { }

        public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highestOccurrenceMappingStream, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, bool filterUnsupportedChars = true) { }

        public EnglishRoberta(string vocabularyPath, string mergePath, string highestOccurrenceMappingPath, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, bool filterUnsupportedChars = true) { }

        public bool FilterUnsupportedChars { get { throw null; } }

        public int PadIndex { get { throw null; } }

        public int SymbolsCount { get { throw null; } }

        public IReadOnlyDictionary<string, int> Vocab { get { throw null; } }

        public int AddMaskSymbol(string mask = "<mask>") { throw null; }

        public IReadOnlyList<int> ConvertIdsToOccurrenceRanks(IReadOnlyList<int> ids) { throw null; }

        public IReadOnlyList<int> ConvertIdsToOccurrenceValues(IReadOnlyList<int> ids) { throw null; }

        public IReadOnlyList<int> ConvertOccurrenceRanksToIds(IReadOnlyList<int> ranks) { throw null; }

        public bool IsSupportedChar(char ch) { throw null; }

        public override Normalizer? Normalizer { get { throw null; } }
        public override PreTokenizer? PreTokenizer { get { throw null; } }
        public override int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override string? Decode(IEnumerable<int> ids) { throw null; }
        public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override IReadOnlyList<Token> Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override int IndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override int LastIndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override int LastIndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override string? MapIdToToken(int? id) { throw null; }
        public override int? MapTokenToId(ReadOnlySpan<char> token) { throw null; }
    }

Concrete Tokenizer - CodeGen

    public sealed partial class CodeGen : Tokenizer
    {
        public CodeGen(string vocabularyPath, string mergePath, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, IReadOnlyDictionary<string, int> addedTokens = null, 
                                     bool? addPrefixSpace = false, bool? addBeginningOfSentence = false, bool? addEndOfSentence = false, string? unknownToken = "<|endoftext|>", 
                                     string? beginningOfSentenceToken = "<|endoftext|>", string? endOfSentenceToken = "<|endoftext|>") { }

        public CodeGen(Stream vocabularyStream, Stream mergeStream, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, IReadOnlyDictionary<string, int> addedTokens = null, 
                                    bool? addPrefixSpace = false, bool? addBeginningOfSentence = false, bool? addEndOfSentence = false, string? unknownToken = "<|endoftext|>", 
                                    string? beginningOfSentenceToken = "<|endoftext|>", string? endOfSentenceToken = "<|endoftext|>") { }

        public bool AddBeginningOfSentence { get { throw null; } }

        public IReadOnlyDictionary<string, int> AddedTokens { get { throw null; } }

        public bool AddEndOfSentence { get { throw null; } }

        public bool AddPrefixSpace { get { throw null; } }

        public int? BeginningOfSentenceId { get { throw null; } }

        public string? BeginningOfSentenceToken { get { throw null; } }

        public int? EndOfSentenceId { get { throw null; } }

        public string? EndOfSentenceToken { get { throw null; } }

        public string? UnknownToken { get { throw null; } }

        public int? UnknownTokenId { get { throw null; } }

        public IReadOnlyDictionary<string, int> Vocab { get { throw null; } }

        public IReadOnlyList<int> EncodeToIds(string text, bool addPrefixSpace, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }

        public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool addPrefixSpace, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }

        public IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, bool addPrefixSpace, bool addBeginningOfSentence, bool addEndOfSentence, out string? normalizedString, 
                                                               out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }

        public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, bool addPrefixSpace, bool addBeginningOfSentence, bool addEndOfSentence, out string? normalizedString, 
                                                               out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }

        public int CountTokens(ReadOnlySpan<char> text, bool addPrefixSpace, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }

        public int CountTokens(string text, bool addPrefixSpace, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }

        public int IndexOfTokenCount(string text, int maxTokenCount, bool addPrefixSpace, bool addBeginningOfSentence, bool addEndOfSentence, out string? normalizedString, 
                                                               out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }

        public int IndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, bool addPrefixSpace, bool addBeginningOfSentence, bool addEndOfSentence, out string? normalizedString, 
                                                                out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }

        public int LastIndexOfTokenCount(string text, int maxTokenCount, bool addPrefixSpace, bool addBeginningOfSentence, bool addEndOfSentence, out string? normalizedString, 
                                                                out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }

        public int LastIndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, bool addPrefixSpace, bool addBeginningOfSentence, bool addEndOfSentence, out string? normalizedString, 
                                                                 out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }

        public IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, bool addPrefixSpace, bool addBeginningOfSentence, bool addEndOfSentence, out string? normalizedString, 
                                                                 bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }

        public IReadOnlyList<Token> Encode(string text, bool addPrefixSpace, bool addBeginningOfSentence, bool addEndOfSentence, out string? normalizedString, 
                                                                  bool considerPreTokenization = true,  bool considerNormalization = true) { throw null; }

        public string? Decode(IEnumerable<int> ids, bool hasPrefixSpace, bool considerSpecialTokens) { throw null; }

        public override Normalizer? Normalizer { get { throw null; } }
        public override PreTokenizer? PreTokenizer { get { throw null; } }
        public override int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override string? Decode(IEnumerable<int> ids) { throw null; }
        public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override IReadOnlyList<Token> Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override int IndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override int LastIndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override int LastIndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override string? MapIdToToken(int? id) { throw null; }
        public override int? MapTokenToId(ReadOnlySpan<char> token) { throw null; }
    }

Concrete Tokenizer - SentencePiece

    public sealed partial class SentencePiece : Tokenizer
    {
        internal SentencePiece() { }

        public bool AddBeginningOfSentence { get { throw null; } }

        public bool AddDummyPrefix { get { throw null; } }

        public bool AddEndOfSentence { get { throw null; } }

        public int BeginningOfSentenceId { get { throw null; } }

        public string BeginningOfSentenceToken { get { throw null; } }

        public bool ByteFallback { get { throw null; } }

        public int EndOfSentenceId { get { throw null; } }

        public string EndOfSentenceToken { get { throw null; } }

        public bool EscapeWhiteSpaces { get { throw null; } }

        public bool TreatWhitespaceAsSuffix { get { throw null; } }

        public int UnknownId { get { throw null; } }

        public string UnknownToken { get { throw null; } }

        public IReadOnlyDictionary<string, int> Vocab { get { throw null; } }

        public int CountTokens(ReadOnlySpan<char> text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }

        public int CountTokens(ReadOnlySpan<char> text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue) { throw null; }

        public int CountTokens(string text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }

        public IReadOnlyList<int> EncodeToIds(string text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }

        public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }

        public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization, 
                                                            out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue) { throw null; }

        public IReadOnlyList<int> EncodeToIds(string text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedString, 
                                                             out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }

        public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedString, 
                                                             out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }

        public int IndexOfTokenCount(string text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedString, 
                                                            out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }

        public int IndexOfTokenCount(ReadOnlySpan<char> text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedString, 
                                                             out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }

        public int LastIndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization, 
                                                              out string? normalizedString, out int tokenCount) { throw null; }

        public IReadOnlyList<Token> Encode(string text, out string? normalizedString, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }

        public IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, out string? normalizedString, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }

        public override Normalizer? Normalizer { get { throw null; } }
        public override PreTokenizer? PreTokenizer { get { throw null; } }
        public override int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override string? Decode(IEnumerable<int> ids) { throw null; }
        public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override IReadOnlyList<Token> Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override int IndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override int LastIndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override int LastIndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public override string? MapIdToToken(int? id) { throw null; }
        public override int? MapTokenToId(ReadOnlySpan<char> token) { throw null; }
    }
tarekgh commented 7 months ago

This is issue we can use for the tokenizers design review.

CC @ericstj @michaelgsharp @stephentoub @luisquintanilla @terrajobst @JakeRadMSFT @LittleLittleCloud

georg-jung commented 7 months ago

Hey @tarekgh,

Thank you for sharing the design! I've had the chance to take a look and have some thoughts - fwiw/probably a bit from a BERT perspective ;-).

If you think I can contribute to/help with any of this please let me know :-)

tarekgh commented 6 months ago

@georg-jung thanks for the feedback.

IIRC, it's easiest to feed Memory/ into onnxruntime. If EncodeToIds returns IReadOnlyList, can we avoid copying the memory?

I'll think more about it if we can return Memory<T> instead. The challenge is during the encoding need to store the produced ids which need to allocate the array before hand. need to estimate the initial array allocation correctly to avoid allocating un-used memory and avoid re-allocation if there is no enough size. Anyway, the idea is worth to think about it. Also, in the future we can expose a new APIs that can take a destination span/Memory so let the callers decide how to manage the allocation.

If Normalize(ReadOnlySpan original) returns string, wouldn't we do more allocations than needed for some normalizations?

Yes, we are designing the library for generic use case, if there is normalization, then we need to have allocations. We always return the normalized string so callers can perform subsequent calls on that normalized string without more allocations nor re-doing the normalization. Anyone want to optimize and avoid the allocations can manually do the normalization before calling the tokenizer.

If the EncodeToIds overload with maxTokenCount has out string normalizedText, wouldn't that require us to construct and allocate the normalized string even if the caller might not be interested in it?

normalizedText is allocated only if the tokenizer is created with a normalizer object and the caller of EncodeToIds having considerNormalization = true. The caller has full control over the operation.

Couldn't we do something like pretokenize -> try to encode -> on failure: normalize -> try to encode and thereby skip the normalization work for most of the input?

I don't think we can do that. You never know if the encode is failed or not. It is possible the encoding to succeed if the string is not normalized and return unexpected values. Users can choose to do that by calling the tokenizer with considerNormalization = false and then check the result and decide to call again with true if needed.

How would MapTokenToId handle tokens that map to multiple ids?

It doesn't map to multiple Ids. It will return null if cannot map it to a single Id. Users can do EncodeToIds if they need to get the full results.

Would it be able to handle a token that is only valid as a suffix, e.g. by calling it with "##suffix"?

This depends on the tokenizer. If the tokenizer support that, then it should be able to map it. Note, we are supporting many tokenizers so each tokenizer can decide how to map tokens to/from Ids.

Can CountTokens (and IndexOfTokenCount and LastIndexOfTokenCount) have faster implementations than e.g. EncodeToIds().Count? Not sure, but I think for BERT the cost might be similar.

Yes. we don't implement CountTokens by calling EncodeToIds().Count. We try to optimize for such cases. Sometimes we need to create a cache though which help for subsequent calls and speed it more.

Padding input_ids is out of scope? Combining multiple strings as a single model input (e.g. paragraph + question about paragraph) is out of scope? Batching is out of scope (e.g. for parallelization)? Encoding texts longer than maxTokenCount with a stride is out of scope?

For the current version, these are out of scope till we find enough demand on such features. We need to get the main features first which covering the majority of scenarios we are seeing so far.

If you think I can contribute to/help with any of this please let me know :-)

Sure. we appreciate all helps!

CC @stephentoub

georg-jung commented 6 months ago

Thanks for the detailed response!

The challenge is during the encoding need to store the produced ids which need to allocate the array before hand. need to estimate the initial array allocation correctly to avoid allocating un-used memory and avoid re-allocation if there is no enough size.

An overload that writes to a passed-in Span would of corse be easy in that regard as it's then up to the caller :D. An approach I took with my bert tokenizer is to re-use an internal buffer for subsequent calls, so the allocation cost becomes less relevant when encoding multiple documents.

Maybe I'm mistaken, regarding IReadOnlyList specifically, I was thinking if one would want to pass it as an input to onnxruntime, wouldn't it then always be needed to write something similar to

var res = EncodeToIds(...)
var modelInput = res is int[] i ? i.AsMemory() : (res is List<int> l ? l.AsMemory() : (int[])[..res])

if there is normalization, then we need to have allocations.

I was thinking of something like MemoryExtensions.ToLower or maybe RemoveControlAndReplacement(ReadOnlySpan<char> text, out ReadOnlySpan<char> cleaned) with a re-used internal buffer or similar. Couldn't then at least many normalizations be alloc-free, e.g. lowercasing, uppercasing, stripping control chars and, probably most important, the "no-op" normalization, where the input already is normalized according to the normalizer it is passed to?

I'm a bit in a hurry and think about the other points soon... Thanks for always taking the time to discuss this, I think it is really interesting!

tarekgh commented 6 months ago

I think trying to optimize for normalization scenarios will complicate the tokenizer interfaces. What proposed would be enough to the users to decide how far they need to optimize. If they need really allocation free normalization, then they can do it themselves before calling the tokenizer. If they are ok with allocation but want to avoid the allocations/processing on the subsequent calls using the normalized string, they can use considerNormalization = false. I am seeing trying to optimize for normalization while want to support all possible scenario would be very challenging and will make the APIs more complicated than normal users want.

bartonjs commented 6 months ago

Video

namespace Microsoft.ML.Tokenizers
{
    public struct EncodeResults<T>
    {
        public IReadOnlyList<T> Tokens { get; set; }
        public string? NormalizedText { get; set; }
    }

    public struct EncodeSettings
    {
        public bool ConsiderNormalization { get; set; }
        public bool ConsiderPreTokenization { get; set; }
        public bool ProduceNormalizedString { get; set; }
        public int MaxTokenCount { get; set; }
    }

    public abstract partial class Tokenizer
    {
        protected Tokenizer() { }

        public virtual Normalizer? Normalizer { get { throw null; } }

        public virtual PreTokenizer? PreTokenizer { get { throw null; } }

        public IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true);

        public IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedText, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedText, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true);

        protected abstract EncodeResults<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings);

        public IReadOnlyList<Token> EncodeToTokens(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public IReadOnlyList<Token> EncodeToTokens(ReadOnlySpan<char> text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true);

        protected abstract EncodeResults<Token> EncodeToTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings);

        public int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true);

        protected abstract int CountTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings);

        public int GetIndexByTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public int GetIndexByTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }

        public int GetIndexByTokenCountFromEnd(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public int GetIndexByTokenCountFromEnd(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }

        protected abstract int GetIndexByTokenCount(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings, bool fromEnd, out string? normalizedString, out int tokenCount);

        public abstract string? Decode(IEnumerable<int> ids) { throw null; }
        public abstract OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, out int idsConsumed, out int charsWritten) { throw null; }

        public virtual int? MapTokenToId(string token) { throw null; }
        public abstract int? MapTokenToId(ReadOnlySpan<char> token);

        public abstract string? MapIdToToken(int? id);

       //
       // Factory methods
       // 

        public static Task<Tokenizer> CreateTiktokenAsync(Stream vocabStream, PreTokenizer? preTokenizer, Normalizer? normalizer, IReadOnlyDictionary<string, int> specialTokens = null, 
                                                                                                  int cacheSize = 8192, Threading.CancellationToken cancellationToken = null) { throw null; }

        public static Task<Tokenizer> CreateTiktokenAsync(string vocabFilePath, PreTokenizer? preTokenizer, Normalizer? normalizer, IReadOnlyDictionary<string, int> specialTokensEncoder = null, 
                                                                                                  int cacheSize = 8192, Threading.CancellationToken cancellationToken = null) { throw null; }

        public static Tokenizer CreateTiktokenForEncoding(string encodingName, IReadOnlyDictionary<string, int> extraSpecialTokens = null, Normalizer? normalizer = null) { throw null; }

        public static Tokenizer CreateTiktokenForModel(string modelName, IReadOnlyDictionary<string, int> extraSpecialTokens = null, Normalizer? normalizer = null) { throw null; }

        public static Tokenizer CreateTiktokenForModel(string modelName, Stream vocabStream, IReadOnlyDictionary<string, int> extraSpecialTokens = null, 
                                                                                                    int cacheSize = 8192, Normalizer? normalizer = null) { throw null; }

        public static Task<Tokenizer> CreateTiktokenForModelAsync(string modelName, Stream vocabStream, IReadOnlyDictionary<string, int> extraSpecialTokens = null, 
                                                                                                   int cacheSize = 8192, Normalizer? normalizer = null, Threading.CancellationToken cancellationToken = null) { throw null; }

        public static Tokenizer CreateLlama(Stream modelStream, bool addBeginOfSentence = true, bool addEndOfSentence = false) { throw null; }

        public static Tokenizer CreateCodeGen(Stream vocabStream, Stream mergesStream, bool addPrefixSpace = false, bool addBeginOfSentence = false, bool addEndOfSentence = false) { throw null; }

        public static Tokenizer CreatePhi2(Stream vocabStream, Stream mergesStream, bool addPrefixSpace = false, bool addBeginOfSentence = false, bool addEndOfSentence = false) { throw null; }
    }
}
terrajobst commented 6 months ago
namespace Microsoft.ML.Tokenizers;

public struct EncodeResults<T>
{
    public IReadOnlyList<T> Tokens { get; set; }
    public string? NormalizedText { get; set; }
}

public struct EncodeSettings
{
    public bool ConsiderNormalization { get; set; }
    public bool ConsiderPreTokenization { get; set; }
    public bool ProduceNormalizedString { get; set; }
    public int MaxTokenCount { get; set; }
}

public abstract partial class Tokenizer
{
    protected Tokenizer();

    public virtual Normalizer? Normalizer { get; }

    public virtual PreTokenizer? PreTokenizer { get; }

    public IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true);
    public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true);

    public IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedText, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true);
    public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedText, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true);

    protected abstract EncodeResults<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings);

    public IReadOnlyList<Token> EncodeToTokens(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true);
    public IReadOnlyList<Token> EncodeToTokens(ReadOnlySpan<char> text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true);

    protected abstract EncodeResults<Token> EncodeToTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings);

    public int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true);
    public int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true);

    protected abstract int CountTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings);

    public int GetIndexByTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true);
    public int GetIndexByTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true);

    public int GetIndexByTokenCountFromEnd(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true);
    public int GetIndexByTokenCountFromEnd(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true);

    protected abstract int GetIndexByTokenCount(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings, bool fromEnd, out string? normalizedString, out int tokenCount);

    public abstract string? Decode(IEnumerable<int> ids);
    public abstract OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, out int idsConsumed, out int charsWritten);
}

public abstract partial class Normalizer
{
    protected Normalizer();

    public string Normalize(string text);
    public string Normalize(ReadOnlySpan<char> text);

    protected abstract string Normalize(string? text, ReadOnlySpan<char> textSpan);
}

public abstract partial class PreTokenizer
{
    protected PreTokenizer();

    public IEnumerable<Range> PreTokenize(string text);
    public IEnumerable<Range> PreTokenize(ReadOnlySpan<char> text);

    protected abstract IEnumerable<Range> PreTokenize(string? text, ReadOnlySpan<char> textSpan);
}

public readonly struct EncodedToken
{
    public Token(int id, string value, Range range);
    public int Id { get; }
    public Range Range { get; }
    public string Value { get; }
}

public sealed partial class LowerCaseNormalizer : Normalizer
{
    public static LowerCaseNormalizer Instance { get; }
    public override string Normalize(ReadOnlySpan<char> original);
    public override string Normalize(string original);
}

public sealed partial class UpperCaseNormalizer : Normalizer
{
    public static UpperCaseNormalizer Instance { get; }
    public override string Normalize(ReadOnlySpan<char> original);
    public override string Normalize(string original);
}

public sealed partial class SentencePieceNormalizer : Normalizer
{
    public SentencePieceNormalizer(bool removeExtraWhitespace, bool addDummyPrefix, bool escapeWhitespace, bool treatWhitespaceAsSuffix);
    public bool AddDummyPrefix { get; }
    public bool EscapeWhitespace { get; }
    public bool RemoveExtraWhitespace { get; }
    public bool TreatWhitespaceAsSuffix { get; }
    public override string Normalize(ReadOnlySpan<char> original);
    public override string Normalize(string original);
}

public sealed partial class TiktokenPreTokenizer : PreTokenizer
{
    public static TiktokenPreTokenizer Create(Regex regex, IReadOnlyDictionary<string, int> specialTokensEncoder);
}

public sealed partial class WhitespacePreTokenizer : PreTokenizer
{
    public static Whitespace Instance { get; }
}

public sealed partial class RobertaPreTokenizer : PreTokenizer
{
    public static RobertaPreTokenizer Instance { get; }
}

public sealed partial class BpeTokenizer : Tokenizer
{
    public static Create(string vocabFile, string? mergesFile, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, string? unknownToken = null, 
                         string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool? fuseUnknownTokens = false);

    public static Create(Stream vocabStream, Stream? mergesStream, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, string? unknownToken = null, 
                         string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool? fuseUnknownTokens = false);

    public string? ContinuingSubwordPrefix { get; }
    public string? EndOfWordSuffix { get; }
    public bool? FuseUnknownTokens { get; }
    public string? UnknownToken { get; }

    public IReadOnlyDictionary<string, int> Vocab { get; }

    public string? Decode(IEnumerable<int> ids, bool considerSpecialTokens);
}

public sealed partial class TiktokenTokenizer : Tokenizer
{
    public static TiktokenTokenizer Create(string vocabFilePath, PreTokenizer? preTokenizer, IReadOnlyDictionary<string, int> specialTokens = null, Normalizer? normalizer = null, int? cacheSize = 8192);
    public static TiktokenTokenizer Create(Stream vocabStream, PreTokenizer? preTokenizer, IReadOnlyDictionary<string, int> specialTokens = null, Normalizer? normalizer = null, int? cacheSize = 8192);

    // public IReadOnlyDictionary<int, ReadOnlyMemory<byte>> Decoder { get; }
    // public IReadOnlyDictionary<ReadOnlyMemory<byte>, int> Encoder { get; }
    public IReadOnlyDictionary<string, int> SpecialTokens { get; }
    public IReadOnlyDictionary<string, int> Vocab { get; }
}
terrajobst commented 5 months ago

Video

namespace Microsoft.ML.Tokenizers;

public struct EncodeResults<T>
{
    // Existing:
    // public IReadOnlyList<T> Tokens { get; set; }
    // public string? NormalizedText { get; set; }

    public int CharsConsumed { get; set; }
}

public partial class LlamaTokenizer : SentencePieceTokenizer 
{
    public static LlamaTokenizer Create();
}