Open ps3-app opened 3 years ago
Same as multiclass classification with few modifications.
Here are some of the steps that might help you achieve the goal:-
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
class SentimentClassifier(nn.Module): def init(self, pretrained_model_name, num_classes=2): super(SentimentClassifier, self).init() self.bert = BertModel.from_pretrained(pretrained_model_name) self.dropout = nn.Dropout(0.1) self.linear = nn.Linear(self.bert.config.hidden_size, num_classes)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
pooled_output = self.dropout(pooled_output)
logits = self.linear(pooled_output)
return logits
pretrained_model_name = 'bert-base-uncased' tokenizer = BertTokenizer.from_pretrained(pretrained_model_name)
text = "This movie is really good!" labels = torch.tensor([1]) # 1 for positive sentiment
inputs = tokenizer(text, return_tensors='pt') input_ids = inputs['input_ids'] attention_mask = inputs['attention_mask']
model = SentimentClassifier(pretrained_model_name, num_classes=1) # Binary classification
logits = model(input_ids, attention_mask)
criterion = nn.BCEWithLogitsLoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
labels = labels.float()
loss = criterion(logits.view(-1), labels)
optimizer.zero_grad() loss.backward() optimizer.step()
predictions = torch.sigmoid(logits) > 0.5 # Threshold at 0.5 predicted_labels = predictions.long()
print("Predicted label:", predicted_labels.item())
Hope this helps
Thanks
I use sentiment analysis with bert, however it is multiclass classification, how to change for binary class text classification.