chen700564 / sdnet

Other
38 stars 6 forks source link

Can you provide your BERT training code? I found you can get 47.6 in I2B2 #5

Open ToneLi opened 1 year ago

ToneLi commented 1 year ago

Can you provide your BERT training code? I found you can get 47.6 in I2B2. So thanks about it

chen700564 commented 1 year ago

There are BERT training codes, you may need to adjust them. By the way, I think the results rely on splited support set mostly, if you need our support set, you can email me.

from logging import log
from transformers import AutoModelForSeq2SeqLM,AutoModelForTokenClassification
from allennlp.models.model import Model
from allennlp.nn import InitializerApplicator
import torch.nn as nn
import torch
import os
import numpy as np
import json
from seqeval.metrics import f1_score, precision_score, recall_score
from allennlp.training.metrics.metric import Metric
from overrides import overrides
from seqeval.scheme import IOB2
from allennlp.data import DatasetReader
from allennlp.data.fields import LabelField, TextField, ArrayField, ListField, MetadataField
from allennlp.data.instance import Instance
import json
from transformers import AutoTokenizer
import numpy as np
import random
import copy

class Bertbiofewreader(DatasetReader):
    def __init__(self, N, K, Q, pretrainedfile=None, file=None, lazy=False) -> None:
        super().__init__(lazy)
        self.tokenizer = AutoTokenizer.from_pretrained(pretrainedfile)
        self.N = N
        self.K = K
        self.Q = Q
        if file is not None:
            self.init(file)

    def getinstance(self,data):
        text = ' '.join(data['tokens'])
        labels = [len(self.label2id) - 1] * len(data['tokens'])
        entitys = []
        for entity in data['entity']:
            if entity['type'] in self.target_classes:
                entitys.append(entity)
                for i in range(entity['offset'][0],entity['offset'][1]):
                    if i == entity['offset'][0]:
                        label = 'B-' + entity['type']
                    else:
                        label = 'I-' + entity['type']
                    labels[i] = self.label2id[label]
        label = [-1]
        tokenid = []
        for i in range(len(data['tokens'])):
            token = data['tokens'][i]
            inputid = self.tokenizer.encode(token,add_special_tokens=False)
            if len(tokenid) + len(inputid) > 510:
                break
            tokenid = tokenid + inputid
            for j in range(len(inputid)):
                if j == 0:
                    label.append(labels[i])
                else:
                    label.append(-1)
        label.append(-1)
        tokenid = [self.tokenizer.cls_token_id] + tokenid + [self.tokenizer.sep_token_id]
        mask = [1] * len(tokenid)
        field = {
            'inputid':ArrayField(np.array(tokenid)),
            'mask':ArrayField(np.array(mask)),
            'label':ListField([LabelField(int(i),skip_indexing=True) for i in label]),
            'text':MetadataField(text),
            'entity':MetadataField(entitys),
            'tokens':MetadataField(data['tokens'])
        }
        return Instance(field)

    def init(self,file,labels=None):
        self.dataset = []
        self.labels = []
        with open(file) as f:
            for line in f:
                line = json.loads(line)
                class_count = {}
                for entity in line['entity']:
                    if entity['type'] not in self.labels:
                        self.labels.append(entity['type'])
                    if entity['type'] not in class_count:
                        class_count[entity['type']] = 1
                    else:
                        class_count[entity['type']] += 1
                line['class_count'] = class_count
                self.dataset.append(line)
        if labels is not None:
            self.labels = labels
        self.target_classes = self.labels

    def buildlabel2id(self):
        self.label2id = {}
        for label in self.target_classes:
            if label != 'O':
                self.label2id['B-'+label] = len(self.label2id)
                self.label2id['I-'+label] = len(self.label2id)
            else:
                self.label2id[label] = len(self.label2id)

    def text_to_instance(self,idx=None,label=True,dataset=None,target_classes=None):
        results = []
        if dataset is not None:
            self.dataset = dataset
        if target_classes is not None:
            self.target_classes = target_classes
            self.buildlabel2id()
        if idx is None:
            for data in self.dataset:
                results.append(self.getinstance(data))
        else:
            for index in idx:
                results.append(self.getinstance(self.dataset[index]))
        if idx is None:
            idx = list(range(len(self.dataset)))
        return results

    def _read(self,file):
        with open(file) as f:
            for line in f:
                yield self.text_to_instance(json.loads(line))

def finetuning(model, reader, testreader, file, lr, batch_size, num_epochs, cuda_device=-1, modelfile = None,constraint='bio_decay'):
    from allennlp.data.dataloader import PyTorchDataLoader
    from transformers import AdamW
    from allennlp.training.trainer import GradientDescentTrainer
    from allennlp.training.learning_rate_schedulers.polynomial_decay import PolynomialDecay
    cpumodel = model
    if constraint == 'none':
        resultfile = modelfile + '/result.txt'
        recordfile = modelfile + '/record.txt'
    else:
        resultfile = modelfile + '/result_' + constraint + '.txt'
        recordfile = modelfile + '/record_' + constraint + '.txt'
    with open(resultfile,'w') as f1:
        with open(recordfile,'w') as f2:
            with open(file) as f:
                step = -1
                for line in f:
                    step += 1
                    dataset = json.loads(line)
                    support = dataset['support']
                    target_classes = dataset['target_label']
                    target_classes.append('O')
                    model = copy.deepcopy(cpumodel)
                    if 'bio' in constraint:
                        id2label = []
                        for label in target_classes:
                            if label != 'O':
                                id2label.append('B-'+label)
                                id2label.append('I-'+label)
                            else:
                                id2label.append(label)
                    else:
                        id2label = target_classes
                    model.f1.setlabel(id2label)
                    if 'query' in dataset:
                        query = dataset['query']
                        query_set = testreader.text_to_instance(dataset=query, target_classes = target_classes)
                    else:
                        query_set = testreader.text_to_instance(None, target_classes = target_classes)
                    print('step: '+str(step))
                    print('finetuning')

                    model = model.cuda(cuda_device)

                    model.train()
                    result = {
                        'pred':[],
                        'target_classes':target_classes,
                    }
                    support_set = reader.text_to_instance(dataset=support, target_classes = target_classes)
                    data_loader = PyTorchDataLoader(support_set,batch_size,shuffle=True)
                    parameters_to_optimize = list(model.named_parameters())
                    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
                    parameters_to_optimize = [
                            {'params': [p for n, p in parameters_to_optimize
                                        if not any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': 0.01},
                            {'params': [p for n, p in parameters_to_optimize
                                        if any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': 0.0},
                        ]
                    optimizer = AdamW(parameters_to_optimize, lr=lr, correct_bias=False)
                    if 'decay' in constraint:
                        warm = int(num_epochs * math.ceil(len(support_set)/batch_size) * 0.01)
                        learning_rate_scheduler = PolynomialDecay(optimizer,num_epochs,math.ceil(len(support_set)/batch_size),1,warm,0)
                    else:
                        learning_rate_scheduler = None
                    trainer = GradientDescentTrainer(
                                model=model,
                                optimizer=optimizer,
                                data_loader=data_loader,
                                num_epochs=num_epochs,
                                cuda_device=cuda_device,
                                learning_rate_scheduler=learning_rate_scheduler,
                            )
                    trainer.train()
                    del trainer
                    del optimizer
                    del parameters_to_optimize
                    print('pred')
                    model.eval()
                    model.f1.reset()
                    batchsize = 32
                    qnum = 0
                    with torch.no_grad():
                        for i in tqdm.tqdm(range(math.ceil(len(query_set)/batchsize))):
                            data = query_set[i*batchsize : (i+1) * batchsize]
                            y = model.forward_on_instances(data)
                            for j in range(len(data)):
                                preds = y[j]['pred']
                                gold = data[j]['entity'].as_tensor(data[j]['entity'].get_padding_lengths)
                                text = data[j]['text'].as_tensor(data[j]['text'].get_padding_lengths)
                                tokens = data[j]['tokens'].as_tensor(data[j]['tokens'].get_padding_lengths)
                                goldlabel = data[j]['label'].as_tensor(data[j]['label'].get_padding_lengths())
                                goldlabel = goldlabel.numpy()
                                if 'bio' in constraint:
                                    pred = model.f1.getbioentity(preds,goldlabel,tokens)
                                result['pred'].append(pred)
                                f1.write(str(qnum)+"\n"+str(text)+"\n"+str(gold)+"\n"+str(pred)+"\n")
                                qnum += 1
                    print(model.get_metrics())
                    model.model = None
                    model.to("cpu")
                    del model
                    del data_loader
                    del support_set
                    del query_set
                    del data
                    gc.collect()
                    torch.cuda.empty_cache()
                    f2.write(json.dumps(result)+"\n")

def getseqeval(pred, gold, id2label, bio=False):
    goldlabel = [[] for i in range(gold.shape[0])]
    predlabel = [[] for i in range(gold.shape[0])]
    for i in range(gold.shape[0]):
        for j in range(gold.shape[1]):
            if gold[i, j] >= 0:
                g = id2label[gold[i][j]]
                p = id2label[pred[i][j]]
                if not bio:
                    if g != 'O':
                        g = 'I-' + g
                    if p != 'O':
                        p = 'I-' + p
                goldlabel[i].append(g)
                predlabel[i].append(p)
    return goldlabel,predlabel

class F1(Metric):
    def __init__(self,bio=False) -> None:
        self.pred = []
        self.gold = []
        self.bio = bio

    def __call__(self, pred, gold):
        pred = pred.detach().cpu().numpy()
        gold = gold.detach().cpu().numpy()
        goldlabel,predlabel = getseqeval(pred, gold, self.id2label, self.bio)
        self.pred += predlabel
        self.gold += goldlabel

    def setlabel(self,id2label):
        self.id2label = id2label

    def getbioentity(self,pred,gold,tokens):
        entitys = []
        lastlabel = 'O'
        nowlabel = 'O'
        start = -1
        end = -1
        index = 0
        for i in range(gold.shape[0]):
            if gold[i] >= 0:
                label = self.id2label[pred[i]]
                if label[0] == 'B':
                    if nowlabel != 'O':
                        entitys.append({'text':' '.join(entity),'offset':[start,end+1],'type':nowlabel})
                        nowlabel = 'O'

                    entity = [tokens[index]]
                    start = index
                    end = index
                    lastlabel = label[2:]
                    nowlabel = label[2:]
                elif label[0] == 'I':
                    if lastlabel == label[2:]:
                        entity.append(tokens[index])
                        end = index
                    else:
                        if nowlabel != 'O':
                            entitys.append({'text':' '.join(entity),'offset':[start,end+1],'type':nowlabel})
                            nowlabel = 'O'
                            lastlabel = 'O'
                else:
                    if nowlabel != 'O':
                        entitys.append({'text':' '.join(entity),'offset':[start,end+1],'type':nowlabel})
                        nowlabel = 'O'
                    entity = []
                    start = -1
                    end = -1
                    lastlabel = 'O'
                index += 1
        if nowlabel != 'O':
            entitys.append({'text':' '.join(entity),'offset':[start,end+1],'type':nowlabel})
            nowlabel = 'O'
        return entitys

    def get_metric(self, reset: bool = False):
        result = {
            'p':precision_score(self.gold,self.pred,mode='strict',scheme=IOB2),
            'r':recall_score(self.gold,self.pred,mode='strict',scheme=IOB2),
            'f1':f1_score(self.gold,self.pred,mode='strict',scheme=IOB2),
        }
        if reset:
            self.reset()
        return result

    @overrides
    def reset(self):
        self.pred = []
        self.gold = []

class Tagmodel(Model):
    def __init__(self, pretrainedfile,num_labels,bio=False,
                 initializer: InitializerApplicator = InitializerApplicator()):
        super(Tagmodel, self).__init__(None, None)
        self.model = AutoModelForTokenClassification.from_pretrained(pretrainedfile,num_labels=num_labels)
        self.config = self.model.config
        self.f1 = F1(bio=bio)
        InitializerApplicator(self)

    def get_metrics(self, reset=False):
        if not self.training:
            return self.f1.get_metric(reset)
        else:
            return {'f1':0}

    def forward(self,inputid,mask=None,label=None,**kargs):
        inputid = inputid.long()
        if label is not None:
            label = label.masked_fill(label==-1,-100)
            output_dict = self.model(input_ids = inputid, attention_mask = mask, labels = label,return_dict=True)
            logits = output_dict['logits']
            pred = torch.argmax(logits,dim=-1)
            output_dict['pred'] = pred
            if not self.training:
                self.f1(pred,label)
            self.get_metrics()
        else:
            output_dict = self.model(input_ids = inputid, attention_mask = mask,return_dict=True)

        return output_dict

You can refer the codes which include reader, model and metric. The function "finetuning" is used to fine-tune and evalute the few-shot dataset.

You should adjust the codes to run it, if the codes work, you can run as:

N,K,Q = 5,5,5 # only K is useful
dataset = 'i2b2'
pretrain = 'bert-base-uncased'
file = 'data/'+dataset+'/'+str(K)+'shot.json'
testreader = Bertbiofewreader(N,K,Q,pretrain, 'data/' + dataset + 'test.json')
readaer= Bertbiofewreader(N,K,Q,pretrain, 'data/' + dataset + '/train.json')

batch_size = 8
lr = 2e-5
num_epochs = 50
finetuning(model, reader, testreader, file, lr, batch_size, num_epochs, cuda_device=0, modelfile = modelfile,constraint='bio_decay')
ToneLi commented 1 year ago

so thanks for that!!