Open ybracke opened 1 year ago
Creating overlapping sequences with huggingface functionality:
import transformers
checkpoint = "google/byt5-small"
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
batch_encoding = tokenizer.prepare_for_model(
[2,3,4,5,6,7,8,9,10,11],
add_special_tokens=False,
truncation=True,
max_length=5,
stride=2,
return_overflowing_tokens=True)
print(batch_encoding)
Output:
{'attention_mask': [1, 1, 1, 1, 1],
'input_ids': [2, 3, 4, 5, 6],
'num_truncated_tokens': 5,
'overflowing_tokens': [5, 6, 7, 8, 9, 10, 11]}
Now pass batch_encoding.overflowing_tokens
to tokenizer.prepare_for_model
to get the next five tokens with two overlapping at the end, etc.
The notebook reports/dataset-sizes.ipynb
contains stats on the mean length and upper outer fence (uof = Q3+3*IQR
) for my datasets. It shows that, for many datasets the uof is >128 subword tokens. (Bytes are also provided for some datasets.) On the other hand, all but two datasets have a uof<=172. So choosing 176 (=11*16) as max_length might be a good idea. 252 would lie (way) above all uofs for my datasets, as does the typical max_length value of 512.
BERT-like models can only process input up to a fixed length (e.g. 512). T5-like (e.g. ByT5) models have been trained with fixed length input, but can process input of any size (see here or here). However, long inputs can lead to high memory requirements or memory errors with ByT5.
Therefore, we take the following approach:
Training Sequences exceeding the
max_seq_len
(currently: 512) in the training set can simply be excluded from training, since there aren't too many. Alternatively the data processing step can chunk these input sequences into more than one. The latter approach needs a bit more work, so we go with the first one.Inference As stated above ByT5 can deal with any input length, but may have memory issues. In my test data, there are not too many examples that heavily exceed 512 bytes. Thus, for now, we do not treat them differently. We just pad all our batches to the longest sequence in the batch, making many batches shorter than 512 and a few of them longer than 512.
Perspective Generally, we should have a way to deal with super-long input sequences during inference. For example, we could say that an input sequence should never exceed 1024 bytes and if it does, it should be processed in multiple steps. A simple way would be to chunk the input after n bytes, treat it as two separate sequences, get the generations and put the generations back together. This way the split could occur within a word, which is not desirable, so a better way would be to chunk at the last space, comma, etc. This would make a customized chunking method necessary. Still it would leave us with the problem that contextual information is lost. The best way would be to split the original long sequence into chunks of overlapping sequences, get the generation for each of these sequences and then harmonize the output. Here is a description of a similar approach for a token classification problem (!= seq2seq). To get the overlapping sequences we can make use of existing huggingface functions, see next comment. To harmonize the genrations, we might have to write a custom
generate
function (hf tutorial on generation, codegenerate
).