Open weiren1998 opened 2 years ago
探究ocr模型可以在几张卡上跑(怕CPU爆炸),每张卡跑多少个batch(怕GPU爆炸)
应该首先按原论文的setting还原:4卡 v100(16G); batchsize=8; lr=0.01->0.001; epoch=20 图片大小未知,这里为了32G显存可以放下,因此选择resize(512, 640)
内存:所有数据一次性load进CPU内存后进行augment处理;同时还有模型的架构和参数
显存:存储整套模型和参数,以及一个或者多个batch的数据
序号 | 参数 | 实验名称 | 时间 | 经验 |
---|---|---|---|---|
1 | gpu=4; batchsize=8/gpu | resnet101-pre-trainset1-512640-e20 | ||
2 | gpu=8; batchsize=8/gpu | resnet101-pre-trainset1-512640-e20-gpu8-patch8 | ||
3 | gpu=8; batchsize=16/gpu | resnet101-pre-trainset1-512640-e20-gpu8-patch16 |
尝试每个将模型直接存到s3上:更换output路径
复现TextFuseNet网络
需要用到的python库 https://github.com/ying09/TextFuseNet/blob/master/step-by-step%20installation.txt pytorch版本1.4
pip install opencv-python
pip install tensorboard
pip install yacs
pip install tqdm
pip install termcolor
pip install tabulate
pip install matplotlib
pip install cloudpickle
pip install wheel
pip install pycocotools
pip install timm
pip install fvcore-master.zip
python setup.py build develop
# 用tnt做backbone
python tools/train_net.py --num-gpus 2 --config-file configs/ocr/icdar2013_tnt_FPN.yaml
# 用resnet做backbone
1. 修改Prejects/TextFuseNet/detectron2/modeling/backbone/fpn.py 文件名
python tools/train_net.py --num-gpus 1 --config-file configs/ocr/synthtext_pretrain_101_FPN.yaml
用在imagenet上训练好的resnet101或者tnt模型在synthtext数据集上做预训练
synthtext数据处理(内存不够原因,需将数据集分割成几部分,再进行训练)【先按照现在的来就行了】
resnet101的预训练模型需要搞到手
1. pytorch model zoo / 、pytorch resnet模型处 下载即可
2. 因为pytorch版本为1.4,而模型加载保存时为1.6以上,因此需要调整
import torch
state_dict = torch.load('./resnet101.pth', map_location="cpu")
torch.save(state_dict, './resnet101_new.pth', _use_new_zipfile_serialization=False)
在modelArt上训练tnt模型
在modelArt上训练resnet101模型
python tools/train_net.py --num-gpus 2 --config-file configs/ocr/synthtext_pretrain_101_FPN.yaml
用预训练好的模型在下游数据集上微调
ICDAR2013
数据预处理
在步骤1的基础上进行微调训练
# 在测试环境训练
python tools/train_net.py --num-gpus 2 --config-file configs/ocr/icdar2013_tnt_FPN.yaml
python demo/icdar2013_detection.py
ICDAR2015
Total-Text
CTW-1500
leaf-yej@whu.edu.cn
SegLink:
TextSnake: A Flexible Representation for Detecting Text of Arbitrary Shapes
transform = Augmentation(
size=512, mean=means, std=stds
)
trainset = SynthText(
data_root='data/SynthText',
is_training=True,
transform=transform
)
PAN: Efficient and Accurate Arbitrary-Shaped Text Detection with Pixel Aggregation Network
data = dict(
batch_size=16,
train=dict(
type='PAN_Synth',
is_transform=True,
img_size=640,
short_size=640,
kernel_scale=0.5,
read_type='cv2'
)
)
train_cfg = dict(
lr=1e-3,
schedule='polylr',
epoch=1,
optimizer='Adam'
)
# https://github.com/whai362/pan_pp.pytorch/blob/master/config/pan/pan_r18_synth.py
PSENet: 论文中说没用synthtext做pretrain
data = dict(
batch_size=16,
train=dict(
type='PSENET_Synth',
is_transform=True,
img_size=736,
short_size=736,
kernel_num=7,
min_scale=0.7,
read_type='cv2'
)
)
train_cfg = dict(
lr=1e-3,
schedule=(200, 400,),
epoch=1,
optimizer='SGD'
)
# https://github.com/whai362/pan_pp.pytorch/blob/master/config/psenet/psenet_r50_synth.py
https://weiren1998.github.io/archives/36aea8fe.html
在探索网络架构的过程中,需要做很多尝试和思考,同时也需要把实验数据和对于结果的思考等记录下来,从而一点点积累感觉