huggingface / setfit

Efficient few-shot learning with Sentence Transformers
https://hf.co/docs/setfit
Apache License 2.0
2.24k stars 223 forks source link

Unfreezing and freezing in new version #474

Closed abedini-arteriaai closed 10 months ago

abedini-arteriaai commented 10 months ago

The migration guide says the following,

Refactor multiple trainer.train(), trainer.freeze() and trainer.unfreeze() calls that were previously necessary to train the differentiable head into just one trainer.train() call by setting batch_size and num_epochs on the TrainingArguments dataclass with tuples. The first value in the tuple is for training the embeddings, and the second is for training the classifier.

This implies we no longer need freeze and unfreeze and that can be replaced with a trainer.train() call as long as batch_size and num_epochs are tuples. However, the guide also says

keep_body_frozen from SetFitModel.unfreeze has been deprecated, simply either pass "head", "body" or no arguments to unfreeze both.

It seems like we still use unfreeze, so what does the change look like? If I can get a sample code it would be very helpful.

Thank you.

tomaarsen commented 10 months ago

Hello!

This implies we no longer need freeze and unfreeze and that can be replaced with a trainer.train() call as long as batch_size and num_epochs are tuples.

Indeed. The reason that SetFitModel.freeze and SetFitModel.unfreeze still exist is because these are now called automatically in the trainer, instead of relying on the user to perform the (un)freezing themselves.

For example:

Before

# Load a SetFit model from Hub
model: SetFitModel = SetFitModel.from_pretrained(
    "sentence-transformers/paraphrase-mpnet-base-v2",
    use_differentiable_head=True,
    head_params={"out_features": 2},
)

# Create trainer
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss_class=CosineSimilarityLoss,
    metric="accuracy",
    learning_rate=2e-5,
    batch_size=16,
    num_iterations=20,
    num_epochs=1,
)

trainer.freeze() # Freeze the head
trainer.train() # Train only the body

# Unfreeze the head and unfreeze the body -> end-to-end training
trainer.unfreeze(keep_body_frozen=False)

trainer.train(
    num_epochs=16,
    batch_size=2,
    body_learning_rate=1e-5,
    learning_rate=1e-2,
)
metrics = trainer.evaluate()

After

# Load a SetFit model from Hub
model: SetFitModel = SetFitModel.from_pretrained(
    "sentence-transformers/paraphrase-mpnet-base-v2",
    use_differentiable_head=True,
    head_params={"out_features": 2},
)

# Create Training Arguments
args = TrainingArguments(
    # When an argument is a tuple, the first value is for training the embeddings,
    # and the latter is for training the differentiable classification head:
    batch_size=(16, 2),
    num_iterations=20,
    num_epochs=(1, 16),
    body_learning_rate=(2e-5, 1e-5),
    head_learning_rate=1e-2,
    end_to_end=True,
    loss=CosineSimilarityLoss,
)

# Create Trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    metric="accuracy",
)

# Train and evaluate
trainer.train()
metrics = trainer.evaluate()

(I didn't run these snippets, so they might have a small mistake somewhere, but they should be roughly correct)

Hope this helps!

abedini-arteriaai commented 10 months ago

This is great, thank you I'll try it.

Is body_learning_rate a tuple value? there's already a head_learning_rate parameter, I assumed that's why learning rate was split into two.

tomaarsen commented 10 months ago

body_learning_rate is indeed a tuple value. This is a bit unfortunate I agree, but there are essentially three learning rates to consider:

You can give body_learning_rate just one float value and then you'll use that LR for both the embedding and classifier phase. I hope this makes some sense! This is perhaps the most complex part of all of the training arguments, so it's all easier from here, haha.