UKPLab / sentence-transformers

State-of-the-Art Text Embeddings
https://www.sbert.net
Apache License 2.0
15.31k stars 2.48k forks source link

FSDP Training with Sentence Transformer #3023

Open ShengYun-Peng opened 2 weeks ago

ShengYun-Peng commented 2 weeks ago

Given there are so many LLM-based models on top of MTEB benchmark nowadays, is there a canonical way to train with FSDP now? I'm trying to explore along this direction, but I just want to ask if there already exists some examples before I rebuild the wheel.

ShengYun-Peng commented 2 weeks ago

I took a stab on training with FSDP, and encountered quite a few issues: https://github.com/huggingface/accelerate/issues/3201

tomaarsen commented 2 weeks ago

Hello!

There are some details here for me to get it running originally: https://sbert.net/docs/sentence_transformer/training/distributed.html#fsdp But I stopped trying to get a neat and convenient integration once I realised that DDP outperformed FSDP for most small models. I'm definitely open to improving on it though.

ShengYun-Peng commented 2 weeks ago

Thanks! I am working on this direction now and would like to hear your input! While you subclass the transformer trainer class and create the sentence transformer trainer, is there a guideline that you follow to write the customized trainer? I notice that you overwrite the compute loss, prepare inputs, and other class methods. Are you following some template or guidelines or just check the trainer source code line by line to make changes?

tomaarsen commented 2 weeks ago

I don't really check it line-by-line, but I'm somewhat familiar with the overall structure of the transformers Trainer. It's set up in quite a modular way, which means that it's rather feasible to subclass some "high level" methods like compute_loss and get_train_dataloader while leaving lower level methods like training_step, _inner_training_loop, etc. intact.

That's why the Sentence Transformers trainer file is only ~900 lines long, compared to ~5k for the base Trainer.

ShengYun-Peng commented 1 week ago

Hi @tomaarsen, if a model is wrapped, can we directly update the model in the loss function with loss_fn.model = self.model here before calling the loss_fn.forward? Basically, I'm curious about the purpose of override_model_in_loss method in the sentence transformer trainer

ShengYun-Peng commented 1 week ago

Based on my experiments, the evaluator cannot work out-of-the-box with fsdp, and it keeps throwing RuntimeError: 'weight' must be 2-D. I also recalled the doc said evaluator didn't work with fsdp. I'm curious why that is the case.

ShengYun-Peng commented 5 days ago

I have successfully finetuned llama3 for text embedding with FSDP and sentence-transformer with some modifications.

ShengYun-Peng commented 5 days ago

The core issue is to make the model in the loss function be aware of the FSDP setting. I may create a new PR if necessary.