Holmeyoung / crnn-pytorch

Pytorch implementation of CRNN (CNN + RNN + CTCLoss) for all language OCR.
MIT License
378 stars 105 forks source link

字典里面字符的排列顺序重要吗? #28

Closed deep-practice closed 5 years ago

deep-practice commented 5 years ago

碰到了一个很奇怪的事情,字典内容相同,把数字分别放在中文前面和中文后面,效果完全不同。当数字放在中文后面,然后训练,数字识别率为0,当数字放在中文前面,数字能被识别出来。求大佬解惑。 @Holmeyoung

Holmeyoung commented 5 years ago

这个我没有测试呢,但是整个内容是以二进制存储的,汉字会占多位,会不会和这个有关呢~有详细的信息吗,这个应该好好看一下

deep-practice commented 5 years ago

看了大佬解码部分,貌似是最优路径解码,有试过更复杂的解码算法吗?

deep-practice commented 5 years ago

还有就是我自己的字典数比提供的模型大,finetune的时候,面临2个选择 1)前面cnn部分不进行参与梯度计算 2)前面cnn部分用很小的学习率学习,后面的RNN用较大的学习率学习 从经验来看,哪种方式好一点呢,主要是设备渣渣,跑一次需要的时间比较长,目前我用了第一种,rnn学习率设为0.01,训练过程中有必要去调整学习率吗?

deep-practice commented 5 years ago

我字典里面有字符'-',程序中空格也是用'-'表示,不冲突吧?

Holmeyoung commented 5 years ago

嗨, 第一个问题,是指 decode 部分吗

第二个问题,其实有一种暴力的 crnn 是直接 lstm 的,所以咱们前面加的 cnn 也是为了更好的提取特征,因为cnn部分没有提供很好的预训练权重,所以还是建议参与更新。学习率建议先大后小,先快速收敛,然后寻找最优~~~具体可以看看https://github.com/Holmeyoung/crnn-pytorch/issues/2

第三个问题,不影响训练,会影响显示,

    def __init__(self, alphabet, ignore_case=False):
        self._ignore_case = ignore_case
        if self._ignore_case:
            alphabet = alphabet.lower()
        self.alphabet = alphabet + '-'  # for `-1` index

        self.dict = {}
        for i, char in enumerate(alphabet):
            # NOTE: 0 is reserved for 'blank' required by wrap_ctc
            self.dict[char] = i + 1

因为 self.dictdict 结构,你字符里面的 - 必然会被冲掉。因为在 ctcloss 中 blank 被 0 表示,所以 encode 的时候需要后移一位,解码前移一位,0 被解码为 -1 ,也就是 - 也就是 空,也就是说,val的时候 空 和 - 都被 - 显示了

deep-practice commented 5 years ago

1.对,是指decoder 2.那我把空白字符换种表示方法就可以了吧?

Holmeyoung commented 5 years ago

关于decode要是有更好的思路提一个 pull request 就好了

空白字符换个表示是可以的

deep-practice commented 5 years ago

求助,实在搞不明白哪儿的问题,用5529那个字典预训练模型finetune,训练过程也是正常的,在torch.save(crnn.state_dict(),'xxxx.pth')之前打印crnn里面rnn层,类别数也是正确的,但保存之后,再次加载就报错: RuntimeError: Error(s) in loading state_dict for CRNN: size mismatch for rnn.1.embedding.weight: copying a param with shape torch.Size([5530, 512]) from checkpoint, the shape in current model is torch.Size([6290, 512]). size mismatch for rnn.1.embedding.bias: copying a param with shape torch.Size([5530]) from checkpoint, the shape in current model is torch.Size([6290]).

Holmeyoung commented 5 years ago

加载之前手动把rnn层修改一下试试

deep-practice commented 5 years ago

嗯嗯,我的错,没有重新初始化