fdschmidt93 / trident-nllb-llm2vec

Repository for "Self-Distillation for Model Stacking Unlocks Cross-Lingual NLU in 200+ Languages"
MIT License
12 stars 0 forks source link

Disclaimer: The code still undergoes refactoring to be much more usable and clean out-of-the-box.

NLLB-LLM2Vec

Setup & Installation

Cloning repositories

You can install required pinned dependencies with the below command.

Setting up environment

For now, miniconda is the recommended tool to manage dependencies for this project.

git clone --recurse-submodules https://github.com/fdschmidt93/trident-nllb-llm2vec.git
cd ./trident-nllb-llm2vec/
conda env create -f environment.yaml

If you want to run NLLB-GPT-2 adaptation, you additionally will have to install flash-attention as, to the best of our knowledge, environment.yaml files do not support the required flag.

pip install flash-attn --no-build-isolation

Initial setup of LLM2Vec

On a machine with a Nvidia-GPU, run

cd ./trident-nllb-llm2vec/
python ./prepare_model.py

This downloads Llama 3 8B, LLM2Vec adapters, merges the required adapters, and stores the model into the appropriate folder.

General adaptation

The NLLB-LLM2Vec adaptation requires downloading the FineWeb 10BT dataset with the following script.

cd ./trident-nllb-llm2vec/
bash ./download_fineweb.sh

Then the model can be trained. We trained the model on 8 A100 80GB for 10K steps (~22 hours) with the below script.

python -m trident.run experiment=adaptation_nllb-llm2vec.yaml hydra.run.dir=$OUTPUT_FOLDER

You must set the output folder where checkpoints get stored to.

Task Fine-tuning

You can train LLM2Vec on a particular $TASK as follows.

cd ./trident-nllb-llm2vec/
bash train_llm2vec_$TASK.sh $SEED

We ran with seeds 42, 43, 44 (NLI & Belebele) and 42, 43, 44, 45, 46 for NusaX.

The outputs are then stored to ./trident-nllb-llm2vec/logs/nllb-llm2vec/llm2vec/nli/seed-$SEED/

Task Distillation

Task distillation first requires to pre-embed the training datasets with the fine-tuned LLM2Vec models which takes 30-60 minutes depending on your GPU infrastructure. You need to check what checkpoint ($EPOCH) performed best on source-language validation instances on wandb.

cd ./trident-nllb-llm2vec/
bash preembed_llm2vec_$TASK.sh $SEED $EPOCH

Then you can run

cd ./trident-nllb-llm2vec/
bash distill_nllb-llm2vec_$TASK.sh $SEED $EPOCH

At last, you can evaluate your model as follows.

cd ./trident-nllb-llm2vec/
# evaluates on all 3 (or 5, for NusaX) seeds already
bash test_nllb-llm2vec_$TASK.sh $EPOCH

Evaluation can be very costly due to the number languages that are being evaluated. Belebele requires about 3hr on a A100 40GB.