A Pytorch implementation of T-PAMI 2021 paper "Convolutional Neural Networks with Gated Recurrent Connections", which is an extended journal version of the previous work "Gated Recurrent Convolution Neural Network for OCR" (https://github.com/Jianf-Wang/GRCNN-for-OCR) presented in NeurIPS 2017. Extensive experiments are presented in this journal version.
This GRCNN implementation is built upon the PyTorch. The requirements are:
To simply train on cifar-10, please run with the following command:
GRCNN-56:
python train_cifar.py --gpu-id 0,1 -a grcnn56
For other network architectures, please set "-a".
If you want to use the weight sharing setting, you can set "--weight-sharing" to "True".
To train on the cifar-100, you can add "--dataset cifar100" to the command.
To train GRCNN or SK-GRCNN on ImageNet, please run with the following command:
GRCNN-55:
python imagenet_train.py \
--epochs 100 \
--dist-url 'tcp://localhost:10010' --multiprocessing-distributed --world-size 1 --rank 0 \
--arch grcnn55 \
SK-GRCNN-55:
python imagenet_train.py \
--epochs 120 \
--dist-url 'tcp://localhost:10010' --multiprocessing-distributed --world-size 1 --rank 0 \
--arch skgrcnn55 \
--cos \
As for GRCNN-109 and SK-GRCNN-109, please set "--arch".
The ImageNet pretrained models are released. Note that we also release the weight sharing version of GRCNN-55 and GRCNN-109. The weight sharing version GRCNNs have less parameters and achieve competitive results on ImageNet when compared with other light weight models.
name | param | top-1 acc. | model (Google Drive) | model (Baidu Disk) |
---|---|---|---|---|
GRCNN-55 | 24.9M | 77.02 | download | download (code: vdb1) |
SK-GRCNN-55 | 27.4M | 79.38 | download | download (code: temi) |
GRCNN-109 | 45.1M | 78.20 | download | download (code: sxcd) |
SK-GRCNN-109 | 50.0M | 80.01 | download | download (code: 93tr) |
GRCNN-55 (weight sharing) | *12.0M* | 75.49 | download | download (code: s11g) |
GRCNN-109 (weight sharing) | *12.1M* | 76.00 | download | download (code: 4eiv) |
To simply use the pretrained models, run following commands:
import torch
import models.imagenet.GRCNN as grcnn
model = grcnn.grcnn55()
model.load_state_dict(torch.load('checkpoint_params_grcnn55.pt'))
The experiments of object detection in the paper are conducted based on the repository of the original papers. But in order to widely evaluate GRCNN on different object detection methods, we integrated GRCNN into a well-known object detetion toolbox: MMDetection.
For the installation, please go to the the direction "./mmdetection" and use the following command:
conda env create -f GRCNN_Detect.yaml
conda activate GRCNN_Detect
pip install -v -e .
After the installation, a simple command to train a GRCNN109 based mask_rcnn:
./tools/dist_train.sh configs/GRCNN/mask_rcnn_grcnn109_fpn_2x_coco.py 8
More information about configuration files and GRCNN can be found in the directory "./mmdetection/configs/GRCNN" and "./mmdetection/mmdet/models/backbones/GRCNN.py"
@Article{jianfeng2021grcnn,
author = {Jianfeng Wang and Xiaolin Hu},
title = {Convolutional Neural Networks with Gated Recurrent Connections},
journal = {TPAMI},
year = {2021},
}