Closed Yikai-Liao closed 5 months ago
Sounds nice, waiting for the translated document and more benchmarks!
Sounds nice, waiting for the translated document and more benchmarks!
I will finish the blog translation as soon as possible. Also, do you have any suggestions about more benchmarks? I'm not quite sure what kind of testing I should do, and the current code interface doesn't exactly align with the implementation in tokenizer (in fact, this code is from an assignment in my nlp class)
Would try with different hardware, different datasets and check this benchmark
Note that This document is not a one-to-one translation of the original one.
This one mainly focuses on implementation details and ignores some background knowledge of BPE algorithm.
The BPE algorithm, as an unsupervised word partitioning algorithm, is widely used in the field of NLP, especially in large-scale language modeling represented by GPT (using its variant Byte Level BPE). However, most of the current open-source implementations of the algorithm have low training efficiency for Chinese.
This is mainly because, although in the original algorithm, the input of BPE is the whole sentence, for Latin languages such as English, a certain degree of approximation can be made to significantly reduce the time complexity of the algorithm, that is, pre-split words according to the space first, and then for each word to perform the BPE algorithm.
It is called an approximation because it forbids the merging of words across spaces in the original algorithm, but for English, it works well in most cases. Most importantly, it compresses the worst case $O(N ^2)$ time complexity of the original algorithm to $O(M^2)$. Here, N refers to the length of the sequence, and M refers to the length of the word. (See the blog Byte Pair Encoding and Data Structures ^1 for a detailed complexity analysis.) Since $M \ll N$ in English, this approach works very well.
But this approximation does not perform well in all languages. Obviously it doesn't work in languages like Chinese, and Japanese, which can't preseparate words according to spaces. And, even in Latin, there are languages like German that have quite long word lengths, making the precondition that this approximation can significantly optimize time complexity not true.
So here, I have implemented an, optimized version of the BPE algorithm without approximation. Even using a pure Python implementation, this is substantially better in terms of speed and memory footprint than the version implemented in Hugging Face Tokenizer ^2 using Rust. Note that the version implemented in the Tokenizer is not the original $O(N^2)$ complexity version, it is also optimized. Here's a time comparison [^time]:
Implementation | user time | system time | total time | cpu |
---|---|---|---|---|
My version (Single Thread) | 2.70s | 0.05s | 2.761s | 99% |
Tokenizer (Single Thread) | 5.51s | 1.60s | 5.411s | 131% |
Tokenizer (Multi Threads) | 8.51s | 3.52s | 2.849s | 422% |
The biggest problem with the original BPE algorithm is that after merging merely one token pair at a time, all the data need to be recounted. Especially for a language like Chinese, which has a lot of symbols itself, each modification actually has little effect on the previous round of statistics. Moreover, each modification also requires traversing the whole corpus, in which a lot of time is wasted on retrieval.
Therefore, a more ideal way is to modify only the data that need to be modified each time, and after the modification is completed, it can be directly used in the next round to continue to find and merge the best symbol pairs.
The implementation in HuggingFace Tokenizer also follows this principle. However, the effeciency could be furthur improved.
In order to make it possible to modify the parts that need to be modified without redundant retrieval operations, we essentially need to solve three core problems:
Here, I'll start by listing the data structures I chose:
pair_freq_queue
Dict[Tuple[str, str], List[int]]
[^set] to hold the starting positions of all current pairs - word_pair_pos
.uint8
array equal to the length of the string to represent the current merge status - seg_status
Dict[Tuple[str, str], int]
to represent the number of valid starting positions in word_pair_pos
at the current merge status - word_pair_len
Using a prioritized queue to find the highest-frequency symbol pairs is a very natural choice. However, there are different options for the information stored in the priority queue.
In Tokenizer, the priority queue stores each specific token pairs (a token pair with one position information), contributing to quite a long queue. Every time we pop a token pair from queue, we merely merge one specific token pairs.
In my implementation, I just stores the total frequency information of a token pair (like "ba" and "na" composing token pairs for 100 times in the corpus) in the queue. And then we can get all the positions of this token pair from a hashmap, so that we could process all these data at one time (without poping new token pairs from the priority queue)
I notice that there is a similar strategy for checking if the frequency of a token pair is valid (Check that the frequencies in the priority queue are consistent with the current statistics). A schematic code is as follows:
while True:
cached_freq, pair = pair_freq_queue.pop()
cur_freq = -word_pair_len[pair]
if cached_freq == cur_freq:
break
else:
pair_freq_queue.push((cur_freq, pair))
In my implementation, I don't use a Word class to represent merged token pairs. Instead, I use an array of uint8 to represent the merging status.
seg_status
is initialized to an all-1 vector.For a string "apple", a possible merging process might be as follows:
a p p l e
1. [1, 1, 1, 1, 1] init
2. [1, 1, 2, 2, 1] merge p + l
3. [1, 1, 3, 0, 3] merge pl + e
4. [2, 2, 3, 0, 3] merge a + p
5. [5, 0, 0, 0, 5] merge ap + ple
The effect of this representation is:
assuming we find a starting position of 2 for the symbol p
, we can access the value of this position in seg_status
, and determine whether it is legal in the current merge state by whether this value is equal to the length of the symbol.
Once we have a legal symbol, we can $O(1)$ find out what its neighboring words are in the current merge status. For example, for the state of the second line, we look up the neighboring symbols of the symbol pl
, which starts at position 2:
token | start position | length |
---|---|---|
this token | 2 | seg_status[2] |
previous token | 2-seg_status[2-1] |
seg_status[2-1] |
next token | 2+seg_status[2] |
seg_status[2+seg_status[2]] |
It is easy to see that under this mechanism, even if a word has a length of 1 such that its start and end positions overlap, the same process can be used to query the words before and after it.
Thus the data in seg_status
actually provides us with two efficient auxiliary functions when merging symbol pairs:
At initialization stage, the starting positions of all pairs of two-by-two symbols are counted, stored in a dictionary, and the length of each list at that point is assigned to word_pair_len
. If the number of categories of pairs is V and the number of tokens is N, then the space complexity of word_pair_pos
is $O(N + V)$ and the space complexity of word_pair_len
is $O(V)$.
Next, the process for each merge is as follows (special cases will be discussed later):
take the starting positions of all symbol pairs and filter out the illegal ones using seg_status
.
use seg_status
to locate the neighboring words before and after all the positions, which may be denoted as (pre_word, word_a, word_b, nxt_word)
, where (word_a, word_b)
is the current merged pair, and the merged symbol is denoted as word_comb
.
change word_pair_len
so that it always matches the current merge case
word_pair_len[(pre_word, word_a)] -= 1
word_pair_len[(word_b, nxt_word)] -= 1
word_pair_len[(pre_word, word_comb)] += 1
word_pair_len[(word_comb, nxt_word)] += 1
change seg_status
so that it indicates the current merge status by noting the starting position of a legal symbol pair to be merged as i
, as follows [^merge]:
seg_status[i] = len(word_comb)
seg_status[i+len(word_a)] = 0
seg_status[i+len(word_comb)] = len(word_comb)
count the starting positions of the new symbol pairs into the new new_pair_pos
:
new_pair_pos[(pre_word, word_comb)].append(i - len(pre_word))
new_pair_pos[(word_comb, nxt_word)].append(i)
After all modifications have been made, update the information in new_pair_pos
to word_pair_pos
and word_pair_len
and press it into the priority queue. [^sync]
During the merge process, there are two special cases that arise and need to be handled, they are:
A B A B
-> AB AB
0 0 0 0
-> 00 00
In essence, the second special case, is also a special case of the first.
Its modification for processing is also simple, viz:
if the first two words happen to be (word_a, word_b)
, then the added symbol pair should be (word_comb, word_comb)
and not (pre_word, word_comb)
if the next two words happen to be (word_a, word_b)
, then you don't need to add the next symbol pair, because it's already been added in this pair
in case of word_a == word_b
, then I'd prefer to match from back to front, e.g. 1 0 0 0
-> 1 0 00
. Of course, you can also prefer to match from front to back, but I don't think that 1 00 0
is as good as the previous one. I didn't really look into how this is handled in HugingFace Tokenizer.
For the above process, let's say we merge off l
valid positions in a round, then the total length of the array I add will be 2 * l
in general, which makes the memory footprint grow rapidly. We don't really need that much space, because many of the values in word_pair_pos
are invalid.
At this point, a memory compression mechanism can be introduced to control memory growth. The principle is also very simple, just need to check each time the priority queue, determine the cur_freq
and word_pair_pos
in the storage array length of the ratio, whether the threshold can be reached. If it reaches the threshold, it means that the legal starting position of the array is already very high, so you can filter the array according to seg_status
and free up a lot of memory. This will keep our memory footprint at a more or less stable level.
Compared to the original BPE training process, space for time is unavoidable to achieve our optimization goal, but with the memory compression mechanism, the final space complexity can still be maintained on the order of $O(N)$ with a small constant. In my python implementation, the memory consumed by data outside the string is about 2-3 times the space occupied by the string itself, which is better than the single-threaded version of Tokenizer, based on observations of memory usage in background processes. In particular, the memory footprint of Tokenizer's BPE training algorithm increases several times when multi-threading is enabled.
[^time]: Timing is done using the time utility on the Linux command line, which includes the read time, and the file read operations are unified in python. The python version is 3.9, the cpu is i7 10875H and the Linux kernel is 6.2.13-zen-1-zen.
[^set]: Note that it is possible to use a set (Hash Set or BTree Set) to represent the starting position of all pairs of symbols in word_pair_pos
, but using a set has a significantly higher memory overhead than an array (especially in the case of python's native sets), which results in a very high constant in space complexity for the algorithm. This makes the space complexity of the whole algorithm very high, so that I couldn't complete the training on a 1GB Chinese wiki corpus with 32G RAM + 32G Swap. By using array.array('I')
for storage, the training can be done with a total of about 5GB of memory. The reason for not using np.array
is that there is no need for vectorization and the access time overhead of np.array
is higher than that of native arrays.
[^freq]: The native prioritized oppositions in python can only be rootlet heaps, so the reference code sets the negative frequency to the queue priority
[^merge]: Here the order of steps 2 and 3 cannot be replaced, because for len(word_b) == 1
, swapping the order would mean that the value at the end would be assigned 0, not len(word_comb)
, which doesn't fit with our rule
[^sync]: The reason we don't modify word_pair_len
and word_pair_pos
directly during traversal is because we can minimize access to word_pair_len
by not having to do +=1
operations all the time. Also, it's easier to filter the additions with min_freq
.
Would try with different hardware, different datasets and check this benchmark
If I understand correctly, I need to train BPE on big.txt with diffrent hardware?
I test it on my macbookpro (Intel i7-7700HQ (8) @ 2.80GHz, 16GB RAM), using %%timeit
in jupyter notebook.
vocab_size = int(1e5) min_freq = 10
Note that how to use multiple threads in my implementation is still to be discussed.
By the way, the blog is translated with the help of deepl (not directly translate throgh deepl). So it might not be very natural.
hi,any further discussions?
Hey! Sorry but I'm a bit low on bandwidth I need to read the blogpost and take some time to check this out! 🚀
Very exciting otherwise ! 🤗
Have not had the time yet sorry
This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.
@ArthurZucker Is there any other helps I can offer?
Actually if you could open a PR it would be amazing! 🤗
I will give it a try
I got a problem to deal with the max token length. In my implementation, I use a vector of u8
(it will use the same space of memory as the original corpus) to store the length of bytes of each token, which means the max token length should be less than 256.
It works in most cases. Even in some very demanding situations, I think u16
is adequate (use twice of the memory). But in tokenizer's original implementation, it use usize
to store the max_token_length
. For a 64-bit machine, this means an 8x memory overhead.
So I'm asking what I should do about this. @ArthurZucker
Memory should not be that much of a problem so would keep usize. Or does it affect speed too much?
Theoretically there is little difference in speed.
But if you consider GB level corpus data, 8 times the memory overhead of the text size is still something to worry about, I think. If the server doesn't have enough memory, this is likely to allow a lot of data to go into swap, which can significantly impact performance.
Also, I'm not quite sure how tokenizer does parallelism now, e.g. how to get a mutex lock on some global data. So I'll implement the single-threaded version in my fork first.
Or, we could ge a runtime dispatch for the dtype of it according to the max_token_length, and let users decide which to use. Oh, by the way, whether max_token_length is in bytes or unicode characters?
@ArthurZucker Progress is much faster than I thought it would be. I've now passed the 3 built-in test cases. After tidying up the code and adding comments, I'll upload it to my fork first. Further parallel optimizations will follow, as well as adding dropout support.
But I think we need to add some special case tests, like how to handle strings like "00000000000".
In the process of modifying the BPE training code in the tokenizer, I feel like I've found the main reason that slowed down the original implementation.
Word's merge method is inefficiently implemented. Frequent remove and insert operations on a large number of vectors are very costly.
Let the number of symbols in a word be n, and the number of token pairs hit be m. The complexity of calling the merge function will be $O(M \cdot N)$, and for the worst case scenario (a sentence that keeps repeating a single character) it will have $O(N^2)$ complexity.
Luckily tokenizer has a better pre tokenization mechanism that makes N smaller and keeps the overall complexity still at a manageable level.
But after my detailed comparison today, I think the merge method here is still the main difference between the original implementation and mine. In my implementation, for each position of the token pair, the modification requires only O(1) complexity, and for this position for a "Word", the complexity of the merge operation is reduced to O(M).
Moreover, changing Vector to List still doesn't reduce the complexity of the original implementation to O(M), because each call to merge requires traversing the entire vector of symbols, an operation that implies O(N) complexity itself.
pub(super) fn merge(
&mut self,
c1: u32,
c2: u32,
replacement: u32,
max_length: usize,
) -> Vec<(Pair, i32)> {
let mut changes: Vec<(Pair, i32)> = vec![];
let mut i = 0;
loop {
if i >= self.symbols.len() {
break;
}
// Found a pair
if self.symbols[i].c == c1 && i + 1 < self.symbols.len() && self.symbols[i + 1].c == c2
{
let first = self.symbols[i];
let second = self.symbols[i + 1];
// Remove in place
let new_s = Symbol {
c: replacement,
prev: first.prev,
next: second.next,
len: first.len + second.len,
};
// If there are other characters before the pair
if i > 0 {
changes.push(((self.symbols[i - 1].c, first.c), -1));
if self.symbols[i - 1].len + new_s.len < max_length {
changes.push(((self.symbols[i - 1].c, replacement), 1));
}
}
self.symbols.insert(i, new_s); // Insert replacement before first char of pair
self.symbols.remove(i + 1); // Remove first char of pair
self.symbols.remove(i + 1); // And then the second
// If there are other characters after the pair
if i < self.symbols.len() - 1 {
changes.push(((second.c, self.symbols[i + 1].c), -1));
if self.symbols[i + 1].len + new_s.len < max_length {
changes.push(((replacement, self.symbols[i + 1].c), 1));
}
}
}
i += 1;
}
changes
}
Wow that sounds great. If we can just modify that single function would be pretty impressive! I did not implement any of that so I'll have to dive a bit one the PR is opened! 🤗 thanks a lot already
I found out about this article after reading on tiktoken
library's README that their code offers much faster inference speeds and started investigating the issues with the Hugging Face's library. I know that this issue regards training a tokenizer, but still looks really exciting. I'm still trying to understand the details fully though, but in the meantime, I'm curious what the status of this is, has it been implemented and merged? Is there anything that I could try and help with?
I found out about this article after reading on
tiktoken
library's README that their code offers much faster inference speeds and started investigating the issues with the Hugging Face's library. I know that this issue regards training a tokenizer, but still looks really exciting. I'm still trying to understand the details fully though, but in the meantime, I'm curious what the status of this is, has it been implemented and merged? Is there anything that I could try and help with?
@AugustasMacijauskas Thank you for your attention. I have written in detail about the improvement method in my previous reply. However, I found it still a bit difficult to add it to the existing interface of tokenizers. There are two main problems:
The algorithm I proposed before limits the maximum length of a single token to 256 chars, which does not meet the interface semantic requirements (and requires some modification to the design of the algorithm I proposed)
The tokenizer's bpe algorithm includes support for adding prefixes and suffixes to consecutive characters, which I hadn't considered before and haven't figured out how to do it.
Since I have limited energy at the moment, I haven't made much progress yet. Of course, you can take a look at my original repository (albeit with very few comments) to help you understand my implementation. Feel free to raise an issue under my repository if you have any questions!
oh,by the way, tiktoken is also written in rust(with just 600 lines). So it might be more feasible to introduce the rust part of tiktoken directly into tokenizers?
When you say "introduce the rust part of tiktoken directly into tokenizers", isn't that irrelevant since tiktoken only has code for inference, while what you propose regards tokenizer training?
When you say "introduce the rust part of tiktoken directly into tokenizers", isn't that irrelevant since tiktoken only has code for inference, while what you propose regards tokenizer training?
Sorry, I didn't look closely at tiktoken, I thought it contained both training and inference code.
By the way, one thing is worth noting that bpe's training and inference processes are similar, and in my own attempts, this improved method can also be used in inference, although I didn't notice much performance gain, probably because of the small test data.
Yeah, tiktoken only contains inference code. Either way, thank you for your answers, I'll take some more time to process the code you proposed and I might come back if I have some more questions.
I'll check if it's possible to include the rust part that makes it faster in tiktoken here. I think they have a super efficient regex thing. WIll check
I'll check if it's possible to include the rust part that makes it faster in tiktoken here. I think they have a super efficient regex thing. WIll check
That's a good idea, this actually made me realize that it'd be great to profile each part of the tokenization process separately for both tiktoken and huggingface to see what improvements can be made. Essentially, the running times for regex splitting and then computing the tokens based on the vocab should be profiled, but maybe more fine-grained profiling could be useful too). I could try looking into this, or is it well-known that it's the regex splitting that bottlenecks?
Also, not sure how much of a difference this makes, but tiktoken operates on byte level instead of string level. Any possibilities that this leads to performance improvements?
Oh, and they use simple Python multiprocessing to introduce parallelism instead of rayon
. This is surprising since it is much simpler, yet seems to perform faster. Are there any resources I could read on why rayon
is used in tokenizers
?
Early in this year, I wrote an new implementation for BPE Algorithm in pure python, which is faster than the version in Tokenizer.
I hope this implementation could help tokenizers to further improve the BPE training performance.
I have writen a blog in Chinese about this implementation. I will try to translate it to English if there is any need. By the way, the code is quite short in my opinion, with about merely 400 lines.
Here is the code: https://github.com/Yikai-Liao/efficient_bpe