huggingface / setfit

Efficient few-shot learning with Sentence Transformers
https://hf.co/docs/setfit
Apache License 2.0
2.24k stars 223 forks source link

Hyperparameter tuning for AbsaModel #494

Open ysapolovych opened 9 months ago

ysapolovych commented 9 months ago

Hi! I would like to do hyperparameter search for an ABSA model (polarity only to be precise). Since there is no hyperparameter_search method for AbsaTrainer, I do this like in the code below, however, I am not certain if I do this right/optimally.

def objective(trial):
    params = {
        'batch_size': trial.suggest_categorical('batch_size', [8, 16, 32]),
        'body_learning_rate': trial.suggest_float('learning_rate', 1e-5, 1e-3),
        'span_contexts': trial.suggest_categorical('span_contexts', [(0, 3), (0, 4), (0, 5)])
    }

    model = AbsaModel.from_pretrained(
        model_id='all-mpnet-base-v2',
        spacy_model='en_core_web_sm',
        span_contexts=params.pop('span_contexts')
    )

    trainer = AbsaTrainer(
        model,
        train_dataset=train,
        metric=compute_metrics
    )

    trainer.train_polarity(args=TrainingArguments(**params))

    res = model.predict(eval)

    f1 = f1_score(res['label'].tolist(), res['pred_polarity'].tolist(), average='macro')

    return f1

study = optuna.create_study(direction='maximize', sampler=optuna.samplers.TPESampler(seed=99))

study.optimize(objective, n_trials=100, gc_after_trial=True)

Some questions:

aspect_extractor = AspectExtractor(spacy_model='en_core_web_sm')

def objective(trial): params = { 'batch_size': trial.suggest_categorical('batch_size', [8, 16, 32]), 'body_learning_rate': trial.suggest_float('learning_rate', 1e-5, 1e-3), 'multi_target_strategy': trial.suggest_categorical('multi_target_strategy', ['one-vs-rest', 'multi-output', 'classifier-chain']) }

st_model = SentenceTransformer('all-mpnet-base-v2')
lr = LinearRegression()
pol_model = PolarityModel(
    st_model,
    lr,
    multi_target_strategy=params.pop('multi_target_strategy'),
    spacy_model='en_core_web_sm')

model = AbsaModel(
    aspect_extractor=aspect_extractor,
    aspect_model=aspect_model,
    polarity_model=pol_model
)

trainer = AbsaTrainer(
    model,
    train_dataset=train,
    metric=compute_metrics
)

(...)