This PyTorch package implements MoEBERT: from BERT to Mixture-of-Experts via Importance-Guided Adaptation (NAACL 2022).
conda env create -f environment.yml
pip install -e .
MoEBERT targets task-specific distillation. Before running any distillation code, a pre-trained BERT model should be fine-tuned on the target task.
Path to the fine-tuned model should be passed to --model_name_or_path
.
bert_base_mnli_example.sh
to compute the importance scores,
add a --preprocess_importance
argument, remove the --do_train
argument.importance_[rank].pkl
file will be saved for each GPU.
Use merge_importance.py
to merge these files.--moebert_load_importance
.examples/text-classification/run_glue.py
.examples/question-answering/run_qa.py
.bash bert_base_mnli_example.sh
as an example.--moebert_route_method
.
hash_balance.py
.
Path to the saved hash list should be passed to --moebert_route_hash_list
.--moebert_load_balance
when using trainable gating mechanisms.