weiren1998 / weiren1998.github.io

This is my blog.
1 stars 0 forks source link

OCR实验 | J球星的博客 #14

Open weiren1998 opened 2 years ago

weiren1998 commented 2 years ago

https://weiren1998.github.io/archives/36aea8fe.html

在探索网络架构的过程中,需要做很多尝试和思考,同时也需要把实验数据和对于结果的思考等记录下来,从而一点点积累感觉

weiren1998 commented 2 years ago
  1. 探究ocr模型可以在几张卡上跑(怕CPU爆炸),每张卡跑多少个batch(怕GPU爆炸)

    应该首先按原论文的setting还原:4卡 v100(16G); batchsize=8; lr=0.01->0.001; epoch=20 图片大小未知,这里为了32G显存可以放下,因此选择resize(512, 640)

    内存:所有数据一次性load进CPU内存后进行augment处理;同时还有模型的架构和参数

    显存:存储整套模型和参数,以及一个或者多个batch的数据

    序号 参数 实验名称 时间 经验
    1 gpu=4; batchsize=8/gpu resnet101-pre-trainset1-512640-e20
    2 gpu=8; batchsize=8/gpu resnet101-pre-trainset1-512640-e20-gpu8-patch8
    3 gpu=8; batchsize=16/gpu resnet101-pre-trainset1-512640-e20-gpu8-patch16
  2. 尝试每个将模型直接存到s3上:更换output路径

weiren1998 commented 2 years ago

复现TextFuseNet网络

环境配置

需要用到的python库 https://github.com/ying09/TextFuseNet/blob/master/step-by-step%20installation.txt pytorch版本1.4

pip install opencv-python
pip install tensorboard
pip install yacs
pip install tqdm
pip install termcolor
pip install tabulate
pip install matplotlib
pip install cloudpickle
pip install wheel
pip install pycocotools
pip install timm

pip install fvcore-master.zip

python setup.py build develop
weiren1998 commented 2 years ago

运行训练程序

# 用tnt做backbone
python tools/train_net.py --num-gpus 2 --config-file configs/ocr/icdar2013_tnt_FPN.yaml

# 用resnet做backbone
1. 修改Prejects/TextFuseNet/detectron2/modeling/backbone/fpn.py 文件名
python tools/train_net.py --num-gpus 1 --config-file configs/ocr/synthtext_pretrain_101_FPN.yaml
weiren1998 commented 2 years ago

Text Fuse Net Training

步骤:
  1. 用在imagenet上训练好的resnet101或者tnt模型在synthtext数据集上做预训练

    • synthtext数据处理(内存不够原因,需将数据集分割成几部分,再进行训练)【先按照现在的来就行了】

    • resnet101的预训练模型需要搞到手

      1. pytorch model zoo / 、pytorch resnet模型处 下载即可
      2. 因为pytorch版本为1.4,而模型加载保存时为1.6以上,因此需要调整
       import torch
       state_dict = torch.load('./resnet101.pth', map_location="cpu")
       torch.save(state_dict, './resnet101_new.pth', _use_new_zipfile_serialization=False)
    • 在modelArt上训练tnt模型

    • 在modelArt上训练resnet101模型

      python tools/train_net.py --num-gpus 2 --config-file configs/ocr/synthtext_pretrain_101_FPN.yaml
  2. 用预训练好的模型在下游数据集上微调

    1. ICDAR2013

      • 数据预处理

      • 在步骤1的基础上进行微调训练

      # 在测试环境训练
      python tools/train_net.py --num-gpus 2 --config-file configs/ocr/icdar2013_tnt_FPN.yaml
      • 在test数据集上进行测试
      python demo/icdar2013_detection.py
      • 将测试结果在官网进行评估
    2. ICDAR2015

    3. Total-Text

    4. CTW-1500

leaf-yej@whu.edu.cn

OCR Pretrain Setting

  1. SegLink:

    • 20w iterations ~ 5 epochs
    • 10e-4 learning rate
    • 512*512 or 384*384 前者比后者提高了1%
  2. TextSnake: A Flexible Representation for Detecting Text of Arbitrary Shapes

    • 在synthtext上训练1个epoch,learning rate为10e-3 fixed
       transform = Augmentation(
           size=512, mean=means, std=stds
       )
    
       trainset = SynthText(
           data_root='data/SynthText',
           is_training=True,
           transform=transform
       )
  3. PAN: Efficient and Accurate Arbitrary-Shaped Text Detection with Pixel Aggregation Network

    data = dict(
       batch_size=16,
       train=dict(
           type='PAN_Synth',
           is_transform=True,
           img_size=640,
           short_size=640,
           kernel_scale=0.5,
           read_type='cv2'
       )
    )
    train_cfg = dict(
       lr=1e-3,
       schedule='polylr',
       epoch=1,
       optimizer='Adam'
    )
    # https://github.com/whai362/pan_pp.pytorch/blob/master/config/pan/pan_r18_synth.py
  4. PSENet: 论文中说没用synthtext做pretrain

    data = dict(
       batch_size=16,
       train=dict(
           type='PSENET_Synth',
           is_transform=True,
           img_size=736,
           short_size=736,
           kernel_num=7,
           min_scale=0.7,
           read_type='cv2'
       )
    )
    train_cfg = dict(
       lr=1e-3,
       schedule=(200, 400,),
       epoch=1,
       optimizer='SGD'
    )
    # https://github.com/whai362/pan_pp.pytorch/blob/master/config/psenet/psenet_r50_synth.py