Closed vahuja4 closed 1 year ago
Can someone please tell me how to make the model work better here?
Perhaps you'll have some luck playing around with the following script:
from typing import List
from datasets import Dataset
from setfit import SetFitModel, SetFitTrainer
# Load some training & evaluation datasets
train_dataset = Dataset.from_csv("bank_data.csv")
eval_dataset = Dataset.from_dict(
{
"text": [
"I'd like for my account to be closed",
"Could you open an account for me please?",
"Can you edit an external account?",
"Please set up auto payment",
"Can you change my name on my account?"
],
"label": ["accountClose", "accountOpenGeneral", "accountTransferManage", "billPaySchedulePayment", "editAccountDetails"],
}
)
dev_dataset = Dataset.from_dict(
{
"text": [
"close account",
"account open",
"edit account",
"auto payment",
"change name?"
],
"label": ["accountClose", "accountOpenGeneral", "accountTransferManage", "billPaySchedulePayment", "editAccountDetails"],
}
)
def encode_labels(sample, classes: List):
sample["label"] = classes.index(sample["label"])
return sample
# Convert labels to integers
classes = list(set(train_dataset["label"]))
train_dataset = train_dataset.map(lambda sample: encode_labels(sample, classes))
eval_dataset = eval_dataset.map(lambda sample: encode_labels(sample, classes))
# Load a SetFit model from Hub
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
# Create trainer
trainer = SetFitTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
metric="accuracy",
batch_size=16,
num_iterations=2, # The number of text pairs to generate for contrastive learning
num_epochs=1, # The number of epochs to use for contrastive learning
)
# Train and evaluate
trainer.train()
metrics = trainer.evaluate()
print(metrics)
for sample in dev_dataset:
text = sample["text"]
str_label = sample["label"]
predicted_int_label = model([text])
predicted_str_label = classes[predicted_int_label]
print(f"Input: {text}")
print(f"Predicted class: {predicted_str_label}")
print(f"Known truth class: {str_label}\n")
Note that I placed your data in a file called bank_data.csv
.
The output that I got when running this script is:
model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.
***** Running training *****
Num examples = 140
Num epochs = 1
Total optimization steps = 9
Total train batch size = 16
Iteration: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:08<00:00, 1.10it/s]
Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:08<00:00, 8.18s/it]
***** Running evaluation *****
{'accuracy': 1.0}
Input: close account
Predicted class: accountClose
Known truth class: accountClose
Input: account open
Predicted class: accountOpenGeneral
Known truth class: accountOpenGeneral
Input: edit account
Predicted class: editAccountDetails
Known truth class: accountTransferManage
Input: auto payment
Predicted class: billPaySchedulePayment
Known truth class: billPaySchedulePayment
Input: change name?
Predicted class: editAccountDetails
Known truth class: editAccountDetails
As you can see, the accuracy from the evaluation set is 100%, and the outputs from the dev set are all correct, too.
I'll close this now, as it's not exactly a SetFit "issue", but feel free to respond or reopen.
Hi, I am trying to compare setfit with another approach. The other approach is like this:
rep_sent
rep_sent
usingmpnet-base-v2
test_sent
, compute the cosine similarity withrep_sent
and choose the class based on the highest cosine sim.If we consider a particular test set sentence: "Remove maiden name from account", then the results from the two approaches are as follows: setfit predicts this to be 'manage account transfer' other approach predicts this to be 'edit account details'
Can someone please help me to understand how setfit's performance can be improved. As far as setfit goes, it has been trained use 'rep_sent' as the training set. Here is how it looks like:
`text,label
I want to close my account,accountClose
Close my credit card,accountClose
Mortgage payoff,accountClose
Loan payoff,accountClose
Loan pay off,accountClose
pay off,accountClose
lease payoff,accountClose
lease pay off,accountClose
account close,accountClose
close card account,accountClose
I want to open an account,accountOpenGeneral
I want to get a card,accountOpenGeneral
I want a loan,accountOpenGeneral
Refinance my car,accountOpenGeneral
Buy a car,accountOpenGeneral
Open checking,accountOpenGeneral
Open savings,accountOpenGeneral
Lease a vehicle,accountOpenGeneral
Link external bank account,accountTransferManage
verify external account,accountTransferManage
Add external account,accountTransferManage
Edit external account,accountTransferManage
Remove external account,accountTransferManage
Mortgage payment,billPaySchedulePayment
Setup Loan payment,billPaySchedulePayment
Setup auto loan payment,billPaySchedulePayment
Schedule bill payment,billPaySchedulePayment
Setup bill payment,billPaySchedulePayment
Setup automatic payment,billPaySchedulePayment
Setup auto pay,billPaySchedulePayment
Setup automatic payment,billPaySchedulePayment
Setup automatic payment,billPaySchedulePayment
Modify account details,editAccountDetails
Modify name on my account,editAccountDetails
Change address in my account,editAccountDetails`