dotnet / machinelearning

ML.NET is an open source and cross-platform machine learning framework for .NET.
https://dot.net/ml
MIT License
8.92k stars 1.86k forks source link

Introducing Tiktoken Tokenizer #6981

Closed tarekgh closed 4 months ago

tarekgh commented 5 months ago

This modification introduces support for the Tiktoken tokenizer into the Microsoft ML tokenizers library. The logic is largely derived from the Microsoft Tokenizers Library, and the update includes optimizations and adjustments to the public APIs. Further refinements for the APIs are pending and are being tracked through issue #6982.

Usage

Tokenizer tokenizer = await Tokenizer.CreateByModelNameAsync("gpt-4");

    // Encoding to Ids
    string text = "Hello World";
    IReadOnlyList<int> encoded = tokenizer.EncodeToIds(text);
    Assert.Equal(new List<int>() { 9906, 4435 }, encoded);
    Assert.Equal(text, tokenizer. Decode(encoded)!);

    // Full encoding to tokens, Ids, and offsets
    TokenizerResult result = tokenizer.Encode(text);
    Assert.Equal(new List<int>() { 9906, 4435 }, result.Ids);
    Assert.Equal(new string[] { "Hello", " World" }, result.Tokens);
    Assert.Equal(new List<(int, int)> { (0, 5), (5, 11) }, result.Offsets);

APIs changes


namespace Microsoft.ML.Tokenizers
{
    public class Tokenizer
    {
+        /// <summary>
+        /// Encodes input text to object has the tokens list, tokens Ids, tokens offset mapping.
+        /// </summary>
+        /// <param name="sequence">The text to tokenize.</param>
+        /// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the encoding.</param>
+        /// <returns>The tokenization result includes the tokens list, tokens Ids, tokens offset mapping.</returns>
+        public TokenizerResult Encode(string sequence, bool skipSpecialTokens); // overload adding skipSpecialTokens parameter.

+        /// <summary>
+        /// Encodes input text to tokens Ids.
+        /// </summary>
+        /// <param name="sequence">The text to tokenize.</param>
+        /// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the encoding.</param>
+        /// <returns>The tokenization result includes the tokens list, tokens Ids, tokens offset mapping.</returns>
+        public IReadOnlyList<int> EncodeToIds(string sequence, bool skipSpecialTokens = false);

+        /// <summary>
+        /// Create tokenizer based on model name
+        /// </summary>
+        /// <param name="modelName">Model name</param>
+        /// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the model</param>
+        /// <param name="normalizer">To normalize the text before tokenization</param>
+        /// <returns>The tokenizer</returns>
+        public static async Task<Tokenizer> CreateByModelNameAsync(
+                                                string modelName,
+                                                IReadOnlyDictionary<string, int>? extraSpecialTokens = null,
+                                                Normalizer? normalizer = null)
    }

-    public class Split : IEquatable<Split>
+    public readonly struct Split : IEquatable<Split>
     {
-        public Split(string token, (int Index, int End) offset)
+        public Split(string token, (int Index, int End) offset, bool isSpecialToken = false)

+        /// <summary>
+        /// Gets if the current Split is a special token.
+        /// </summary>
+        public bool IsSpecialToken { get; }
    }

    public abstract class PreTokenizer
    {
+        // Primarily focused on optimizing to minimize memory allocations and enable the enumeration of one item at a time,
+        // rather than holding a large list in a collection.
+        // This change will reflect in all public classes which implementing this interface.
-        public abstract IReadOnlyLIst<Split> PreTokenize(string sentence);
+        public abstract IEnumerable<Split> PreTokenize(string sentence, bool skipSpecialTokens = false);
    }

    public sealed class TokenizerResult
    {
-        public TokenizerResult(string originalString, string normalizedString, IReadOnlyList<Split> splits, bool offsetsMappedToOriginalString);
+        public TokenizerResult(string originalString, string normalizedString, IEnumerable<Split> splits, bool offsetsMappedToOriginalString);
    }

    public abstract class Model
    {
+        public virtual IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialToken); // overload to add isSpecialToken parameter.

+        public virtual bool TokenizeToIds(string sequence, bool isSpecialToken, List<int> accumulatedIds); // To be consumed by Tokenizer.EncodeToIds

+        public virtual int? TokenToId(string token, bool skipSpecialTokens); // overload to add isSpecialToken parameter.
   }

+    public sealed class Tiktoken : Model
+    {
+        public Tiktoken(string tikTokenBpeFile, IReadOnlyDictionary<string, int>? specialTokensEncoder = null, int cacheSize = DefaultCacheSize);
+        public Tiktoken(Stream tikTokenBpeFileStream, IReadOnlyDictionary<string, int>? specialTokensEncoder = null, int cacheSize = DefaultCacheSize);

+        public IReadOnlyDictionary<string, int>? SpecialTokens { get; }

+        // Implement the Model abstract methods
+    }

+   public sealed class TikTokenPreTokenizer : PreTokenizer
+   {
+       public TikTokenPreTokenizer(string regexPattern, IReadOnlyDictionary<string, int>? specialTokensEncoder);

+       // Implement the Model abstract methods
+   }
tarekgh commented 5 months ago

CC @ericstj @luisquintanilla @stephentoub @LittleLittleCloud @shengyfu

codecov[bot] commented 5 months ago

Codecov Report

Attention: 210 lines in your changes are missing coverage. Please review.

Comparison is base (902102e) 68.80% compared to head (35e2cbc) 68.81%.

:exclamation: Current head 35e2cbc differs from pull request most recent head 4cd96b3. Consider uploading reports for the commit 4cd96b3 to get more accurate results

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #6981 +/- ## ========================================== + Coverage 68.80% 68.81% +0.01% ========================================== Files 1249 1256 +7 Lines 249686 250425 +739 Branches 25485 25569 +84 ========================================== + Hits 171795 172335 +540 - Misses 71294 71466 +172 - Partials 6597 6624 +27 ``` | [Flag](https://app.codecov.io/gh/dotnet/machinelearning/pull/6981/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=dotnet) | Coverage Δ | | |---|---|---| | [Debug](https://app.codecov.io/gh/dotnet/machinelearning/pull/6981/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=dotnet) | `68.81% <72.62%> (+0.01%)` | :arrow_up: | | [production](https://app.codecov.io/gh/dotnet/machinelearning/pull/6981/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=dotnet) | `63.28% <66.87%> (+0.01%)` | :arrow_up: | | [test](https://app.codecov.io/gh/dotnet/machinelearning/pull/6981/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=dotnet) | `88.44% <100.00%> (+0.02%)` | :arrow_up: | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=dotnet#carryforward-flags-in-the-pull-request-comment) to find out more. | [Files](https://app.codecov.io/gh/dotnet/machinelearning/pull/6981?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=dotnet) | Coverage Δ | | |---|---|---| | [...Microsoft.ML.Tokenizers/PreTokenizer/Whitespace.cs](https://app.codecov.io/gh/dotnet/machinelearning/pull/6981?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=dotnet#diff-c3JjL01pY3Jvc29mdC5NTC5Ub2tlbml6ZXJzL1ByZVRva2VuaXplci9XaGl0ZXNwYWNlLmNz) | `100.00% <100.00%> (ø)` | | | [src/Microsoft.ML.Tokenizers/TokenizerResult.cs](https://app.codecov.io/gh/dotnet/machinelearning/pull/6981?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=dotnet#diff-c3JjL01pY3Jvc29mdC5NTC5Ub2tlbml6ZXJzL1Rva2VuaXplclJlc3VsdC5jcw==) | `100.00% <100.00%> (+9.09%)` | :arrow_up: | | [...Microsoft.ML.Tokenizers.Tests/PreTokenizerTests.cs](https://app.codecov.io/gh/dotnet/machinelearning/pull/6981?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=dotnet#diff-dGVzdC9NaWNyb3NvZnQuTUwuVG9rZW5pemVycy5UZXN0cy9QcmVUb2tlbml6ZXJUZXN0cy5jcw==) | `95.31% <100.00%> (ø)` | | | [test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs](https://app.codecov.io/gh/dotnet/machinelearning/pull/6981?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=dotnet#diff-dGVzdC9NaWNyb3NvZnQuTUwuVG9rZW5pemVycy5UZXN0cy9UaXRva2VuVGVzdHMuY3M=) | `100.00% <100.00%> (ø)` | | | [...rc/Microsoft.ML.Tokenizers/PreTokenizer/Roberta.cs](https://app.codecov.io/gh/dotnet/machinelearning/pull/6981?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=dotnet#diff-c3JjL01pY3Jvc29mdC5NTC5Ub2tlbml6ZXJzL1ByZVRva2VuaXplci9Sb2JlcnRhLmNz) | `57.14% <33.33%> (-19.79%)` | :arrow_down: | | [...c/Microsoft.ML.Tokenizers/Utils/BytePairEncoder.cs](https://app.codecov.io/gh/dotnet/machinelearning/pull/6981?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=dotnet#diff-c3JjL01pY3Jvc29mdC5NTC5Ub2tlbml6ZXJzL1V0aWxzL0J5dGVQYWlyRW5jb2Rlci5jcw==) | `95.23% <95.23%> (ø)` | | | [...crosoft.ML.Tokenizers/PreTokenizer/PreTokenizer.cs](https://app.codecov.io/gh/dotnet/machinelearning/pull/6981?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=dotnet#diff-c3JjL01pY3Jvc29mdC5NTC5Ub2tlbml6ZXJzL1ByZVRva2VuaXplci9QcmVUb2tlbml6ZXIuY3M=) | `83.33% <81.48%> (-7.58%)` | :arrow_down: | | [...Microsoft.ML.Tokenizers/Utils/ByteArrayComparer.cs](https://app.codecov.io/gh/dotnet/machinelearning/pull/6981?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=dotnet#diff-c3JjL01pY3Jvc29mdC5NTC5Ub2tlbml6ZXJzL1V0aWxzL0J5dGVBcnJheUNvbXBhcmVyLmNz) | `65.38% <65.38%> (ø)` | | | [src/Microsoft.ML.Tokenizers/Model/Model.cs](https://app.codecov.io/gh/dotnet/machinelearning/pull/6981?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=dotnet#diff-c3JjL01pY3Jvc29mdC5NTC5Ub2tlbml6ZXJzL01vZGVsL01vZGVsLmNz) | `7.69% <7.69%> (ø)` | | | [src/Microsoft.ML.Tokenizers/Utils/LruCache.cs](https://app.codecov.io/gh/dotnet/machinelearning/pull/6981?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=dotnet#diff-c3JjL01pY3Jvc29mdC5NTC5Ub2tlbml6ZXJzL1V0aWxzL0xydUNhY2hlLmNz) | `66.66% <66.66%> (ø)` | | | ... and [3 more](https://app.codecov.io/gh/dotnet/machinelearning/pull/6981?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=dotnet) | | ... and [3 files with indirect coverage changes](https://app.codecov.io/gh/dotnet/machinelearning/pull/6981/indirect-changes?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=dotnet)