peteryuX / pcdarts-tf2

PC-DARTS (PC-DARTS: Partial Channel Connections for Memory-Efficient Differentiable Architecture Search, published in ICLR 2020) implemented in Tensorflow 2.0+. This is an unofficial implementation.
MIT License
27 stars 8 forks source link
cifar-10 cifar10 darts iclr2020 neural-architecture-search pc-darts pcdarts tensorflow tensorflow2 tf2

pcdarts-tf2

Language grade: Python Star Fork License

:fire: PC-DARTS (PC-DARTS: Partial Channel Connections for Memory-Efficient Differentiable Architecture Search, published in ICLR 2020) implemented in Tensorflow 2.0+. This is an unofficial implementation. :fire:

PC-DARTS is a memory efficient differentiable architecture search method, which can be trained with a larger batch size and, consequently, enjoys both faster speed and higher training stability. Experimental results achieve an error rate of 2.57% on CIFAR10 with merely 0.1 GPU-days for architecture search.

Original Paper:   Arxiv   OpenReview

Offical Implementation:   PyTorch


Contents

:bookmark_tabs:


Installation

:pizza:

Create a new python virtual environment by Anaconda or just use pip in your python environment and then clone this repository as following.

Clone this repo

git clone https://github.com/peteryuX/pcdarts-tf2.git
cd pcdarts-tf2

Conda

conda env create -f environment.yml
conda activate pcdarts-tf2

Pip

pip install -r requirements.txt

Usage

:lollipop:

Config File

You can modify your own dataset path or other settings of model in ./configs/*.yaml for training and testing, which would like below.

# general setting
batch_size: 128
input_size: 32
init_channels: 36
layers: 20
num_classes: 10
auxiliary_weight: 0.4
drop_path_prob: 0.3
arch: PCDARTS
sub_name: 'pcdarts_cifar10'
using_normalize: True

# training dataset
dataset_len: 50000  # number of training samples
using_crop: True
using_flip: True
using_cutout: True
cutout_length: 16

# training setting
epoch: 600
init_lr: 0.025
lr_min: 0.0
momentum: 0.9
weights_decay: !!float 3e-4
grad_clip: 5.0

val_steps: 1000
save_steps: 1000

Note:

Architecture Searching on CIFAR-10 (using small proxy model)

Step1: Search cell architecture on CIFAR-10 using small proxy model.

python train_search.py --cfg_path="./configs/pcdarts_cifar10_search.yaml" --gpu=0

Note:

Step2: After the searching completed, you can find the result genotypes in ./logs/{sub_name}/search_arch_genotype.py. Open it and copy the latest genotype into the ./modules/genotypes.py, which will be used for further training later. The genotype like bellow:

TheNameYouWantToCall = Genotype(
    normal=[
        ('sep_conv_3x3', 1),
        ('skip_connect', 0),
        ('sep_conv_3x3', 0),
        ('dil_conv_3x3', 1),
        ('sep_conv_5x5', 0),
        ('sep_conv_3x3', 1),
        ('avg_pool_3x3', 0),
        ('dil_conv_3x3', 1)],
    normal_concat=range(2, 6),
    reduce=[
        ('sep_conv_5x5', 1),
        ('max_pool_3x3', 0),
        ('sep_conv_5x5', 1),
        ('sep_conv_5x5', 2),
        ('sep_conv_3x3', 0),
        ('sep_conv_3x3', 3),
        ('sep_conv_3x3', 1),
        ('sep_conv_3x3', 2)],
    reduce_concat=range(2, 6))

Note:

Training on CIFAR-10 (using full-sized model)

Step1: Make sure that you already modifed the flag arch in ./configs/pcdarts_cifar10.yaml to match the genotype you want to use in ./modules/genotypes.py.

Note:

Step2: Train the full-sized model on CIFAR-10 with specific genotype.

python train.py --cfg_path="./configs/pcdarts_cifar10.yaml" --gpu=0

Testing on CIFAR-10 (using full-sized model)

To evaluate the full-sized model with the corresponding cfg file on the testing dataset. You can also download my trained model for testing from Models without training it yourself, which default arch (PCDARTS) is the best cell proposed in paper.

python test.py --cfg_path="./configs/pcdarts_cifar10.yaml" --gpu=0

Benchmark

:coffee:

Results on CIFAR-10

Method Search Method Params(M) Test Error(%) Search-Cost(GPU-days)
NASNet-A RL 3.3 2.65 1800
AmoebaNet-B Evolution 2.8 2.55 3150
ENAS RL 4.6 2.89 0.5
DARTSV1 gradient-based 3.3 3.00 0.4
DARTSV2 gradient-based 3.3 2.76 1.0
SNAS gradient-based 2.8 2.85 1.5
PC-DARTS (official PyTorch version) gradient-based 3.63 2.57 0.1
PC-DARTS TF2 (paper architecture) gradient-based 3.63 2.73 -
PC-DARTS TF2 (searched by myself) gradient-based 3.56 2.88 0.12

Note:


Models

:doughnut:

Dowload these models bellow, then extract them into ./checkpoints/ for restoring.

Model Name Config File arch Download Link
PC-DARTS (CIFAR-10, paper architecture) pcdarts_cifar10.yaml PCDARTS GoogleDrive
PC-DARTS (CIFAR-10, searched by myself) pcdarts_cifar10_TF2.yaml PCDARTS_TF2_SEARCH GoogleDrive

Note:


References

:hamburger:

Thanks for these source codes porviding me with knowledges to complete this repository.