An instruction-based unified model for performing various biomedical tasks.
You may want to check out
This work explores the impact of instructional prompts on biomedical Multi-Task Learning. We introduce the BoX, a collection of 32 instruction tasks for Biomedical NLP across (X) various categories. Using this meta-dataset, we propose a unified model termed In-BoXBART, that can jointly learn all tasks of the BoX without any task-specific modules. To the best of our knowledge, this is the first attempt to propose a unified model in the biomedical domain and use instructions to achieve generalization across several biomedical tasks. Below figure shows the overview of the approach.
You can very easily load the models with Transformers, instead of downloading them manually. The BART-base model is the backbone of our model. Here is how to use the model in PyTorch:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("cogint/in-boxbart")
model = AutoModelForSeq2SeqLM.from_pretrained("cogint/in-boxbart")
Or just clone the model repo
git lfs install
git clone https://huggingface.co/cogint/in-boxbart
Please see ./templates
to find instructional prompts corresponsing to all tasks from BoX that are used for experiments. For more details, refer our paper.
Find below details to use our source code for fine-tuning other models.
Setup
Run the following command to install all the dependecies to run the model:
pip install -r requirements.txt
Training
In order to finetune model on your data, use scripts/finetune_model.py
and run it with the following arguments:
python scripts/finetune_model.py \
--model_name_or_path Provide path of the model, you want to finetune. To finetune on BART use - "facebook/bart-base"
(default: None) \
--do_train Provide True or False
(default: False) \
--do_eval Provide True or False
(default: False) \
--do_predict Provide True or False
(default: False) \
--train_file Path of an optional input training data file (a JSON or CSV file), if do_train argument is true.
(default: None) \
--validation_file Path of an optional input evaluation data file to evaluate the metrics (rouge) on (a JSON or CSV file), if do_eval argument is true.
(default: None) \
--test_file Path of an optional input test data file to evaluate the metrics (rouge) on (a JSON or CSV file), if do_predict argument is true.
(default: None) \
--output_dir The output directory where the model predictions and checkpoints will be written.
(default: None) \
--per_device_train_batch_size Batch size per GPU/TPU core/CPU for training.
(default: 8) \
--per_device_eval_batch_size Batch size per GPU/TPU core/CPU for evaluation.
(default: 8) \
--gradient_accumulation_steps Number of updates steps to accumulate before performing a backward/update pass.
(default: 1) \
--predict_with_generate Whether to use generate to calculate generative metrics (ROUGE, BLEU).
(default: False) \
--save_strategy The checkpoint save strategy to use. (no, steps, epoch)
(default: steps)
Evaluation
To evaluate model, use scripts/evaluation.py
and run it with the following arguments:
python evaluation.py \
--dataset_file Path of test data file (a JSON or CSV file), which contains ground truth.
(default: None) \
--prediction_file Path of the prediction file; expected to be a JSON file of the following format: { "predictions": ["pred1", "pred2", ...] } or .txt file of the following format: "pred1" \n "pred2" \n ...
(default: None) \
--save_results The output directory where JSON of result will be saved.
(default: None)
Here, we provide an example for the "Document Classification" (HoC dataset) task. Once you load the model from huggigface for inference, you can append instruction given in ./templates
for that particular dataset with input instance. Below is an example of one instance.
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("cogint/in-boxbart")
model = AutoModelForSeq2SeqLM.from_pretrained("cogint/in-boxbart")
# Input shows how we have appended instruction from our file for HoC dataset with instance.
input = "Instruction: Definition: In this task, you are given a medical text related to cancer. Your job is to classify into zero or more classes from (1) Sustaining proliferative signaling, (2) Resisting cell death, (3) Genomic instability and mutation, (4) Activating invasion and metastasis, (5) Tumor promoting inflammation, (6) Evading growth suppressors, (7) Inducing angiogenesis (8) Enabling replicative immortality, (9) Avoiding immune destruction and (10) Cellular energetics., Positive Examples: [[input: Studies of cell-cycle progression showed that the anti-proliferative effect of Fan was associated with an increase in the G1/S phase of PC3 cells ., output: Evading growth suppressors, Sustaining proliferative signaling, explanation: Given text is classified into two categories, hence, generated label is 'Evading growth suppressors, Sustaining proliferative signaling'.] ]; Instance: input: Similar to previous studies utilizing IGF-1 , pretreatment with Roscovitine leads to a significant up-regulation of p21 expression and a significant decrease in the number of PCNA positive cells ., output: ?"
# Ideal output for this input is 'Sustaining proliferative signaling'
output = model(input)
If you are using our model, please cite our paper:
@article{parmar2022boxbart,
title={{In-BoXBART: Get Instructions into Biomedical Multi-Task Learning}},
author={Parmar, Mihir and Mishra, Swaroop and Purohit, Mirali and Luo, Man and Murad, M Hassan and Baral, Chitta},
journal={NAACL 2022 Findings},
year={2022}
}