Closed wangguanhua closed 5 years ago
你是改了加载的模型?我记得的默认是checkpoints/CRNN.pth 用百度网盘里的CRNN-1010.pth这个模型应该就没问题
我这个跑的时候必须要绝对路径不能用相对路径(相对路径会报错)。至于模型,CRNN-1010.pth是没问题的,但您度盘里不是也提供了CRNN.pth吗,我想试试看这个。而且我之前也训练过一个,然后也报错。
如果用相对路径的话,你需要进入的ocr.pytorch的目录下(cd /root/zjut/ocr.pytorch/)跑python demo.py应该就行, CRNN.pth是之前的api用的模型,最近更新的这版模型和字典有一些变化,所以不通用会报错 你之前训练是用这个代码训练的吗,如果字典是一样的话应该是不会报错的呀。
是的,我发现了字典的格式不同。顺带问一下您训练crnn的数据集有多大,我这边准备的是360w的开源数据集,要是还没您训练所用的大,那我再去找一些其他的。
我用的也是开源的,好像还没你的多,你可以先训着看看效果,不行的话再加别的
好的,多谢您了。
哈哈不客气
您好,我和前面的朋友遇见的问题一样,修改crnn_recognizer.py文件的第100行def init(self, model_path='/root/zjut/ocr.pytorch/checkpoints/CRNN.pth')。当我执行'python demo.py'命令出错,显示如下: Traceback (most recent call last): File "/root/.vscode-server/extensions/ms-python.python-2019.11.50794/pythonFiles/ptvsd_launcher.py", line 43, in
main(ptvsdArgs)
File "/root/.vscode-server/extensions/ms-python.python-2019.11.50794/pythonFiles/lib/python/old_ptvsd/ptvsd/main.py", line 432, in main
run()
File "/root/.vscode-server/extensions/ms-python.python-2019.11.50794/pythonFiles/lib/python/old_ptvsd/ptvsd/main.py", line 316, in run_file
runpy.run_path(target, run_name='main')
File "/root/anaconda3/lib/python3.6/runpy.py", line 263, in run_path
pkg_name=pkg_name, script_name=fname)
File "/root/anaconda3/lib/python3.6/runpy.py", line 96, in _run_module_code
mod_name, mod_spec, pkg_name, script_name)
File "/root/anaconda3/lib/python3.6/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/root/zjut/ocr.pytorch/demo.py", line 10, in
from ocr import ocr
File "/root/zjut/ocr.pytorch/ocr.py", line 6, in
recognizer = PytorchOcr()
File "/root/zjut/ocr.pytorch/recognize/crnn_recognizer.py", line 111, in init
self.model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_path).items()})
File "/root/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 845, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for CRNN:
Missing key(s) in state_dict: "conv1.weight", "conv1.bias", "conv2.weight", "conv2.bias", "conv3_1.weight", "conv3_1.bias", "bn3.weight", "bn3.bias", "bn3.running_mean", "bn3.running_var", "conv3_2.weight", "conv3_2.bias", "conv4_1.weight", "conv4_1.bias", "bn4.weight", "bn4.bias", "bn4.running_mean", "bn4.running_var", "conv4_2.weight", "conv4_2.bias", "conv5.weight", "conv5.bias", "bn5.weight", "bn5.bias", "bn5.running_mean", "bn5.running_var".
Unexpected key(s) in state_dict: "cnn.conv0.weight", "cnn.conv0.bias", "cnn.conv1.weight", "cnn.conv1.bias", "cnn.conv2.weight", "cnn.conv2.bias", "cnn.batchnorm2.weight", "cnn.batchnorm2.bias", "cnn.batchnorm2.running_mean", "cnn.batchnorm2.running_var", "cnn.batchnorm2.num_batches_tracked", "cnn.conv3.weight", "cnn.conv3.bias", "cnn.conv4.weight", "cnn.conv4.bias", "cnn.batchnorm4.weight", "cnn.batchnorm4.bias", "cnn.batchnorm4.running_mean", "cnn.batchnorm4.running_var", "cnn.batchnorm4.num_batches_tracked", "cnn.conv5.weight", "cnn.conv5.bias", "cnn.conv6.weight", "cnn.conv6.bias", "cnn.batchnorm6.weight", "cnn.batchnorm6.bias", "cnn.batchnorm6.running_mean", "cnn.batchnorm6.running_var", "cnn.batchnorm6.num_batches_tracked".
size mismatch for rnn.1.embedding.weight: copying a param with shape torch.Size([5997, 512]) from checkpoint, the shape in current model is torch.Size([5835, 512]).
size mismatch for rnn.1.embedding.bias: copying a param with shape torch.Size([5997]) from checkpoint, the shape in current model is torch.Size([5835]).
其中CRNN.pth是您度盘所提供的。