thunlp / OpenNRE

An Open-Source Package for Neural Relation Extraction (NRE)
MIT License
4.34k stars 1.05k forks source link

Bert+attention -2 #277

Closed hjing100 closed 4 years ago

hjing100 commented 4 years ago


import sys, json import torch import os import numpy as np import opennre from opennre import encoder, model, framework import sys import os import argparse import logging

parser = argparse.ArgumentParser() parser.add_argument('--pretrain_path', default='bert-base-uncased', help='Pre-trained ckpt path / model name (hugginface)') parser.add_argument('--ckpt', default='', help='Checkpoint name') parser.add_argument('--only_test', action='store_true', help='Only run test') parser.add_argument('--pooler', default='entity', choices=['cls', 'entity'], help='Sentence representation pooler') parser.add_argument('--mask_entity', action='store_true', help='Mask entity mentions')


parser.add_argument('--metric', default='auc', choices=['micro_f1', 'auc'], help='Metric for picking up best checkpoint') parser.add_argument('--dataset', default='none', choices=['none', 'wiki_distant', 'nyt10'], help='Dataset. If not none, the following args can be ignored') parser.add_argument('--train_file', default='', type=str, help='Training data file') parser.add_argument('--val_file', default='', type=str, help='Validation data file') parser.add_argument('--test_file', default='', type=str, help='Test data file') parser.add_argument('--rel2id_file', default='', type=str, help='Relation to ID file')

Bag related

parser.add_argument('--bag_size', type=int, default=4, help='Fixed bag size. If set to 0, use original bag sizes')


parser.add_argument('--batch_size', default=16, type=int, help='Batch size') parser.add_argument('--lr', default=2e-5, type=float, help='Learning rate') parser.add_argument('--max_length', default=128, type=int, help='Maximum sentence length') parser.add_argument('--max_epoch', default=3, type=int, help='Max number of training epochs')

args = parser.parse_args()

Some basic settings

root_path = '.' sys.path.append(rootpath) if not os.path.exists('ckpt'): os.mkdir('ckpt') if len(args.ckpt) == 0: args.ckpt = '{}{}'.format(args.dataset, 'pcnn_att') ckpt = 'ckpt/{}.pth.tar'.format(args.ckpt)

''' if args.dataset != 'none':, root_path=root_path) args.train_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_train.txt'.format(args.dataset)) args.val_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_val.txt'.format(args.dataset)) args.test_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_test.txt'.format(args.dataset)) args.rel2id_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_rel2id.json'.format(args.dataset)) else: if not (os.path.exists(args.train_file) and os.path.exists(args.val_file) and os.path.exists(args.test_file) and os.path.exists(args.rel2id_file)): raise Exception('--train_file, --val_file, --test_file and --rel2id_file are not specified or files do not exist. Or specify --dataset') ''''Arguments:') for arg in vars(args):' {}: {}'.format(arg, getattr(args, arg)))

rel2id = json.load(open(args.rel2id_file))

rel2id = json.load(open('./data/company-relation/people-relation_rel2id.json', encoding='utf-8'))

Define the sentence encoder

sentence_encoder = opennre.encoder.BERTEncoder( max_length=args.max_length, pretrain_path='./pretrain/chinese_wwm_pytorch', # args.pretrain_path, mask_entity=args.mask_entity ) ''' if args.pooler == 'entity': sentence_encoder = opennre.encoder.BERTEntityEncoder( max_length=args.max_length, pretrain_path=args.pretrain_path, mask_entity=args.mask_entity ) elif args.pooler == 'cls': sentence_encoder = opennre.encoder.BERTEncoder( max_length=args.max_length, pretrain_path=args.pretrain_path, mask_entity=args.mask_entity ) else: raise NotImplementedError '''

Define the model

model = opennre.model.BagAttention(sentence_encoder, len(rel2id), rel2id)

Define the whole training framework

framework = opennre.framework.BagRE( train_path='people-relation_distant_train.txt', #args.train_file, val_path='people-relation_distant_val.txt', # args.val_file, test_path='people-relation_distant_val.txt', # args.test_file, model=model, ckpt=ckpt, batch_size=args.batch_size, max_epoch=args.max_epoch,, opt='adamw', bag_size=args.bag_size)

Train the model

if not args.only_test: framework.train_model(args.metric)

Test the model

framework.load_state_dict(torch.load(ckpt)['state_dict']) result = framework.eval_model(framework.test_loader)

Print the result'Test set results:')'AUC: {}'.format(result['auc']))'Micro F1: {}'.format(result['micro_f1']))

hjing100 commented 4 years ago

I run this .py file, and has an error. Traceback (most recent call last): File "", line 126, in framework.train_model(args.metric) TypeError: train_model() takes 1 positional argument but 2 were given

In my opinion, simply changing PCNN encoder to BERT is not OK.

gaotianyu1350 commented 4 years ago

It seems that you are not using the latest repo. Please pull the latest update and try again.

hjing100 commented 4 years ago

Yes,I don't pull the latest update . And I create the BERTBagInputEncoder, but have a RuntimeError, so change batch_size=64 to batch_size=16, and the error is solved. I will pull the latest update and try again later. Thank you for your patient answers

RuntimeError: CUDA out of memory. Tried to allocate 46.00 MiB (GPU 0; 7.79 GiB total capacity; ted; 17.62 MiB free; 7.02 GiB reserved in total by PyTorch)

hjing100 commented 4 years ago

And I realize the bert model don't use pos_embedding in BERTEncoder ?

_, x = self.bert(token, attention_mask=att_mask) return x

hjing100 commented 4 years ago

I pull the latest update and try again . And still have TypeError.

Traceback (most recent call last): File "", line 59, in framework.train_model() File "/opennre/framework/", line 121, in train_model logits = self.model(label, scope, args, bag_size=self.bag_size) File "/root/miniconda3/envs/python36j/lib/python3.6/site-packages/torch/nn/modules/", line 722, in _call_impl result = self.forward(input, kwargs) File "/root/miniconda3/envs/python36j/lib/python3.6/site-packages/torch/nn/parallel/", line 153, in forward return self.module(*inputs[0], *kwargs[0]) File "/root/miniconda3/envs/python36j/lib/python3.6/site-packages/torch/nn/modules/", line 722, in _call_impl result = self.forward(input, kwargs) TypeError: forward() missing 1 required positional argument: 'pos2'

gaotianyu1350 commented 4 years ago

Maybe you should reinstall opennre after pulling the latest update.

pip uninstall opennre
python develop
DaddyJin commented 3 years ago

Hi, @hjing100 Have you sovle the probelm yet? I also encountered the problem " TypeError: forward() missing 1 required positional argument: 'pos2' "?