rnajena / bertax_training

Training scripts for BERTax
8 stars 4 forks source link

+TITLE: BERTax training utilities

+OPTIONS: ^:nil

This repository contains utilities for pre-training and fine-tuning [[https://github.com/rnajena/bertax][BERTax]] models, as well as various utility functions and scripts used in the development of BERTax.

Additionally to the [[https://doi.org/10.1101/2021.07.09.451778][described mode of training BERTax on /genomic/ DNA sequences]], development scripts for training on /gene/ sequences are included as well.

For the training of a new BERTax model the user must provide one of the three following data-structures.

*** multi fastas with TaxIds For training any model you can use multi fastas which should contain the sequences of your classes of interest.

+begin_example

[class_1].fa [class_2].fa ... [class_n].fa

+end_example

The =fasta= files contain headers, which consist of a sequences associated species' TaxID and a concurrent index.

Example fasta file (you can find our pretraining fasta files [[https://osf.io/qg6mv/files/osfstorage/60faa9f84d949102092323b4][here]]):

+begin_example

380669 0 TCGAGATACCAGATGGAAATCCTCCAGAGGTATTATCGGAA 264076 1 GCAGACGAGTTCACCACTGCTGCAGGAAAAGAT 218387 2 AACTATGCATAGGGCCTTTGCCGGCACTAT

+end_example

After generating these files you can transform them for training.

*** fragments directories

For training the normal, genomic DNA-based models, a fixed directory structure with one =json= file and one =txt= file per classis required:

+begin_example

[class_1]_fragments.json [class_1]_species_picked.txt [class_2]_fragments.json [class_2]_species_picked.txt ... [class_n]_fragments.json [class_n]_species_picked.txt

+end_example

The =json= files must consist of a simple list of sequences. Example json file:

+begin_src

["ACGTACGTACGATCGA", "TACACTTTTTA", ..., "ATACTATCTATCTA"]

+end_src

The =txt= files are a ordered lists of the corresponding TaxIDs, meaning the first listed TaxID describes the taxonomical origin of the first sequence in the json file with the same prefix.

Example txt file:

+begin_src

380669 264076 218387 11569 204873 346884 565995 11318

+end_src

*** gene model training directories

The gene models were used in an early stage of BERTax development, where a different directory structure was required:

Each sequence is contained in a fasta file, additionally, a =json= file containg all file-names and associated classes can speed up preprocessing tremendously.

+begin_example

[class_1]/ [sequence_1.fa] [seuqence_2.fa] ... [sequence_n.fa] [class_2]/ ... .../ [class_n]/ ... [sequence_l.fa] {files.json}

+end_example

The =json=-files cotains a list of two lists with equal size, the first list contains filepaths to the fasta files and the second list the associated classes:

+begin_src

[["class_1/sequence1.fa", "class_1/sequence2.fa", ..., "class_n/sequence_l.fa"], ["class_1", "class_1", ..., "class_n"]]

+end_src

** Training process The normal, genomic DNA-based model can be pre-trained with [[file:models/bert_nc.py]] and fine-tuned with [[file:models/bert_nc_finetune.py]].

For example, the BERTax model was pre-trained with:

+begin_src shell

python -m models.bert_nc fragments_root_dir --batch_size 32 --head_num 5 \ --transformer_num 12 --embed_dim 250 --feed_forward_dim 1024 --dropout_rate 0.05 \ --name bert_nc_C2 --epochs 10

+end_src

and fine-tuned with:

+begin_src shell

python -m models.bert_nc_finetune bert_nc_C2.h5 fragments_root_dir --multi_tax \ --epochs 15 --batch_size 24 --save_name _small_trainingset_filtered_fix_classes_selection \ --store_predictions --nr_seqs 1000000000

+end_src

The development gene models can be pre-trained with [[file:models/bert_pretrain.py]]:

+begin_src shell

python -m models.bert_pretrain bert_gene_C2 --epochs 10 --batch_size 32 --seq_len 502 \ --head_num 5 --embed_dim 250 --feed_forward_dim 1024 --dropout_rate 0.05 \ --root_fa_dir sequences --from_cache sequences/files.json

+end_src

and fine-tuned with [[file:models/bert_finetune.py]]:

+begin_src shell

python -m models.bert_finetune bert_gene_C2_trained.h5 --epochs 4 \ --root_fa_dir sequences --from_cache sequences/files.json

+end_src

All training scripts can be called with the =--help= flag to adjust various parameters.

** Using BERT models

It is recommended to use fine-tuned models in the BERTax tool with the parameter =--custom_model_file=.

However, a much more minimal script to predict multi-fasta sequences with the trained model is also available in this repository:

+begin_src shell

python -m utils.test_bert finetuned_bert.h5 --fasta sequences.fa

+end_src

** Benchmarking If the user needs a predefined training and test set, for example for benchmarking different approaches:

+begin_src shell

python -m preprocessing.make_dataset single_sequences_json_folder/ out_folder/ --unbalanced

+end_src

This creates a the files test.tsv, train.tsv, classes.pkl which can be used by bert_nc_finetune

+begin_src shell

python -m models.bert_nc_finetune bert_nc_trained.h5 make_dataset_out_folder/ --unbalanced --use_defined_train_test_set

+end_src

If fasta files are necessary, e.g., for competing methods, you can parse the train.tsv and test.tsv via

+begin_src shell

python -m preprocessing.dataset2fasta make_dataset_out_folder/

+end_src