Closed TGLTommy closed 2 years ago
你好,通过调用inference.py中的load_model_directly()方法,无法加载训练的模型,具体代码如下:
① 代码部分:
def load_model_directly(): ckpt_file = 'SoftMaskedBert/epoch=05-val_loss=0.03253.ckpt' config_file = 'csc/train_SoftMaskedBert.yml'
from bbcm.config import cfg cp = get_abs_path('checkpoints', ckpt_file) cfg.merge_from_file(get_abs_path('configs', config_file)) tokenizer = BertTokenizer.from_pretrained(cfg.MODEL.BERT_CKPT) print("###tokenizer加载完毕") print("### tokenizer: ", tokenizer) if cfg.MODEL.NAME in ['bert4csc', 'macbert4csc']: model = BertForCsc.load_from_checkpoint(cp, cfg=cfg, tokenizer=tokenizer) else: print("###加载模型") print("###cp : ", cp) model = SoftMaskedBertModel.load_from_checkpoint(cp, cfg=cfg, tokenizer=tokenizer) print("###model加载完毕") model.eval() model.to(cfg.MODEL.DEVICE) return model
② 问题: 感觉这段代码没有起作用,ckpt文件无法加载,程序还是自动从huggingface下载了。 model = SoftMaskedBertModel.load_from_checkpoint(cp, cfg=cfg, tokenizer=tokenizer) 我查了一下load_from_checkpoint() 方法,对于参数cp, cfg的传递,没有看明白。
你好,通过调用inference.py中的load_model_directly()方法,无法加载训练的模型,具体代码如下:
① 代码部分:
def load_model_directly(): ckpt_file = 'SoftMaskedBert/epoch=05-val_loss=0.03253.ckpt' config_file = 'csc/train_SoftMaskedBert.yml'
② 问题: 感觉这段代码没有起作用,ckpt文件无法加载,程序还是自动从huggingface下载了。 model = SoftMaskedBertModel.load_from_checkpoint(cp, cfg=cfg, tokenizer=tokenizer) 我查了一下load_from_checkpoint() 方法,对于参数cp, cfg的传递,没有看明白。