denru01 / netadapt

This repo contains the official Pytorch reimplementation of the paper "NetAdapt: Platform-Aware Neural Network Adaptation for Mobile Applications".
MIT License
180 stars 47 forks source link

NetAdapt: Platform-Aware Neural Network Adaptation for Mobile Applications

This repo contains the official Pytorch reimplementation of the paper "NetAdapt: Platform-Aware Neural Network Adaptation for Mobile Applications" [paper] [project]. The results in the paper were generated by the Tensorflow implementation from Google AI.

photo not available

Summary

  1. Requirements
  2. Usage
  3. Example
  4. Customization
  5. Citation

Requirements

The code base is tested with the following setting:

  1. Python 3.7.0
  2. CUDA 10.0
  3. Pytorch 1.2.0
  4. torchvision 0.4.0
  5. numpy 1.17.0
  6. scipy 1.3.1

First clone the repo in the directory you want to work:

    git clone https://github.com/denru01/netadapt.git  
    cd netadapt

In the following context, we assume you are at the repo root.

If the versions of Python and CUDA are the same as yours, you can download the python packages using:

    pip install -r requirements.txt

To verify the downloaded code base is correct, please run either

    sh scripts/unittest.sh

or

    sh scripts/unittest_helloworld.sh
    sh scripts/unittest_alexnet.sh
    sh scripts/unittest_mobilenet.sh

If it is correct, you should not see any FAIL.

Usage

In order to apply NetAdapt, run:

    python master.py [-h] [-gp GPUS [GPUS ...]] [-re] [-im INIT_MODEL_PATH]
             [-mi MAX_ITERS] [-lr FINETUNE_LR] [-bu BUDGET]
             [-bur BUDGET_RATIO] [-rt RESOURCE_TYPE]
             [-ir INIT_RESOURCE_REDUCTION]
             [-irr INIT_RESOURCE_REDUCTION_RATIO]
             [-rd RESOURCE_REDUCTION_DECAY]
             [-st SHORT_TERM_FINE_TUNE_ITERATION] [-lt LOOKUP_TABLE_PATH]
             [-dp DATASET_PATH] [-a ARCH] [-si SAVE_INTERVAL]
             working_folder input_data_shape input_data_shape
             input_data_shape

Example

We provide a simple example of applying NetAdapt to a very small network:

    sh scripts/netadapt_helloworld.sh

Detailed examples of applying NetAdapt to AlexNet/MobileNet on CIFAR-10 are shown here (AlexNet) and here (MobileNet).

photo not available

If you want to apply the algorithm to different networks or even different tasks, please see the following Customization section.

Customization

To apply NetAdapt to differenct networks or different tasks, please follow the instructions:

  1. Create your own network_utils python file (said network_utils_yourNetworkOrTask.py) and place it under network_utils.

  2. Implement functions described in network_utils_abstract.py.

  3. As we provide an example of applying NetAdapt to AlexNet, you can also build your network_utils based on network_utils_alexnet.py:

        cd network_utils
        cp network_utils_alexnet.py ./network_utils_yourNetworkOrTask.py
  4. Add from .network_utils_yourNetworkOrTask import * to __init__.py, which is under the same directory.

  5. Modify class networkUtils_alexnet(...) in line 44 in network_utils_yourNetworkOrTask.py to class networkUtils_yourNetworkOrTask(...).

  6. Modify def alexnet(...) in line 325-326 to:

        def yourNetworkOrTask(model, input_data_shape, dataset_path, finetune_lr=1e-3):
            return networkUtils_yourNetworkOrTask(model, input_data_shape, dataset_path, finetune_lr)
  7. Specify training/validation data loader, loss functions, optimizers, network architecture, training method, and evaluation method in network_utils_yourNetworkOrTask.py if there is any difference from the AlexNet example:

    • Modify data loader and loss functionsin function def __init__(...): in line 52.

    • Specify additive skip connections if there is any and modify def simplify_network_def_based_on_constraint(...) in network_utils_yourNetworkOrTask.py. You can see how our implementation uses additive skip connections here.

    • Modify training method (short-term finetune) in function def fine_tune(...): in line 245.

    • Modify evaluation method in function def evaluate(...): in line 291.

    You can see how these methods are utilized by the framework here.

  8. Our current code base supports pruning Conv2d, ConvTranspose2d, and Linear with additive skip connection. If your network architecture is not supported, please modify this. If you want to use other metrics (resource type) to prune networks, please modify this.

  9. We can apply NetAdapt to different networks or tasks by using --arch yourNetworkOrTask in scripts/netadapt_alexnet-0.5mac.sh. As for the values of other arguments, please see Usage. Generally, if you want to apply NetAdapt to a different task, you might change input_data_shape. If your network architecture is very different from that of MobileNet, you would have to modify the values of --init_resource_reduction_ratio and --resource_reduction_decay to get a different resource reduction schedule.

Citation

If you use our code or method in your work, please consider citing the following:

@InProceedings{eccv_2018_yang_netadapt,
    author = {Yang, Tien-Ju and Howard, Andrew and Chen, Bo and Zhang, Xiao and Go, Alec and Sandler, Mark and Sze, Vivienne and Adam, Hartwig},
    title = {NetAdapt: Platform-Aware Neural Network Adaptation for Mobile Applications},
    booktitle = {The European Conference on Computer Vision (ECCV)},
    month = {September},
    year = {2018}
}

Please direct any questions to the authors: Tien-Ju Yang (tjy@mit.edu) and Yi-Lun Liao (ylliao@mit.edu).