PaddlePaddle / PaddleOCR

Awesome multilingual OCR toolkits based on PaddlePaddle (practical ultra lightweight OCR system, support 80+ languages recognition, provide data annotation and synthesis tools, support training and deployment among server, mobile, embedded and IoT devices)
https://paddlepaddle.github.io/PaddleOCR/
Apache License 2.0
42.66k stars 7.67k forks source link

PaddleOCR2.3识别训练蒸馏模型运行报错 #5835

Closed Onmyown577 closed 2 years ago

Onmyown577 commented 2 years ago

运行python tools/train.py -c configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml时报错 Traceback (most recent call last): File "tools/train.py", line 131, in main(config, device, logger, vdl_writer) File "tools/train.py", line 108, in main eval_class, pre_best_model_dict, logger, vdl_writer) File "/home/wtt/PaddleOCR2.3/tools/program.py", line 245, in train eval_class(post_result, batch) File "/home/wtt/PaddleOCR2.3/ppocr/metrics/distillation_metric.py", line 46, in call assert isinstance(preds, dict) AssertionError 同一数据集验证ch_PP-OCRv2_rec.yml模型可行,唯蒸馏模型报错 ch_PP-OCRv2_rec_distillation.yml具体内容 Global: debug: false use_gpu: true epoch_num: 100 log_smooth_window: 20 print_batch_step: 1 save_model_dir: ./output/rec_pp-OCRv2_distillation save_epoch_step: 3

evaluation is run every 2000 iterations after the 0th iteration

eval_batch_step: 80000000 cal_metric_during_train: true pretrained_model: checkpoints: save_inference_dir: use_visualdl: false infer_img: character_dict_path: ppocr/utils/ppocr_keys_v1.txt character_type: ch max_text_length: 25 infer_mode: false use_space_char: true distributed: true save_res_path: ./output/rec/predicts_pp-OCRv2_distillation.txt

Optimizer: name: Adam beta1: 0.9 beta2: 0.999 lr: name: Piecewise decay_epochs : [700, 800] values : [0.001, 0.0001] warmup_epoch: 5 regularizer: name: L2 factor: 2.0e-05

Architecture: model_type: &model_type "rec" name: DistillationModel algorithm: Distillation Models: Teacher: pretrained: ./pretrain_models/ch_PP-OCRv2_rec_train/best_accuracy freeze_params: false return_all_feats: true model_type: model_type algorithm: CRNN Transform: Backbone: name: MobileNetV1Enhance scale: 0.5 Neck: name: SequenceEncoder encoder_type: rnn hidden_size: 64 Head: name: CTCHead mid_channels: 96 fc_decay: 0.00002 Student: pretrained: ./pretrain_models/ch_PP-OCRv2_rec_train/best_accuracy freeze_params: false return_all_feats: true model_type: model_type algorithm: CRNN Transform: Backbone: name: MobileNetV1Enhance scale: 0.5 Neck: name: SequenceEncoder encoder_type: rnn hidden_size: 64 Head: name: CTCHead mid_channels: 96 fc_decay: 0.00002

Loss: name: CombinedLoss loss_config_list:

PostProcess: name: DistillationCTCLabelDecode model_name: ["Student", "Teacher"] key: head_out

Metric: name: DistillationMetric base_metric_name: RecMetric main_indicator: acc key: "Student"

Train: dataset: name: SimpleDataSet data_dir: ./ch_data/yl label_file_list:

LDOUBLEV commented 2 years ago

报错提示模型的输出不是dict类型,但是2.4分支的代码结果输出是dict,https://github.com/PaddlePaddle/PaddleOCR/blob/e0a52ee5110235d5ba5313f95726110e11eb16bf/ppocr/modeling/architectures/distillation_model.py#L60 更新代码再试下?

Onmyown577 commented 2 years ago

报错提示模型的输出不是dict类型,但是2.4分支的代码结果输出是dict,

https://github.com/PaddlePaddle/PaddleOCR/blob/e0a52ee5110235d5ba5313f95726110e11eb16bf/ppocr/modeling/architectures/distillation_model.py#L60

更新代码再试下?

更新代码之后可以啦,谢谢大佬 #