Closed hjing100 closed 4 years ago
I run this .py file, and has an error.
Traceback (most recent call last):
File "train_company_chinese_bert_attention_git_author.py", line 126, in
In my opinion, simply changing PCNN encoder to BERT is not OK.
It seems that you are not using the latest repo. Please pull the latest update and try again.
Yes,I don't pull the latest update . And I create the BERTBagInputEncoder, but have a RuntimeError, so change batch_size=64 to batch_size=16, and the error is solved. I will pull the latest update and try again later. Thank you for your patient answers
RuntimeError: CUDA out of memory. Tried to allocate 46.00 MiB (GPU 0; 7.79 GiB total capacity; ted; 17.62 MiB free; 7.02 GiB reserved in total by PyTorch)
And I realize the bert model don't use pos_embedding in BERTEncoder ?
_, x = self.bert(token, attention_mask=att_mask) return x
I pull the latest update and try again . And still have TypeError.
Traceback (most recent call last):
File "train_people_chinese_bert_attention.py", line 59, in
Maybe you should reinstall opennre
after pulling the latest update.
pip uninstall opennre
python setup.py develop
Hi, @hjing100 Have you sovle the probelm yet? I also encountered the problem " TypeError: forward() missing 1 required positional argument: 'pos2' "?
coding:utf-8
import sys, json import torch import os import numpy as np import opennre from opennre import encoder, model, framework import sys import os import argparse import logging
parser = argparse.ArgumentParser() parser.add_argument('--pretrain_path', default='bert-base-uncased', help='Pre-trained ckpt path / model name (hugginface)') parser.add_argument('--ckpt', default='', help='Checkpoint name') parser.add_argument('--only_test', action='store_true', help='Only run test') parser.add_argument('--pooler', default='entity', choices=['cls', 'entity'], help='Sentence representation pooler') parser.add_argument('--mask_entity', action='store_true', help='Mask entity mentions')
Data
parser.add_argument('--metric', default='auc', choices=['micro_f1', 'auc'], help='Metric for picking up best checkpoint') parser.add_argument('--dataset', default='none', choices=['none', 'wiki_distant', 'nyt10'], help='Dataset. If not none, the following args can be ignored') parser.add_argument('--train_file', default='', type=str, help='Training data file') parser.add_argument('--val_file', default='', type=str, help='Validation data file') parser.add_argument('--test_file', default='', type=str, help='Test data file') parser.add_argument('--rel2id_file', default='', type=str, help='Relation to ID file')
Bag related
parser.add_argument('--bag_size', type=int, default=4, help='Fixed bag size. If set to 0, use original bag sizes')
Hyper-parameters
parser.add_argument('--batch_size', default=16, type=int, help='Batch size') parser.add_argument('--lr', default=2e-5, type=float, help='Learning rate') parser.add_argument('--max_length', default=128, type=int, help='Maximum sentence length') parser.add_argument('--max_epoch', default=3, type=int, help='Max number of training epochs')
args = parser.parse_args()
Some basic settings
root_path = '.' sys.path.append(rootpath) if not os.path.exists('ckpt'): os.mkdir('ckpt') if len(args.ckpt) == 0: args.ckpt = '{}{}'.format(args.dataset, 'pcnn_att') ckpt = 'ckpt/{}.pth.tar'.format(args.ckpt)
''' if args.dataset != 'none': opennre.download(args.dataset, root_path=root_path) args.train_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_train.txt'.format(args.dataset)) args.val_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_val.txt'.format(args.dataset)) args.test_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_test.txt'.format(args.dataset)) args.rel2id_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_rel2id.json'.format(args.dataset)) else: if not (os.path.exists(args.train_file) and os.path.exists(args.val_file) and os.path.exists(args.test_file) and os.path.exists(args.rel2id_file)): raise Exception('--train_file, --val_file, --test_file and --rel2id_file are not specified or files do not exist. Or specify --dataset') '''
logging.info('Arguments:') for arg in vars(args): logging.info(' {}: {}'.format(arg, getattr(args, arg)))
rel2id = json.load(open(args.rel2id_file))
rel2id = json.load(open('./data/company-relation/people-relation_rel2id.json', encoding='utf-8'))
Define the sentence encoder
sentence_encoder = opennre.encoder.BERTEncoder( max_length=args.max_length, pretrain_path='./pretrain/chinese_wwm_pytorch', # args.pretrain_path, mask_entity=args.mask_entity ) ''' if args.pooler == 'entity': sentence_encoder = opennre.encoder.BERTEntityEncoder( max_length=args.max_length, pretrain_path=args.pretrain_path, mask_entity=args.mask_entity ) elif args.pooler == 'cls': sentence_encoder = opennre.encoder.BERTEncoder( max_length=args.max_length, pretrain_path=args.pretrain_path, mask_entity=args.mask_entity ) else: raise NotImplementedError '''
Define the model
model = opennre.model.BagAttention(sentence_encoder, len(rel2id), rel2id)
Define the whole training framework
framework = opennre.framework.BagRE( train_path='people-relation_distant_train.txt', #args.train_file, val_path='people-relation_distant_val.txt', # args.val_file, test_path='people-relation_distant_val.txt', # args.test_file, model=model, ckpt=ckpt, batch_size=args.batch_size, max_epoch=args.max_epoch, lr=args.lr, opt='adamw', bag_size=args.bag_size)
Train the model
if not args.only_test: framework.train_model(args.metric)
Test the model
framework.load_state_dict(torch.load(ckpt)['state_dict']) result = framework.eval_model(framework.test_loader)
Print the result
logging.info('Test set results:') logging.info('AUC: {}'.format(result['auc'])) logging.info('Micro F1: {}'.format(result['micro_f1']))