klrc / RACNN-pytorch

pytorch implementation of Recurrent Attention CNN.
48 stars 15 forks source link

RACNN Pytorch Implementation

NOTE: There seems to be some issue with the margin loss. Sorry that I'm not planning to fix this, you can try other implementation https://github.com/jeong-tae/RACNN-pytorch.

This is a mobilenet version of RACNN.

Referred from raw pytorch implementation .

Requirements

Changes

different from the origin code, several possibly important changes are applied here:

Results

Apn pretrained with mobilenet-v2(imagenet pretrained) backbone:

pretrain_apn_imagenet-1577260547 pretrain_apn_imagenet-1577260547
zoomed input after apn-1. zoom input after apn-2.

I pretrained the mobilenet on CUB_200 dataset before training, and it helps a lot as following:

pretrain_apn_imagenet-1577260547 pretrain_apn_imagenet-1577260547
zoomed input after apn-1 (with pretraining on CUB_200_2011) zoom input after apn-2 (with pretraining on CUB_200_2011)

Final accuracy

Accuracy at epoch-50:

[2019-12-31 20:06:50]    :: Testing on test set ...
[2019-12-31 20:07:10]           Accuracy clsf-0@top-1 (201/725) = 79.95050%
[2019-12-31 20:07:10]           Accuracy clsf-0@top-5 (201/725) = 94.61634%
[2019-12-31 20:07:10]           Accuracy clsf-1@top-1 (201/725) = 74.25743%
[2019-12-31 20:07:10]           Accuracy clsf-1@top-5 (201/725) = 91.39851%
[2019-12-31 20:07:10]           Accuracy clsf-2@top-1 (201/725) = 74.62871%
[2019-12-31 20:07:10]           Accuracy clsf-2@top-5 (201/725) = 90.71782%

Each accuracy-epochs:

Figure_3

Usage

the CUB_200_2011 dataset here. (extract it to external/)

  1. pretrain a mobilenet-v2 on CUB_200_2011 (optional):

    $ python src/recurrent_attention_network_paper/pretrain_mobilenet.py
  2. pretrain the apn:

    edit some configurations in pretrain_apn.py here:

    if __name__ == "__main__":
        clean()
        run(pretrained_backbone='build/mobilenet_v2_cub200-e801577256085.pt')

    set the model for backbone, then:

    $ python src/recurrent_attention_network_paper/pretrain_apn.py
  3. training:

    edit same configurations in forge.py , then:

    $ python src/recurrent_attention_network_paper/forge.py

outputs are generated at build/, including logs, frozen optimizers&model and some gifs as visualization.

Issues

References