gitabtion / BertBasedCorrectionModels

PyTorch impelementations of BERT-based Spelling Error Correction Models. 基于BERT的文本纠错模型,使用PyTorch实现。
Apache License 2.0
265 stars 43 forks source link

无法加载训练的模型,程序自动从HuggingFace下载模型,这是什么原因? #35

Closed TGLTommy closed 2 years ago

TGLTommy commented 2 years ago

你好,通过调用inference.py中的load_model_directly()方法,无法加载训练的模型,具体代码如下:

① 代码部分:

def load_model_directly(): ckpt_file = 'SoftMaskedBert/epoch=05-val_loss=0.03253.ckpt' config_file = 'csc/train_SoftMaskedBert.yml'

from bbcm.config import cfg
cp = get_abs_path('checkpoints', ckpt_file)
cfg.merge_from_file(get_abs_path('configs', config_file))
tokenizer = BertTokenizer.from_pretrained(cfg.MODEL.BERT_CKPT)
print("###tokenizer加载完毕")
print("### tokenizer: ", tokenizer)
if cfg.MODEL.NAME in ['bert4csc', 'macbert4csc']:
    model = BertForCsc.load_from_checkpoint(cp,
                                            cfg=cfg,
                                            tokenizer=tokenizer)
else:
    print("###加载模型")
    print("###cp : ", cp)
    model = SoftMaskedBertModel.load_from_checkpoint(cp,
                                                     cfg=cfg,
                                                     tokenizer=tokenizer)
print("###model加载完毕")
model.eval()
model.to(cfg.MODEL.DEVICE)
return model

② 问题: 感觉这段代码没有起作用,ckpt文件无法加载,程序还是自动从huggingface下载了。 model = SoftMaskedBertModel.load_from_checkpoint(cp, cfg=cfg, tokenizer=tokenizer) 我查了一下load_from_checkpoint() 方法,对于参数cp, cfg的传递,没有看明白。