courao / ocr.pytorch

A pure pytorch implemented ocr project including text detection and recognition
MIT License
583 stars 133 forks source link

更改crnn_recognizer.py报错 #15

Closed wangguanhua closed 5 years ago

wangguanhua commented 5 years ago

您好,我和前面的朋友遇见的问题一样,修改crnn_recognizer.py文件的第100行def init(self, model_path='/root/zjut/ocr.pytorch/checkpoints/CRNN.pth')。当我执行'python demo.py'命令出错,显示如下: Traceback (most recent call last): File "/root/.vscode-server/extensions/ms-python.python-2019.11.50794/pythonFiles/ptvsd_launcher.py", line 43, in main(ptvsdArgs) File "/root/.vscode-server/extensions/ms-python.python-2019.11.50794/pythonFiles/lib/python/old_ptvsd/ptvsd/main.py", line 432, in main run() File "/root/.vscode-server/extensions/ms-python.python-2019.11.50794/pythonFiles/lib/python/old_ptvsd/ptvsd/main.py", line 316, in run_file runpy.run_path(target, run_name='main') File "/root/anaconda3/lib/python3.6/runpy.py", line 263, in run_path pkg_name=pkg_name, script_name=fname) File "/root/anaconda3/lib/python3.6/runpy.py", line 96, in _run_module_code mod_name, mod_spec, pkg_name, script_name) File "/root/anaconda3/lib/python3.6/runpy.py", line 85, in _run_code exec(code, run_globals) File "/root/zjut/ocr.pytorch/demo.py", line 10, in from ocr import ocr File "/root/zjut/ocr.pytorch/ocr.py", line 6, in recognizer = PytorchOcr() File "/root/zjut/ocr.pytorch/recognize/crnn_recognizer.py", line 111, in init self.model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_path).items()}) File "/root/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 845, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for CRNN: Missing key(s) in state_dict: "conv1.weight", "conv1.bias", "conv2.weight", "conv2.bias", "conv3_1.weight", "conv3_1.bias", "bn3.weight", "bn3.bias", "bn3.running_mean", "bn3.running_var", "conv3_2.weight", "conv3_2.bias", "conv4_1.weight", "conv4_1.bias", "bn4.weight", "bn4.bias", "bn4.running_mean", "bn4.running_var", "conv4_2.weight", "conv4_2.bias", "conv5.weight", "conv5.bias", "bn5.weight", "bn5.bias", "bn5.running_mean", "bn5.running_var". Unexpected key(s) in state_dict: "cnn.conv0.weight", "cnn.conv0.bias", "cnn.conv1.weight", "cnn.conv1.bias", "cnn.conv2.weight", "cnn.conv2.bias", "cnn.batchnorm2.weight", "cnn.batchnorm2.bias", "cnn.batchnorm2.running_mean", "cnn.batchnorm2.running_var", "cnn.batchnorm2.num_batches_tracked", "cnn.conv3.weight", "cnn.conv3.bias", "cnn.conv4.weight", "cnn.conv4.bias", "cnn.batchnorm4.weight", "cnn.batchnorm4.bias", "cnn.batchnorm4.running_mean", "cnn.batchnorm4.running_var", "cnn.batchnorm4.num_batches_tracked", "cnn.conv5.weight", "cnn.conv5.bias", "cnn.conv6.weight", "cnn.conv6.bias", "cnn.batchnorm6.weight", "cnn.batchnorm6.bias", "cnn.batchnorm6.running_mean", "cnn.batchnorm6.running_var", "cnn.batchnorm6.num_batches_tracked". size mismatch for rnn.1.embedding.weight: copying a param with shape torch.Size([5997, 512]) from checkpoint, the shape in current model is torch.Size([5835, 512]). size mismatch for rnn.1.embedding.bias: copying a param with shape torch.Size([5997]) from checkpoint, the shape in current model is torch.Size([5835]). 其中CRNN.pth是您度盘所提供的。

courao commented 5 years ago

你是改了加载的模型?我记得的默认是checkpoints/CRNN.pth 用百度网盘里的CRNN-1010.pth这个模型应该就没问题

wangguanhua commented 5 years ago

我这个跑的时候必须要绝对路径不能用相对路径(相对路径会报错)。至于模型,CRNN-1010.pth是没问题的,但您度盘里不是也提供了CRNN.pth吗,我想试试看这个。而且我之前也训练过一个,然后也报错。

courao commented 5 years ago

如果用相对路径的话,你需要进入的ocr.pytorch的目录下(cd /root/zjut/ocr.pytorch/)跑python demo.py应该就行, CRNN.pth是之前的api用的模型,最近更新的这版模型和字典有一些变化,所以不通用会报错 你之前训练是用这个代码训练的吗,如果字典是一样的话应该是不会报错的呀。

wangguanhua commented 5 years ago

是的,我发现了字典的格式不同。顺带问一下您训练crnn的数据集有多大,我这边准备的是360w的开源数据集,要是还没您训练所用的大,那我再去找一些其他的。

courao commented 5 years ago

我用的也是开源的,好像还没你的多,你可以先训着看看效果,不行的话再加别的

wangguanhua commented 5 years ago

好的,多谢您了。

courao commented 5 years ago

哈哈不客气