keras-team / keras-nlp

Modular Natural Language Processing workflows with Keras
Apache License 2.0
741 stars 219 forks source link

Rework examples/bert pretraining to use KerasNLP preprocessing #347

Open mattdangerw opened 1 year ago

mattdangerw commented 1 year ago

This will be a little investigatory, as it is unclear what the precise solution could look like. Right now, we have a large preprocessing script that we inherited from the original bert repo for preprocessing bert inputs for pretraining. Overall it is quite long, and does not leverage KerasNLP components.

We would like a simpler preprocessing approach that can leverage tf.data and KerasNLP layers for tokenization, packing, and masking text.

mattdangerw commented 1 year ago

One potential way this could work:

Rework the split sentence script to go from a raw wikipedia dump and books text files -> to a set of sharded files with triples of entires with the form (sentence1, sentence2, next_sentence_label). The output file format should either be csv or tfrecords (and sharded).

This "data prep" script does not need to leverage tf.data at all, but we do want it to be simple and efficient when working the roughly ~20GB of input text that come with Bert pre-training. It will definitely need to efficiently use multithreading on a CPU, most likely with the multiprocessing model.

We will also need to take care to make sure the input sentences are of the correct length. There is a good bit of logic here about deciding when to split pretraining example between the first and second sentences. Recreating the logic exactly would require tokenizing the input just to figure out its token length, but I suspect we can be heuristic and just use the text's word count.

After reworking the data prep script, we would then need to do all the tokenization, segment packing, and word masking, using tf.data and the WordPieceTokenizer, MultiSegmentPacker and MLMMaskGenerator layers. This preprocessing will live inside the bert_train.py script. We will need to validate that this is performant both in terms of throughput and model quality.