Cooduck / OCR-Invoice_Identification

基于CTPN+CRNN的OCR项目,用于发票识别
6 stars 1 forks source link

楼主你好,怎么根据你的这个模型训练自己的权重呢? #1

Open szzwy opened 4 months ago

szzwy commented 4 months ago

我想针对火车票训练一下,数据集我有,应该怎么搞呢?谢谢楼主

Cooduck commented 4 months ago

你好,我这个项目分为两个模型,一个是用于检测并提取文字的ctpn模型,另一个是用于识别文字的crnn模型,这两个需要分别训练,训练的代码在/train_code文件夹内。 针对ctpn的训练,你需要进入/train_code/ctpn_train_code文件夹内,修改/ctpn/config.py文件内的img_dir和label_dir为你图片的路径和标签的路径,同时,你的数据集还要做成如/imagedata的格式(如果标签不为json文件,就要自行去转换为json文件或修改/ctpn/dataset.py文件内read_json和readxml读取标签的函数)。如果以上均完成,运行train.py应该就能成功开始训练了。 针对crnn的训练,你需要进入/train_code/crnn_train_code文件夹内,修改/train_pytorch_ctc.py文件内的config.train_infofile为你的数据标签文件路径,数据标签文件的格式参照information.txt文件(这里可以用上面训练ctpn的数据集,然后调用data_pre.py代码,就能自动生成这个数据标签文件)。如果以上均完成,运行train_pytorch_ctc.py应该就能成功开始训练了。另外,对于crnn的训练,你还可以参考这个https://github.com/courao/ocr.pytorch/blob/master/train_code/train_crnn/readme.md 。 希望我的回答能帮助到你。

szzwy commented 4 months ago

好的楼主,但是我在测试ctpn的train.py时使用了代码内自带的和我自己测试,都报出下面的错误: Traceback (most recent call last): File "/mnt/workspace/OCR-Invoice_Identification/train_code/ctpn_train_code/train.py", line 106, in out_cls, out_regr = model(imgs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, *kwargs) File "/mnt/workspace/OCR-Invoice_Identification/train_code/ctpn_train_code/ctpn/ctpn.py", line 127, in forward x = self.base_layers(x) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/container.py", line 215, in forward input = module(input) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward return self._conv_forward(input, self.weight, self.bias) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward return F.conv2d(input, weight, bias, self.stride, RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [0]

是我环境的问题吗