dotnet / machinelearning

ML.NET is an open source and cross-platform machine learning framework for .NET.
https://dot.net/ml
MIT License
9.05k stars 1.88k forks source link

WordPiece and Bert Tokenizer Design review #7281

Closed tarekgh closed 1 week ago

tarekgh commented 3 weeks ago

Proposal

The proposal omitted the overridden properties and method that is defined in the abstraction we already reviewed before.

WordPiece Tokenizer

namespace Microsoft.ML.Tokenizers
{
    public partial class WordPieceTokenizer : Tokenizer
    {
        public static WordPieceTokenizer Create(
                        string vocabFilePath,
                        PreTokenizer? preTokenizer = null,
                        Normalizer? normalizer = null,
                        IReadOnlyDictionary<string, int>? specialTokens = null,
                        string unknownToken = "[UNK]",
                        string continuingSubwordPrefix = DefaultContinuingSubwordPrefix,
                        int maxInputCharsPerWord = DefaultMaxInputCharsPerWord)

        public static WordPieceTokenizer Create(
                        Stream vocabStream,
                        PreTokenizer? preTokenizer = null,
                        Normalizer? normalizer = null,
                        IReadOnlyDictionary<string, int>? specialTokens = null,
                        string unknownToken = "[UNK]",
                        string continuingSubwordPrefix = DefaultContinuingSubwordPrefix,
                        int maxInputCharsPerWord = DefaultMaxInputCharsPerWord)

        public static async Task<WordPieceTokenizer> CreateAsync(
                        Stream vocabStream,
                        PreTokenizer? preTokenizer = null,
                        Normalizer? normalizer = null,
                        IReadOnlyDictionary<string, int>? specialTokens = null,
                        string unknownToken = "[UNK]",
                        string continuingSubwordPrefix = DefaultContinuingSubwordPrefix,
                        int maxInputCharsPerWord = DefaultMaxInputCharsPerWord,
                        CancellationToken cancellationToken = default)

        /// <summary>
        /// Gets the unknown token.
        /// A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead.
        /// </summary>
        public string UnknownToken { get; }

        /// <summary>
        /// Gets the unknown token ID.
        /// A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead.
        /// </summary>
        public int UnknownTokenId { get; }

        /// <summary>
        /// Gets the prefix to use for sub-words that are not the first part of a word.
        /// </summary>
        public string ContinuingSubwordPrefix { get; }

        /// <summary>
        /// Gets the maximum number of characters to authorize in a single word.
        /// </summary>
        public int MaxInputCharsPerWord { get; }

        /// <summary>
        /// Gets the special tokens and their corresponding ids.
        /// </summary>
        public IReadOnlyDictionary<string, int>? SpecialTokens { get; }

        /// <summary>
        /// Decode the given ids, back to a String.
        /// </summary>
        /// <param name="ids">The list of ids that we want to decode.</param>
        /// <param name="skipSpecialTokens">Indicate whether to skip the special tokens during the decoding.</param>
        /// <returns>The decoded string.</returns>
        public string Decode(IEnumerable<int> ids, bool skipSpecialTokens)

        public OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, bool skipSpecialTokens, 
                   out int idsConsumed, out int charsWritten)
    }
}

Bert Tokenizer

namespace Microsoft.ML.Tokenizers
{
    public sealed partial class BertTokenizer : WordPieceTokenizer
    {
        public static BertTokenizer Create(
                    string vocabFilePath,
                    bool doLowerCase = true,
                    bool doBasicTokenization = true,
                    bool splitOnSpecialTokens = true,
                    string unknownToken = "[UNK]",
                    string sepToken = "[SEP]",
                    string padToken = "[PAD]",
                    string clsToken = "[CLS]",
                    string maskToken = "[MASK]",
                    bool tokenizeChineseChars = true,
                    bool stripAccents = false)    

        public static BertTokenizer Create(
                    Stream vocabStream,
                    bool doLowerCase = true,
                    bool doBasicTokenization = true,
                    bool splitOnSpecialTokens = true,
                    string unknownToken = "[UNK]",
                    string sepToken = "[SEP]",
                    string padToken = "[PAD]",
                    string clsToken = "[CLS]",
                    string maskToken = "[MASK]",
                    bool tokenizeChineseChars = true,
                    bool stripAccents = false)

        public static async Task<BertTokenizer> CreateAsync(
                    Stream vocabStream,
                    bool doLowerCase = true,
                    bool doBasicTokenization = true,
                    bool splitOnSpecialTokens = true,
                    string unknownToken = "[UNK]",
                    string sepToken = "[SEP]",
                    string padToken = "[PAD]",
                    string clsToken = "[CLS]",
                    string maskToken = "[MASK]",
                    bool tokenizeChineseChars = true,
                    bool stripAccents = false)

        /// <summary>
        /// Gets a value indicating whether the tokenizer should lowercase the input text.
        /// </summary>
        public bool DoLowerCase { get; }

        /// <summary>
        /// Gets a value indicating whether the tokenizer should do basic tokenization. Like clean text, normalize it, lowercasing, etc.
        /// </summary>
        public bool DoBasicTokenization { get; }

        /// <summary>
        /// Gets a value indicating whether the tokenizer should split on the special tokens or treat special tokens as normal text.
        /// </summary>
        public bool SplitOnSpecialTokens { get; }

        /// <summary>
        /// Gets the separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for sequence classification 
        /// or for a text and a question for question answering.
        /// It is also used as the last token of a sequence built with special tokens.
        /// </summary>
        public string SepToken { get; }

        /// <summary>
        /// Gets the separator token Id
        /// </summary>
        public int SepTokenId { get; }

        /// <summary>
        /// Gets the token used for padding, for example when batching sequences of different lengths
        /// </summary>
        public string PadToken { get; }

        /// <summary>
        /// Gets padding token Id
        /// </summary>
        public int PadTokenId { get; }

        /// <summary>
        /// Gets the classifier token which is used when doing sequence classification (classification of the whole sequence 
        /// instead of per-token classification).
        /// It is the first token of the sequence when built with special tokens.
        /// </summary>
        public string ClsToken { get; }

        /// <summary>
        /// Gets the classifier token Id
        /// </summary>
        public int ClsTokenId { get; }

        /// <summary>
        /// Gets the mask token used for masking values. This is the token used when training this model with masked language modeling.
        /// This is the token which the model will try to predict.
        /// </summary>
        public string MaskToken { get; }

        /// <summary>
        /// Gets the mask token Id
        /// </summary>
        public int MaskTokenId { get; }

        /// <summary>
        /// Gets a value indicating whether the tokenizer should split the Chinese characters into tokens.
        /// </summary>
        public bool TokenizeChineseChars { get; }

        /// <summary>
        /// Gets a value indicating whether the tokenizer should strip accents characters.
        /// </summary>
        public bool StripAccents { get; }

        public IReadOnlyList<int> EncodeToIds(string text, bool addSpecialTokens, 
                      bool considerPreTokenization = true, bool considerNormalization = true)

        public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool addSpecialTokens, 
                      bool considerPreTokenization = true, bool considerNormalization = true)

        public IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, bool addSpecialTokens, out string? normalizedText, 
                       out int charsConsumed, bool considerPreTokenization = true, bool considerNormalization = true)

        public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, bool addSpecialTokens, 
                        out string? normalizedText,  out int charsConsumed, bool considerPreTokenization = true, bool considerNormalization = true)

        /// <summary>
        /// Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating and 
        /// adding special tokens. A BERT sequence has the following format:
        ///     - single sequence: `[CLS] tokenIds0 [SEP]`
        ///     - pair of sequences: `[CLS] tokenIds0 [SEP] tokenIds1 [SEP]`
        /// </summary>
        /// <param name="tokenIds0">List of IDs to which the special tokens will be added.</param>
        /// <param name="tokenIds1">Optional second list of IDs for sequence pairs.</param>
        /// <returns>The list of IDs with special tokens added.</returns>
        /// <exception cref="ArgumentNullException">When <paramref name="tokenIds0"/> is null.</exception>
        public IReadOnlyList<int> BuildInputsWithSpecialTokens(IEnumerable<int> tokenIds0, IEnumerable<int>? tokenIds1 = null)

        public OperationStatus BuildInputsWithSpecialTokens(IEnumerable<int> tokenIds0, Span<int> buffer, out int written, 
                             IEnumerable<int>? tokenIds1 = null)

        /// <summary>
        /// Retrieve sequence tokens mask from a IDs list.
        /// </summary>
        /// <param name="tokenIds0">List of IDs.</param>
        /// <param name="tokenIds1">Optional second list of IDs for sequence pairs.</param>
        /// <param name="alreadyHasSpecialTokens">Indicate whether or not the token list is already formatted with special tokens 
        /// for the model.</param>
        /// <returns>A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.</returns>
        /// <exception cref="ArgumentNullException"></exception>
        public IReadOnlyList<int> GetSpecialTokensMask(IEnumerable<int> tokenIds0, IEnumerable<int>? tokenIds1 = null, 
                    bool alreadyHasSpecialTokens = false)

        public OperationStatus GetSpecialTokensMask(IEnumerable<int> tokenIds0, Span<int> buffer, out int written, 
                     IEnumerable<int>? tokenIds1 = null, bool alreadyHasSpecialTokens = false)

        /// <summary>
        /// Create a mask from the two sequences passed to be used in a sequence-pair classification task. 
        /// A BERT sequence pair mask has the following format:
        ///         0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
        ///         | first sequence    | second sequence |
        /// If <paramref name="tokenIds1"/> is null, this method only returns the first portion of the type ids (0s).
        /// </summary>
        /// <param name="tokenIds0">List of token IDs for the first sequence.</param>
        /// <param name="tokenIds1">Optional list of token IDs for the second sequence.</param>
        /// <returns>List of token type IDs according to the given sequence(s).</returns>
        /// <exception cref="ArgumentNullException">When <paramref name="tokenIds0"/> is null.</exception>
        public IReadOnlyList<int> CreateTokenTypeIdsFromSequences(IEnumerable<int> tokenIds0, IEnumerable<int>? tokenIds1 = null)

        public OperationStatus CreateTokenTypeIdsFromSequences(IEnumerable<int> tokenIds0, Span<int> buffer, out int written, 
                         IEnumerable<int>? tokenIds1 = null)
    }
}

PreTokenizer Factory methods

namespace Microsoft.ML.Tokenizers
{
    public abstract partial class PreTokenizer
    {
        // @"\w+|[\p{P}]"
        public static PreTokenizer CreateWhiteSpaceOrPunctuationPreTokenizer(IReadOnlyDictionary<string, int>? specialTokensEncoder = null)

        // @"\w+|[^\w\s]+"
        public static PreTokenizer CreateWordOrNonWordPreTokenizer(IReadOnlyDictionary<string, int>? specialTokensEncoder = null)

        // @"\S+"
        public static PreTokenizer CreateWhiteSpacePreTokenizer(IReadOnlyDictionary<string, int>? specialTokensEncoder = null)

    }
}
terrajobst commented 2 weeks ago

WordPieceTokenizer

namespace Microsoft.ML.Tokenizers
{
    public partial class WordPieceTokenizer : Tokenizer
    {
        public static WordPieceTokenizer Create(
            string vocabFilePath,
            PreTokenizer? preTokenizer = null,
            Normalizer? normalizer = null,
            IReadOnlyDictionary<string, int>? specialTokens = null,
            string unknownToken = "[UNK]",
            string continuingSubwordPrefix = DefaultContinuingSubwordPrefix,
            int maxInputCharsPerWord = DefaultMaxInputCharsPerWord
        );

        public static WordPieceTokenizer Create(
            Stream vocabStream,
            PreTokenizer? preTokenizer = null,
            Normalizer? normalizer = null,
            IReadOnlyDictionary<string, int>? specialTokens = null,
            string unknownToken = "[UNK]",
            string continuingSubwordPrefix = DefaultContinuingSubwordPrefix,
            int maxInputCharsPerWord = DefaultMaxInputCharsPerWord
        );

        public static async Task<WordPieceTokenizer> CreateAsync(
            Stream vocabStream,
            PreTokenizer? preTokenizer = null,
            Normalizer? normalizer = null,
            IReadOnlyDictionary<string, int>? specialTokens = null,
            string unknownToken = "[UNK]",
            string continuingSubwordPrefix = DefaultContinuingSubwordPrefix,
            int maxInputCharsPerWord = DefaultMaxInputCharsPerWord,
            CancellationToken cancellationToken = default
        );

        public string UnknownToken { get; }
        public int UnknownTokenId { get; }
        public string ContinuingSubwordPrefix { get; }
        public int MaxInputCharsPerWord { get; }

        public IReadOnlyDictionary<string, int>? SpecialTokens { get; }
        public string Decode(IEnumerable<int> ids, bool skipSpecialTokens);

        public OperationStatus Decode(
            IEnumerable<int> ids,
            Span<char> destination,
            bool skipSpecialTokens, 
            out int idsConsumed,
            out int charsWritten
        );
    }
}

BertTokenizer

namespace Microsoft.ML.Tokenizers;

public sealed partial class BertTokenizer : WordPieceTokenizer
{
    public static BertTokenizer Create(
        string vocabFilePath,
        bool doLowerCase = true,
        bool doBasicTokenization = true,
        bool splitOnSpecialTokens = true,
        string unknownToken = "[UNK]",
        string sepToken = "[SEP]",
        string padToken = "[PAD]",
        string clsToken = "[CLS]",
        string maskToken = "[MASK]",
        bool tokenizeChineseChars = true,
        bool stripAccents = false
    );  

    public static BertTokenizer Create(
        Stream vocabStream,
        bool doLowerCase = true,
        bool doBasicTokenization = true,
        bool splitOnSpecialTokens = true,
        string unknownToken = "[UNK]",
        string sepToken = "[SEP]",
        string padToken = "[PAD]",
        string clsToken = "[CLS]",
        string maskToken = "[MASK]",
        bool tokenizeChineseChars = true,
        bool stripAccents = false
    );

    public static async Task<BertTokenizer> CreateAsync(
        Stream vocabStream,
        bool doLowerCase = true,
        bool doBasicTokenization = true,
        bool splitOnSpecialTokens = true,
        string unknownToken = "[UNK]",
        string sepToken = "[SEP]",
        string padToken = "[PAD]",
        string clsToken = "[CLS]",
        string maskToken = "[MASK]",
        bool tokenizeChineseChars = true,
        bool stripAccents = false
    );

    public bool DoLowerCase { get; }
    public bool DoBasicTokenization { get; }
    public bool SplitOnSpecialTokens { get; }
    public string SepToken { get; }
    public int SepTokenId { get; }
    public string PadToken { get; }
    public int PadTokenId { get; }
    public string ClsToken { get; }
    public int ClsTokenId { get; }
    public string MaskToken { get; }
    public int MaskTokenId { get; }
    public bool TokenizeChineseChars { get; }
    public bool StripAccents { get; }

    public IReadOnlyList<int> EncodeToIds(
        string text,
        bool addSpecialTokens, 
        bool considerPreTokenization = true,
        bool considerNormalization = true
    );

    public IReadOnlyList<int> EncodeToIds(
        ReadOnlySpan<char> text,
        bool addSpecialTokens, 
        bool considerPreTokenization = true,
        bool considerNormalization = true
    );

    public IReadOnlyList<int> EncodeToIds(
        string text,
        int maxTokenCount,
        bool addSpecialTokens,
        out string? normalizedText,        
        out int charsConsumed,
        bool considerPreTokenization = true,
        bool considerNormalization = true
    );

    public IReadOnlyList<int> EncodeToIds(
        ReadOnlySpan<char> text,
        int maxTokenCount,
        bool addSpecialTokens,        
        out string? normalizedText,
        out int charsConsumed,
        bool considerPreTokenization = true,
        bool considerNormalization = true
    );

    public IReadOnlyList<int> BuildInputsWithSpecialTokens(
        IEnumerable<int> tokenIds0,
        IEnumerable<int>? tokenIds1 = null
    );

    public OperationStatus BuildInputsWithSpecialTokens(
        IEnumerable<int> tokenIds0,
        Span<int> buffer,
        out int written,        
        IEnumerable<int>? tokenIds1 = null
    );

    public IReadOnlyList<int> GetSpecialTokensMask(
        IEnumerable<int> tokenIds0,
        IEnumerable<int>? tokenIds1 = null, 
        bool alreadyHasSpecialTokens = false
    );

    public OperationStatus GetSpecialTokensMask(
        IEnumerable<int> tokenIds0,
        Span<int> buffer,
        out int written,        
        IEnumerable<int>? tokenIds1 = null,
        bool alreadyHasSpecialTokens = false
    );

    public IReadOnlyList<int> CreateTokenTypeIdsFromSequences(
        IEnumerable<int> tokenIds0,
        IEnumerable<int>? tokenIds1 = null
    );

    public OperationStatus CreateTokenTypeIdsFromSequences(
        IEnumerable<int> tokenIds0,
        Span<int> buffer,
        out int written,        
        IEnumerable<int>? tokenIds1 = null
    );
}

PreTokenizer Factory methods

namespace Microsoft.ML.Tokenizers;

public partial class PreTokenizer
{
    public static PreTokenizer CreateWordOrPunctuation(IReadOnlyDictionary<string, int>? specialTokens = null);
    public static PreTokenizer CreateWordOrNonWord(IReadOnlyDictionary<string, int>? specialTokens = null);
    public static PreTokenizer CreateWhitespace(IReadOnlyDictionary<string, int>? specialTokens = null);
}
terrajobst commented 2 weeks ago

From @tarekgh

  • TokenizeChineseChars:
    • We are going to cover the some ranges in CJK but not all ranges. So I suggest to use the name IndividuallyTokenizeCjk. The doc will list the exact ranges we use.
  • stripAccent
    • I discovered we need to remove only the non-spacing marks. so I suggest naming it as removeNonSpacingMarks
  • doLowerCase, doBasicTokenization
    • I suggest lowerCaseBeforeTokenization and applyBasicTokenization. I am fine with applyLowerCasing if Jeremy is fine with it
  • Using options
    • I am seeing the polymorphism is not bad idea after all with options. I am suggesting the following:

Incorporating both our feedback above and the options proposal, I believe the totality of the API looks as below.

namespace Microsoft.ML.Tokenizers;

public class WordPieceOptions
{
    public WordPieceOptions();
    public PreTokenizer? PreTokenizer { get; set; }
    public Normalizer? Normalizer { get; set; }
    public IReadOnlyDictionary<string, int>? SpecialTokens  { get; set; }
    public string UnknownToken  { get; set; } = "[UNK]";
    public string ContinuingSubwordPrefix  { get; set; } = DefaultContinuingSubwordPrefix;
    public int MaxInputCharsPerWord  { get; set; } = DefaultMaxInputCharsPerWord;
}

public sealed class BertOptions : WordPieceOptions
{
    public BertOptions();
    public bool LowerCaseBeforeTokenization  { get; set; } = true;
    public bool ApplyBasicTokenization  { get; set; } = true;
    public bool SplitOnSpecialTokens  { get; set; } = true;
    public string SeparatorToken { get; set; } = "[SEP]";
    public string PaddingToken { get; set; } = "[PAD]";
    public string ClassificationToken { get; set; } = "[CLS]";
    public string MaskingToken { get; set; } = "[MASK]";a
    public bool IndividuallyTokenizeCjk { get; set; } = true;
    public bool RemoveNonSpacingMarks { get; set; };
}

public partial class WordPieceTokenizer : Tokenizer
{
    public static WordPieceTokenizer Create(string vocabFilePath, WordPieceOptions? options = null);
    public static WordPieceTokenizer Create(Stream vocabStream, WordPieceOptions? options = null);
    public static Task<WordPieceTokenizer> CreateAsync(Stream vocabStream, WordPieceOptions? options = null, CancellationToken cancellationToken = default);

    public string UnknownToken { get; }
    public int UnknownTokenId { get; }
    public string ContinuingSubwordPrefix { get; }
    public int MaxInputCharsPerWord { get; }

    public IReadOnlyDictionary<string, int>? SpecialTokens { get; }
    public string Decode(IEnumerable<int> ids, bool skipSpecialTokens);

    public OperationStatus Decode(
        IEnumerable<int> ids,
        Span<char> destination,
        bool skipSpecialTokens, 
        out int idsConsumed,
        out int charsWritten
    );
}

public sealed partial class BertTokenizer : WordPieceTokenizer
{
    public static BertTokenizer Create(string vocabFilePath, BertOptions? options = null);
    public static BertTokenizer Create(Stream vocabStream, BertOptions? options = null);
    public static Task<BertTokenizer> CreateAsync(Stream vocabStream, BertOptions? options = null, CancellationToken cancellationToken = default);

    public bool LowerCaseBeforeTokenization { get; }
    public bool ApplyBasicTokenization { get; }
    public bool SplitOnSpecialTokens { get; }
    public string SeparatorToken { get; }
    public int SeparatorTokenId { get; }
    public string PaddingToken { get; }
    public int PaddingTokenId { get; }
    public string ClassificationToken { get; }
    public int ClassificationTokenId { get; }
    public string MaskingToken { get; }
    public int MaskingTokenId { get; }
    public bool IndividuallyTokenizeCjk { get; }
    public bool RemoveNonSpacingMarks { get; }

    public IReadOnlyList<int> EncodeToIds(
        string text,
        bool addSpecialTokens, 
        bool considerPreTokenization = true,
        bool considerNormalization = true
    );

    public IReadOnlyList<int> EncodeToIds(
        ReadOnlySpan<char> text,
        bool addSpecialTokens, 
        bool considerPreTokenization = true,
        bool considerNormalization = true
    );

    public IReadOnlyList<int> EncodeToIds(
        string text,
        int maxTokenCount,
        bool addSpecialTokens,
        out string? normalizedText,        
        out int charsConsumed,
        bool considerPreTokenization = true,
        bool considerNormalization = true
    );

    public IReadOnlyList<int> EncodeToIds(
        ReadOnlySpan<char> text,
        int maxTokenCount,
        bool addSpecialTokens,        
        out string? normalizedText,
        out int charsConsumed,
        bool considerPreTokenization = true,
        bool considerNormalization = true
    );

    public IReadOnlyList<int> BuildInputsWithSpecialTokens(
        IEnumerable<int> tokenIds,
        IEnumerable<int>? additionalTokenIds = null
    );

    public OperationStatus BuildInputsWithSpecialTokens(
        IEnumerable<int> tokenIds,
        Span<int> destination,
        out int valuesWritten,        
        IEnumerable<int>? additionalTokenIds = null
    );

    public IReadOnlyList<int> GetSpecialTokensMask(
        IEnumerable<int> tokenIds,
        IEnumerable<int>? additionalTokenIds = null, 
        bool alreadyHasSpecialTokens = false
    );

    public OperationStatus GetSpecialTokensMask(
        IEnumerable<int> tokenIds,
        Span<int> destination,
        out int valuesWritten,        
        IEnumerable<int>? additionalTokenIds = null,
        bool alreadyHasSpecialTokens = false
    );

    public IReadOnlyList<int> CreateTokenTypeIdsFromSequences(
        IEnumerable<int> tokenIds,
        IEnumerable<int>? additionalTokenIds = null
    );

    public OperationStatus CreateTokenTypeIdsFromSequences(
        IEnumerable<int> tokenIds,
        Span<int> destination,
        out int valuesWritten,        
        IEnumerable<int>? additionalTokenIds = null
    );
}

public partial class PreTokenizer
{
    public static PreTokenizer CreateWordOrPunctuation(IReadOnlyDictionary<string, int>? specialTokens = null);
    public static PreTokenizer CreateWordOrNonWord(IReadOnlyDictionary<string, int>? specialTokens = null);
    public static PreTokenizer CreateWhitespace(IReadOnlyDictionary<string, int>? specialTokens = null);
}
tarekgh commented 1 week ago

https://github.com/dotnet/machinelearning/pull/7291