UKPLab / sentence-transformers

Multilingual Sentence & Image Embeddings with BERT
https://www.SBERT.net
Apache License 2.0
14.74k stars 2.43k forks source link

support gradient checkpointing for training #2531

Open pszemraj opened 5 months ago

pszemraj commented 5 months ago

Hi, I am trying to use the model.fit method with gradient_checkpointing=True due to memory usage/total memory constraints. After having tried several variations on the idea, neither myself nor claude3/gpt4 can figure out where to put it such that it doesn't cause errors. Most errors are typically due to the model becoming None.

example attempt

example attempt at adding it in the convoluted models.Transformer model class prior to creating a SentenceTransformer:

model_name = "distilroberta-base" # supports gradient_checkpointing in transformers 
word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)

# Enable gradient checkpointing on the underlying transformer model
if gradient_checkpointing:
    word_embedding_model.auto_model.gradient_checkpointing_enable()

pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

# ... (rest of the code for dataloading etc)
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    evaluator=dev_evaluator,
    epochs=num_epochs,
    evaluation_steps=int(len(train_dataloader) * 0.05),
    warmup_steps=warmup_steps,
    output_path=model_save_path,
    checkpoint_path=model_save_path,
    use_amp=torch.cuda.is_available(),
    checkpoint_save_total_limit=1,
)

results in:

# ...
File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 475, in _Fire

    component, remaining_args = _CallAndUpdateTrace(

  File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 691, in _CallAndUpdateTrace

    component = fn(*varargs, **kwargs)

  File "/content/train_sbert_from_encoder.py", line 350, in main

    model.fit(

AttributeError: 'NoneType' object has no attribute 'fit'

the same code works without trying to enable gradient checkpointing.


I also tried passing it to model_args, but it doesn't accept it. Any input on whether I am doing something wrong or if this missing feature can be added to this package would be appreciated, I cannot use the library to train the intended models without it.

tomaarsen commented 5 months ago

Hello!

I'm afraid that gradient checkpointing is not supported by the fit method right now. #2449 should introduce it, but that PR is still undergoing testing. If you have memory issues, then the best solution would be to reduce the batch_size in the DataLoader. This should result in a linear reduction in memory as well.

pszemraj commented 5 months ago

Thanks for the reply!

So, good news first: after some initial fiddling where I almost gave up, I did get your PR to train with a non-stock model and dataset with gradient checkpointing. You can see my runs on wandb here, it also is saving the .py scripts to each run for details.

if you find it helpful, I can provide "qualitative feedback" in that PR thread on some features that work/don't work, but I'm not sure I'll have time to hunt down where exactly in the source code they are related to as they aren't "errors" per se but run OOM, or loss is always 0, etc

re: batch size

alright so I didn't mention it explicitly so it's a fair suggestion, but I been on that batch_size=1 game for a while now

image


BTW, since the implementation would be tied to the PR, feel free to close this issue or link it to the PR to be closed in the future - either works for me.

tomaarsen commented 5 months ago

I'll link them :)

Well done on getting that PR working! It doesn't have much documentation to guide you right now 😄 And please feel free to share any feedback in that PR. The refactor will be so big that I won't be able to test all combinations & settings, e.g. gradient checkpointing was likely not going to be thoroughly tested.

Also, that Spearman cosine looks great (0.93 Spearman correlation based on Cosine Similarity). Is that from a hold-out testing set or from the training set?

pszemraj commented 5 months ago

Great! I'll share some points in the PR later when I have a chance to write things up a bit more. A tl;dr (all of these I have tried to validate with multiple base models) is that:

That was longer than I meant but you get the idea.

re: eval scores

Yeah I was surprised by this, but it could make sense. It is using held-out validation samples split from the source data (n=500), but there are a couple spicy things I'm testing here at the same time so I'm going to wait till I can validate this a bit better.

bonus mostly unsure where to put this. I have a question related to controlling tokenizer padding during both train and inference. I'm not sure if it's something you are specifically adjusting in the PR, but it is a problem in both "old" SBERT and still doesn't work in the v3 PR code. Should I make a new issue, ask in your PR conversation, or something else?

tomaarsen commented 5 months ago

Those all sound good to know! I'll be glad to learn more about them in the PR comment. By "base loss", do you mean with Matryoshka? I admittedly have only tested that with CoSENTLoss, AnglELoss, CosineSimilarityLoss and MultipleNegativeRankingLoss, but it should work with all.

bonus mostly unsure where to put this. I have a question related to controlling tokenizer padding during both train and inference. I'm not sure if it's something you are specifically adjusting in the PR, but it is a problem in both "old" SBERT and still doesn't work in the v3 PR code. Should I make a new issue, ask in your PR conversation, or something else?

I think a new issue would be best suited. I'm aware that the tokenizer settings are just harshly overridden by some old code. The big problem is trying to give people freedom to change it without also breaking a lot of old models.

pszemraj commented 5 months ago

Great, will add details to the PR sometime this weekend.

By "base loss", do you mean with Matryoshka? I admittedly have only tested that with CoSENTLoss, AnglELoss, CosineSimilarityLoss and MultipleNegativeRankingLoss, but it should work with all.

I added the "base" bit because for me, using Matryoshka or not doesn't seem to 'break' training. So basically, all of those losses you listed (except Cosine) cause the train loss to be 0. wandb examples, relevant code is there for each run: MultipleNegativeRankingLoss, CoSENTLoss

I think a new issue would be best suited. I'm aware that the tokenizer settings are just harshly overridden by some old code. The big problem is trying to give people freedom to change it without also breaking a lot of old models.

ty, will do!

tomaarsen commented 5 months ago

Will experiment with these on Monday. Those are some very clean scripts you've created, I'm impressed 😄 Obviously, the losses shouldn't all just be 0 haha.