huggingface / setfit

Efficient few-shot learning with Sentence Transformers
https://hf.co/docs/setfit
Apache License 2.0
2.16k stars 217 forks source link

Comparing setfit with a simpler approach #297

Closed vahuja4 closed 1 year ago

vahuja4 commented 1 year ago

Hi, I am trying to compare setfit with another approach. The other approach is like this:

  1. Define a list of representative sentences per class, call it rep_sent
  2. Compute sentence embeddings for rep_sent using mpnet-base-v2
  3. Define a list of test sentences, call it 'test_sent'.
  4. Compute sentence embeddings for 'test_sent'
  5. Now, in order to assign a class to the sentences in test_sent, compute the cosine similarity with rep_sent and choose the class based on the highest cosine sim.

If we consider a particular test set sentence: "Remove maiden name from account", then the results from the two approaches are as follows: setfit predicts this to be 'manage account transfer' other approach predicts this to be 'edit account details'

Can someone please help me to understand how setfit's performance can be improved. As far as setfit goes, it has been trained use 'rep_sent' as the training set. Here is how it looks like:

`text,label

I want to close my account,accountClose

Close my credit card,accountClose

Mortgage payoff,accountClose

Loan payoff,accountClose

Loan pay off,accountClose

pay off,accountClose

lease payoff,accountClose

lease pay off,accountClose

account close,accountClose

close card account,accountClose

I want to open an account,accountOpenGeneral

I want to get a card,accountOpenGeneral

I want a loan,accountOpenGeneral

Refinance my car,accountOpenGeneral

Buy a car,accountOpenGeneral

Open checking,accountOpenGeneral

Open savings,accountOpenGeneral

Lease a vehicle,accountOpenGeneral

Link external bank account,accountTransferManage

verify external account,accountTransferManage

Add external account,accountTransferManage

Edit external account,accountTransferManage

Remove external account,accountTransferManage

Mortgage payment,billPaySchedulePayment

Setup Loan payment,billPaySchedulePayment

Setup auto loan payment,billPaySchedulePayment

Schedule bill payment,billPaySchedulePayment

Setup bill payment,billPaySchedulePayment

Setup automatic payment,billPaySchedulePayment

Setup auto pay,billPaySchedulePayment

Setup automatic payment,billPaySchedulePayment

Setup automatic payment,billPaySchedulePayment

Modify account details,editAccountDetails

Modify name on my account,editAccountDetails

Change address in my account,editAccountDetails`

vahuja4 commented 1 year ago

Can someone please tell me how to make the model work better here?

tomaarsen commented 1 year ago

Perhaps you'll have some luck playing around with the following script:

from typing import List
from datasets import Dataset
from setfit import SetFitModel, SetFitTrainer

# Load some training & evaluation datasets
train_dataset = Dataset.from_csv("bank_data.csv")
eval_dataset = Dataset.from_dict(
    {
        "text": [
            "I'd like for my account to be closed",
            "Could you open an account for me please?",
            "Can you edit an external account?",
            "Please set up auto payment",
            "Can you change my name on my account?"
        ],
        "label": ["accountClose", "accountOpenGeneral", "accountTransferManage", "billPaySchedulePayment", "editAccountDetails"],
    }
)
dev_dataset = Dataset.from_dict(
    {
        "text": [
            "close account",
            "account open",
            "edit account",
            "auto payment",
            "change name?"
        ],
        "label": ["accountClose", "accountOpenGeneral", "accountTransferManage", "billPaySchedulePayment", "editAccountDetails"],
    }
)

def encode_labels(sample, classes: List):
    sample["label"] = classes.index(sample["label"])
    return sample

# Convert labels to integers
classes = list(set(train_dataset["label"]))
train_dataset = train_dataset.map(lambda sample: encode_labels(sample, classes))
eval_dataset = eval_dataset.map(lambda sample: encode_labels(sample, classes))

# Load a SetFit model from Hub
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")

# Create trainer
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    metric="accuracy",
    batch_size=16,
    num_iterations=2, # The number of text pairs to generate for contrastive learning
    num_epochs=1, # The number of epochs to use for contrastive learning
)

# Train and evaluate
trainer.train()
metrics = trainer.evaluate()

print(metrics)

for sample in dev_dataset:
    text = sample["text"]
    str_label = sample["label"]

    predicted_int_label = model([text])
    predicted_str_label = classes[predicted_int_label]
    print(f"Input: {text}")
    print(f"Predicted class: {predicted_str_label}")
    print(f"Known truth class: {str_label}\n")

Note that I placed your data in a file called bank_data.csv. The output that I got when running this script is:

model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.
***** Running training *****
  Num examples = 140
  Num epochs = 1
  Total optimization steps = 9
  Total train batch size = 16
Iteration: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:08<00:00,  1.10it/s] 
Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:08<00:00,  8.18s/it] 
***** Running evaluation *****
{'accuracy': 1.0}
Input: close account
Predicted class: accountClose
Known truth class: accountClose

Input: account open
Predicted class: accountOpenGeneral
Known truth class: accountOpenGeneral

Input: edit account
Predicted class: editAccountDetails
Known truth class: accountTransferManage

Input: auto payment
Predicted class: billPaySchedulePayment
Known truth class: billPaySchedulePayment

Input: change name?
Predicted class: editAccountDetails
Known truth class: editAccountDetails

As you can see, the accuracy from the evaluation set is 100%, and the outputs from the dev set are all correct, too.

I'll close this now, as it's not exactly a SetFit "issue", but feel free to respond or reopen.