GuYuc / WS-DAN.PyTorch

A PyTorch implementation of WS-DAN (Weakly Supervised Data Augmentation Network) for FGVC (Fine-Grained Visual Classification)
MIT License
405 stars 99 forks source link

WS-DAN.PyTorch

A neat PyTorch implementation of WS-DAN (Weakly Supervised Data Augmentation Network) for FGVC (Fine-Grained Visual Classification). (Hu et al., "See Better Before Looking Closer: Weakly Supervised Data Augmentation Network for Fine-Grained Visual Classification", arXiv:1901.09891)

NOTICE: This is NOT an official implementation by authors of WS-DAN. The official implementation is available at tau-yihouxiang/WS_DAN (and there's another unofficial PyTorch version wvinzh/WS_DAN_PyTorch).

Innovations

  1. Data Augmentation: Attention Cropping and Attention Dropping

    Fig1
  2. Bilinear Attention Pooling (BAP) for Features Generation

    Fig3
  3. Training Process and Testing Process

    Fig2a Fig2b

Performance

Dataset Object Category Train Test Accuracy (Paper) Accuracy (PyTorch) Feature Net
FGVC-Aircraft Aircraft 100 6,667 3,333 93.0 93.28 inception_mixed_6e
CUB-200-2011 Bird 200 5,994 5,794 89.4 88.28 inception_mixed_6e
Stanford Cars Car 196 8,144 8,041 94.5 94.38 inception_mixed_6e
Stanford Dogs Dog 120 12,000 8,580 92.2 89.66 inception_mixed_7c

Usage

WS-DAN

This repo contains WS-DAN with feature extractors including VGG19('vgg19', 'vgg19_bn'), ResNet34/50/101/152('resnet34', 'resnet50', 'resnet101', 'resnet152'), and Inception_v3('inception_mixed_6e', 'inception_mixed_7c') in PyTorch form, see ./models/wsdan.py.

net = WSDAN(num_classes=num_classes, M=num_attentions, net='inception_mixed_6e', pretrained=True)
net = WSDAN(num_classes=num_classes, M=num_attentions, net='inception_mixed_7c', pretrained=True)
net = WSDAN(num_classes=num_classes, M=num_attentions, net='vgg19_bn', pretrained=True)
net = WSDAN(num_classes=num_classes, M=num_attentions, net='resnet50', pretrained=True)

Dataset Directory

Run

  1. git clone this repo.

  2. Prepare data and modify DATAPATH in datasets/<abcd>_dataset.py.

  3. Set configurations in config.py (Training Config, Model Config, Dataset/Path Config):

    tag = 'aircraft'  # 'aircraft', 'bird', 'car', or 'dog'
  4. $ nohup python3 train.py > progress.bar & for training.

  5. $ tail -f progress.bar to see training process (tqdm package is required. Other logs are written in <config.save_dir>/train.log).

  6. Set configurations in config.py (Eval Config) and run $ python3 eval.py for evaluation and visualization.

Attention Maps Visualization

Code in eval.py helps generate attention maps. (Image, Heat Attention Map, Image x Attention Map)

Raw Heat Atten