boncey / Flickr4Java

Java API For Flickr. Fork of FlickrJ
BSD 2-Clause "Simplified" License
176 stars 155 forks source link

Bert for text classification #527

Closed insignia1 closed 2 years ago

insignia1 commented 3 years ago

I am trying to build a BERT model for text classification with the help of this code [https://towardsdatascience.com/bert-text-classification-using-pytorch-723dfb8b6b5b]. My dataset contains two columns(label, text). The labels can have three values of (0,1,2). The code works without any error but all values of confusion matrix are 0. Is there something wrong with my code?

import matplotlib.pyplot as plt
import pandas as pd
import torch
from torchtext.data import Field, TabularDataset, BucketIterator, Iterator
import torch.nn as nn
from transformers import BertTokenizer, BertForSequenceClassification
import torch.optim as optim
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

import seaborn as sns

torch.manual_seed(42)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

MAX_SEQ_LEN = 128
PAD_INDEX = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
UNK_INDEX = tokenizer.convert_tokens_to_ids(tokenizer.unk_token)

label_field = Field(sequential=False, use_vocab=False, batch_first=True, dtype=torch.float)
text_field = Field(use_vocab=False, tokenize=tokenizer.encode, lower=False, include_lengths=False, batch_first=True, fix_length=MAX_SEQ_LEN, pad_token=PAD_INDEX, unk_t>
fields = [('label', label_field), ('text', text_field)]
CLASSIFICATION_REPORT = "classification_report.jsonl"

train, valid, test = TabularDataset.splits(path='', train='train.csv', validation='validate.csv', test='test.csv', format='CSV', fields=fields, skip_header=True)

train_iter = BucketIterator(train, batch_size=16, sort_key=lambda x: len(x.text), device=device, train=True, sort=True, sort_within_batch=True)
valid_iter = BucketIterator(valid, batch_size=16, sort_key=lambda x: len(x.text), device=device, train=True, sort=True, sort_within_batch=True)
test_iter = Iterator(test, batch_size=16, device=device, train=False, shuffle=False, sort=False)

class BERT(nn.Module):
        def __init__(self):
                super(BERT, self).__init__()
                options_name = "bert-base-uncased"
                self.encoder = BertForSequenceClassification.from_pretrained(options_name, num_labels = 3)

        def forward(self, text, label):
                loss, text_fea = self.encoder(text, labels=label)[:2]
                return loss, text_fea

def train(model, optimizer, criterion = nn.BCELoss(), train_loader = train_iter, valid_loader = valid_iter, num_epochs = 5, eval_every = len(train_iter) // 2, file_pat>        running_loss = 0.0
        valid_running_loss = 0.0
        global_step = 0
        train_loss_list = []
        valid_loss_list = []
        global_steps_list = []

        model.train()

        for epoch in range(num_epochs):
                for (label, text), _ in train_loader:
                        label = label.type(torch.LongTensor)
                        label = label.to(device)
                        text = text.type(torch.LongTensor)
                        text = text.to(device)
                        output = model(text, label)
                        loss, _ = output
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                        running_loss += loss.item()
                        global_step += 1
                        if global_step % eval_every == 0:
                                model.eval()
                                with torch.no_grad():
                                        for (label, text), _ in valid_loader:
                                                label = label.type(torch.LongTensor)
                                                label = label.to(device)
                                                text = text.type(torch.LongTensor)
                                                text = text.to(device)
                                                output = model(text, label)
                                                loss, _ = output
                                                valid_running_loss += loss.item()

                                average_train_loss = running_loss / eval_every
                                average_valid_loss = valid_running_loss / len(valid_loader)
                                train_loss_list.append(average_train_loss)
                                valid_loss_list.append(average_valid_loss)
                                global_steps_list.append(global_step)

                                # resetting running values
                                running_loss = 0.0
                                valid_running_loss = 0.0
                                model.train()

                                # print progress
                                print('Epoch [{}/{}], Step [{}/{}], Train Loss: {:.4f}, Valid Loss: {:.4f}'.format(epoch+1, num_epochs, global_step, num_epochs*len(tra>
                                if best_valid_loss > average_valid_loss:
                                        best_valid_loss = average_valid_loss
        print('Finished Training!')

model = BERT().to(device)
optimizer = optim.Adam(model.parameters(), lr=2e-5)

train(model=model, optimizer=optimizer)

def evaluate(model, test_loader):
        y_pred = []
        y_true = []
        model.eval()
        with torch.no_grad():
                for (label, text), _ in test_loader:
                        label = label.type(torch.LongTensor)
                        label = label.to(device)
                        text = text.type(torch.LongTensor)
                        text = text.to(device)
                        output = model(text, label)

                        _, output = output
                        y_pred.extend(torch.argmax(output, 1).tolist())
                        y_true.extend(label.tolist())
        print('Classification Report:')
        print(classification_report(y_true, y_pred, labels=[1,0], digits=4))
best_model = BERT().to(device)
evaluate(best_model, test_iter)