ybracke / transnormer

A lexical normalizer for historical spelling variants using a transformer architecture.
GNU General Public License v3.0
6 stars 1 forks source link

Dealing with long input sequences #49

Open ybracke opened 1 year ago

ybracke commented 1 year ago

BERT-like models can only process input up to a fixed length (e.g. 512). T5-like (e.g. ByT5) models have been trained with fixed length input, but can process input of any size (see here or here). However, long inputs can lead to high memory requirements or memory errors with ByT5.

Therefore, we take the following approach:

Training Sequences exceeding the max_seq_len (currently: 512) in the training set can simply be excluded from training, since there aren't too many. Alternatively the data processing step can chunk these input sequences into more than one. The latter approach needs a bit more work, so we go with the first one.

Inference As stated above ByT5 can deal with any input length, but may have memory issues. In my test data, there are not too many examples that heavily exceed 512 bytes. Thus, for now, we do not treat them differently. We just pad all our batches to the longest sequence in the batch, making many batches shorter than 512 and a few of them longer than 512.

Perspective Generally, we should have a way to deal with super-long input sequences during inference. For example, we could say that an input sequence should never exceed 1024 bytes and if it does, it should be processed in multiple steps. A simple way would be to chunk the input after n bytes, treat it as two separate sequences, get the generations and put the generations back together. This way the split could occur within a word, which is not desirable, so a better way would be to chunk at the last space, comma, etc. This would make a customized chunking method necessary. Still it would leave us with the problem that contextual information is lost. The best way would be to split the original long sequence into chunks of overlapping sequences, get the generation for each of these sequences and then harmonize the output. Here is a description of a similar approach for a token classification problem (!= seq2seq). To get the overlapping sequences we can make use of existing huggingface functions, see next comment. To harmonize the genrations, we might have to write a custom generate function (hf tutorial on generation, code generate).

ybracke commented 9 months ago

Creating overlapping sequences with huggingface functionality:

import transformers
checkpoint = "google/byt5-small"
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
batch_encoding = tokenizer.prepare_for_model(
    [2,3,4,5,6,7,8,9,10,11], 
    add_special_tokens=False,
    truncation=True,
    max_length=5,
    stride=2, 
    return_overflowing_tokens=True)
print(batch_encoding)

Output:

{'attention_mask': [1, 1, 1, 1, 1],
 'input_ids': [2, 3, 4, 5, 6],
 'num_truncated_tokens': 5,
 'overflowing_tokens': [5, 6, 7, 8, 9, 10, 11]}

Now pass batch_encoding.overflowing_tokens to tokenizer.prepare_for_model to get the next five tokens with two overlapping at the end, etc.

Link to documentation

ybracke commented 8 months ago

The notebook reports/dataset-sizes.ipynb contains stats on the mean length and upper outer fence (uof = Q3+3*IQR) for my datasets. It shows that, for many datasets the uof is >128 subword tokens. (Bytes are also provided for some datasets.) On the other hand, all but two datasets have a uof<=172. So choosing 176 (=11*16) as max_length might be a good idea. 252 would lie (way) above all uofs for my datasets, as does the typical max_length value of 512.