huggingface / setfit

Efficient few-shot learning with Sentence Transformers
https://hf.co/docs/setfit
Apache License 2.0
2.19k stars 220 forks source link

Model checkpoints saved during the training are unusable #526

Open n-splv opened 4 months ago

n-splv commented 4 months ago

Step 1: Train a model:

model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
model = SetFitModel.from_pretrained(
    model_name,
    multi_target_strategy="multi-output",
    use_differentiable_head=True, 
    head_params={"out_features": len(id2label)},
)

args = TrainingArguments(
    output_dir=MODEL_DIR,
    batch_size=32,
    num_epochs=20,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    save_total_limit=4,
    sampling_strategy='unique',
)
args.eval_strategy = args.evaluation_strategy

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    metric=batch_multi_label_metric,
)
trainer.train()

Step 2: Save the model explicitly. The examples in docs always do it, but there's no clear communication that this is absolutely necessary and, in fact, the only way to use the model later:

model.save_pretrained(MODEL_DIR / "explicit_save")

Step 3: Try to load from the latest checkpoint

checkpoint_model = SetFitModel.from_pretrained(
    MODEL_DIR / "step_26560",
)

Without any warning, this model will not perform well, because the classifier (head) weights have not been loaded or even saved in the first place. If we compare this model's head with the one we saved explicitly, the difference is obvious:

explicit_model = SetFitModel.from_pretrained(
    MODEL_DIR / "explicit_save",
)

checkpoint_head_weights = next(checkpoint_model.model_head.named_parameters())[1]
explicit_head_weights = next(explicit_model.model_head.named_parameters())[1]

fig1 = px.line(checkpoint_head_weights.detach().numpy().ravel())
fig2 = px.line(explicit_head_weights.detach().numpy().ravel())

newplot newplot (1)

So if I didn't mess something up, my proposal would be to ether make this behavior clear to the user, or better to fix it so that the checkpoints would be usable.

cjuracek-tess commented 1 week ago

+1, I am wondering how the intermediate model checkpoints are supposed to be used. They are saved during the fine-tuning, not during the classification phase of training. Thus, I'm interpreting they need to have their classifier trained afterward?

Update: I have tried evaluation using model checkpoints for my own model (cannot share code). The precision tanks drastically compared to the fully trained model, suggesting to me that the classifier head is not trained for checkpoints.

cjuracek-tess commented 1 week ago

@n-splv If you are interested in using these checkpoints, the workaround for me was to actually train the classifier head ("filling in" the missing logic from Trainer.train())

model = SetFitModel.from_pretrained(<checkpoint>)
trainer = Trainer(
    args=args,
    model=model,
    ...
)
train_parameters = trainer.dataset_to_parameters(trainer.train_dataset)
trainer.train_classifier(*train_parameters, args=trainer.args)