meijieru / crnn.pytorch

Convolutional recurrent network in pytorch
MIT License
2.39k stars 658 forks source link

How to use pre-training model #134

Open jjccyy opened 6 years ago

jjccyy commented 6 years ago

I modified my key.py
When I load the pre-training model,It will get an error RuntimeError: While copying the parameter named rnn.1.embedding.weight, whose dimensions in the model are torch.Size([12, 512]) and whose dimensions in the checkpoint are torch.Size([5530, 512]). I just want to modify the last layer,How to solve it. thanks

leonzhu211 commented 6 years ago

Maybe you need to modify the 'nclass' to satisfy the pre-trained model.

model = crnn.CRNN(opt.imgH, nc, nclass, opt.nh)

ZhuLingfeng1993 commented 6 years ago

what is the nclass?The kinds of the character?In my data,the kind is 100,but my error is RuntimeError: While copying the parameter named module.rnn.1.embedding.weight, whose dimensions in the model are torch.Size([37, 512]) and whose dimensions in the checkpoint are torch.Size([63, 512]),so I set the nclass as 63.The problem is solved.I am confused that the nclass is 63?Shouldn't it be 100?

ZhuLingfeng1993 commented 6 years ago

请问为什么会报错?你又是怎么解决的呢?

leonzhu211 commented 6 years ago

我修改后的代码中有句

self.nclass = len(self.alphabet) + 1 self.converter = utils.strLabelConverter(self.alphabet) self.model = CRNN(self.image_height, 3, self.nclass, self.hidden_size, self.bone)

ndclass 是你的字典长度加1,看原代码的注释,加1应该是 CTCLoss 要求的,即增加一个,表明“无效字符”

你拿到的 pretrained model,要看一下这个 model 用的 alphabet是不是跟你用的是相同的。

liuyiyiyiyi commented 5 years ago

我修改后的代码中有句

self.nclass = len(self.alphabet) + 1 self.converter = utils.strLabelConverter(self.alphabet) self.model = CRNN(self.image_height, 3, self.nclass, self.hidden_size, self.bone)

ndclass 是你的字典长度加1,看原代码的注释,加1应该是 CTCLoss 要求的,即增加一个,表明“无效字符”

你拿到的 pretrained model,要看一下这个 model 用的 alphabet是不是跟你用的是相同的。

请问你有没有训练过中文,关于中文的nclass有个问题想要请教,比如说我的中文字符类别有10个,使用python3训练的话nclass为11,但是使用python2训练的话,由于一个中文字符长度为3,所以此时nclass为31,这样的话就有问题了,如果我需要使用python2来训练,有什么解决方法吗