courao / ocr.pytorch

A pure pytorch implemented ocr project including text detection and recognition
MIT License
582 stars 133 forks source link

关于crnn训练中出现train loss=nan的问题 #25

Open wpc11 opened 4 years ago

wpc11 commented 4 years ago

1.我的数据集是这样的(有360w张图片) 6(1CATNR(T7 YAST9ME}}CB JBZM~%9{E%}Q{ZNN(DC1% R 2.我的config是这样的 D Y_8AL Y9YA7S2}T B_PYS 3、运行情况是 UO0CR5%(4ZZ_WF(HX )F5{S

烦请作者大大解答一下,感激不尽

courao commented 4 years ago

我的经验是有这么几个可能的出问题的地方, 1.alphabet的设置,新的数据集的话最好更新一下alphabet,因为之前的alphabet可能有漏掉的字符。 更新alphabet可以参考keys.py中注释掉的代码 2.网络输出的类别也就是nclass是不是和alphabet匹配上,看代码里似乎没有问题nclass=len(alphabet)+1 3.如果用的是pytorch自带的CTCloss的话有时候也会触发nan的情况,不过我在train_pytorch_ctc.py做了修正,在计算loss时用cpu计算了,试过几次已经没有出现nan的情况了,不知道你那边是不是用的pytorch自带的CTC。

wpc11 commented 4 years ago

train_wrap_ctc.py,用的是自带的from torch.nn import CTCLoss,这样的话我再去修改下代码,更新下alphabet,谢谢作者大大

courao commented 4 years ago

如果你用的train_wrap_ctc.py这个训练的话就用warp-ctc 如果想用pytorch自带的CTCloss的话,就用train_pytorch_ctc.py这个文件训练 这样的话应该是不会有nan出现的, 虽然两个文件很像,但是应该还是有细微差别的

wpc11 commented 4 years ago

okok

------------------ 原始邮件 ------------------ 发件人: "coura"<notifications@github.com>; 发送时间: 2020年3月22日(星期天) 下午2:21 收件人: "courao/ocr.pytorch"<ocr.pytorch@noreply.github.com>; 抄送: "青龙"<799691142@qq.com>; "Author"<author@noreply.github.com>; 主题: Re: [courao/ocr.pytorch] 关于crnn训练中出现train loss=nan的问题 (#25)

如果你用的train_wrap_ctc.py这个训练的话就用warp-ctc 如果想用pytorch自带的CTCloss的话,就用train_pytorch_ctc.py这个文件训练 这样的话应该是不会有nan出现的, 虽然两个文件很像,但是应该还是有细微差别的

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub, or unsubscribe.

wpc11 commented 4 years ago

作者大大你好,,就是我在更新完alphabet后出现了这样的错误,不知道是哪里操作出现了错误,,麻烦作者大大啦

 

------------------ 原始邮件 ------------------ 发件人: "coura"<notifications@github.com>; 发送时间: 2020年3月22日(星期天) 下午2:21 收件人: "courao/ocr.pytorch"<ocr.pytorch@noreply.github.com>; 抄送: "青龙"<799691142@qq.com>; "Author"<author@noreply.github.com>; 主题: Re: [courao/ocr.pytorch] 关于crnn训练中出现train loss=nan的问题 (#25)

如果你用的train_wrap_ctc.py这个训练的话就用warp-ctc 如果想用pytorch自带的CTCloss的话,就用train_pytorch_ctc.py这个文件训练 这样的话应该是不会有nan出现的, 虽然两个文件很像,但是应该还是有细微差别的

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub, or unsubscribe.

courao commented 4 years ago

你是贴了图吗。好像没看到你贴的内容

wpc11 commented 4 years ago

https://mail.qq.com/cgi-bin/viewfile?f=687B756EADCF463FA553967FBE8847DEAACEBA46AD314DB48E3FC6CB412A9002BC0BE251A4B30A99745C54D7F5C1C882051A548927E38219C1F45F6A79E37E15FBCD5803E274B899C14065FD2C2BA7C6AF537A6D08BB2E3049CD78475BD469F8&mailid=ZL1122-%7EnyFWQgBMZvsAS5wKdpqHa3&sid=vTutInOJJI74TWI8&net=889192575

wpc11 commented 4 years ago

问题似乎解决了,,麻烦了。。图好像没上传上来

wpc11 commented 4 years ago

作者大大你好,,我又来了。。我在加载你之前的训练模型当作我的预训练模型时发生了如下的错误: 2FM`)WXI3)L~O`B@ U7IU$T A19YN`O$V0F_%TTTGNDAYX 这是不是意味着我的数据集里的数据不符合规范,,(我训练的数据集就是那个8.6g的Synthetic Chinese String的数据集,,数据的不规范是否也会导致我出现nan呢,,其实我下午再修改了lr和batch_size,以及更新了alphabet的情况下,,用了train_pytorch_ctc.py以后仍然会出现nan的现象,😵,所以特此想问下会不会是我的数据出现了问题,期待答复,谢谢了!)

courao commented 4 years ago

加载预训练模型出错是因为alphabet发生了变化,如果要加载的话换回原来的alphabet即可, 既然还有nan发生的话有两个建议你可以尝试一下: 1.装一下warp-ctc在另一个训练代码上进行训练,不过warp-ctc貌似不支持特别高版本的pytorch,所以要退回低版本的pytorch,你要试的话可以装个虚拟环境试一试 2.可以生成一个小点并且简单点的数据集,比如只有英文+数字的数据,先跑通实验再进行大数据集上的实验,这样也方便定位问题

wpc11 commented 4 years ago

谢谢作者大大答复,我这就去尝试一下

courao commented 4 years ago

嗯这个错误应该是alphabet里面没有包含unicode为32的字符,这个字符应该是空格,按理说自己生成的alphabet应该是会包含这个字符的,如果没有的可以自己手动加一下。。。

------------------ 原始邮件 ------------------ 发件人: "wpc11"<notifications@github.com>; 发送时间: 2020年3月28日(星期六) 中午11:04 收件人: "courao/ocr.pytorch"<ocr.pytorch@noreply.github.com>; 抄送: "coura"<379186524@qq.com>; "Comment"<comment@noreply.github.com>; 主题: Re: [courao/ocr.pytorch] 关于crnn训练中出现train loss=nan的问题 (#25)

好像图没截全,,

— You are receiving this because you commented. Reply to this email directly, view it on GitHub, or unsubscribe.

wpc11 commented 4 years ago

好的好的。thanks啦

发自我的iPhone

------------------ 原始邮件 ------------------ 发件人: coura <notifications@github.com> 发送时间: 2020年3月28日 16:01 收件人: courao/ocr.pytorch <ocr.pytorch@noreply.github.com> 抄送: wpc11 <799691142@qq.com>, Author <author@noreply.github.com> 主题: 回复:[courao/ocr.pytorch] 关于crnn训练中出现train loss=nan的问题 (#25)

嗯这个错误应该是alphabet里面没有包含unicode为32的字符,这个字符应该是空格,按理说自己生成的alphabet应该是会包含这个字符的,如果没有的可以自己手动加一下。。。

------------------&nbsp;原始邮件&nbsp;------------------ 发件人: "wpc11"<notifications@github.com&gt;;
发送时间: 2020年3月28日(星期六) 中午11:04 收件人: "courao/ocr.pytorch"<ocr.pytorch@noreply.github.com&gt;;
抄送: "coura"<379186524@qq.com&gt;; "Comment"<comment@noreply.github.com&gt;;
主题: Re: [courao/ocr.pytorch] 关于crnn训练中出现train loss=nan的问题 (#25)

好像图没截全,,

— You are receiving this because you commented. Reply to this email directly, view it on GitHub, or unsubscribe. — You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub, or unsubscribe.

wpc11 commented 4 years ago

QQ图片20200406145042 QQ图片20200406150300 作者大大你好,之前我按照该程序中的crnn训练模型的程序train_pytorch_ctc跑完自己的数据集后,得到了一个最终的结果模型,可是当我用这个模型去替换原先的预训练模型后出现了这样的报错,请问下这样的错误出现原因是不是因为自己的数据集需要经过处理之类的,期待您的回答

courao commented 4 years ago

这个错误的原因是你再训练时用了新的alphabet但是测试时还是旧的,你把测试的也替换一下就可以了

wpc11 commented 4 years ago

好的感谢啦

woshildh commented 4 years ago

1.我的数据集是这样的(有360w张图片) 6(1CATNR(T7 YAST9ME}}CB JBZM~%9{E%}Q{ZNN(DC1% R 2.我的config是这样的 D Y_8AL Y9YA7S2}T B_PYS 3、运行情况是 UO0CR5%(4ZZ_WF(HX )F5{S

烦请作者大大解答一下,感激不尽

请问问题解决了吗?我也是这个数据集损失为Nan。非常感谢

Xel233 commented 4 years ago

preds.log_softmax(2).to(torch.float64)

It works smoothly on 3.6m dataset

InferMaster commented 3 years ago

我在训练的时候损失为148多,精确度一直为0,这个咋搞呀

EurekaTesla commented 3 years ago

我又来补充了。 在utils.py的110行self.alphabet.append(ocr('_')),如果你的标签中含有 字符就需要注意了,标签开头 导致无法正常获取标签值,所以把 _ 改为其他特殊字符(不会出现在你的标签中的字符)

CHAMYJ commented 3 years ago

请问您是如何准备新的alphabet.pkl 呢?

在keys.py 里, 有这代码, 但是我该如何准备新的 text.txt (在 infofiles 那)

import pickle as pkl alphabet_set = set() infofiles = ['text.txt'] for infofile in infofiles: f = open(infofile, encoding='utf8') content = f.readlines() f.close() for line in content: if len(line.strip())>0: if len(line.strip().split('\t'))!=2: print(line) else: fname,label = line.strip().split('\t') for ch in label: alphabet_set.add(ch) alphabet_list = sorted(list(alphabet_set)) pkl.dump(alphabet_list,open('khmer2.pkl','wb'))