facebookresearch / ppuda

Code for Parameter Prediction for Unseen Deep Architectures (NeurIPS 2021)
MIT License
485 stars 60 forks source link

Improved fine-tuning, ConvNeXt support, improved training speed of GHNs #7

Closed bknyaz closed 2 years ago

bknyaz commented 2 years ago

Training times

Implementation of some steps in the forward pass of GHNs is improved to speed up the training time of GHNs without altering their overall behavior.

Speed is measured on NVIDIA A100-40GB in terms of seconds per training iteration on ImageNet (averaged for the first 50 iterations). 4xA100 are used for meta-batch size (bm) = 8. Measurements can be noisy because of potentially other users using some computational resources of the same cluster node.

Model AMP* Current version This PR Estimated total speed up for 300/150 epochs**
MLP with bm = 1 0.30 sec/iter 0.22 sec/iter 5.0 days -> 3.7 days
MLP with bm = 8 1.64 sec/iter 1.01 sec/iter 13.7 days -> 8.4 days
GHN-2 with bm = 1 0.77 sec/iter 0.70 sec/iter 12.9 days -> 11.7 days
GHN-2 with bm = 8 3.80 sec/iter 3.08 sec/iter 31.7 days -> 25.7 days
GHN-2 with bm = 8 3.45 sec/iter 2.80 sec/iter 28.8 days -> 23.4 days

Fine-tuning and ConvNeXt support

According to the report (Pretraining a Neural Network before Knowing Its Architecture) showing improved fine-tuning results, the following arguments are added to the code: --opt, --init, --imsize, --beta, --layer and file ppuda/utils/init.py with initialization functions. Also argument --val is added to enable evaluation on the validation data rather than testing data during training.

A simple example to try parameter prediction for ConvNeXt is to run:

python examples/torch_models.py cifar10 convnext_base

Code correctness

To make sure that the evaluation results (classification accuracies of predicted parameters) reported in the paper are the same as in this PR, the GHNs were evaluated on selected architectures and the same results were obtained (see the table below).

Model ResNet-50 ViT Test Architecture (index in the test split)
GHN-2-CIFAR-10 (top 1 acc) 58.6% 11.4% 77.1% (210)
GHN-2-ImageNet (top5 acc) 5.3% 4.4% 48.3% (85)

To further confirm the correctness of the updated code, the training loss and top1 accuracy of training GHN-2 on CIFAR-10 for 3 epochs are reported in the table below. The command used in this benchmark is: python experiments/train_ghn.py -m 8 -n -v 50 --ln.

Version Epoch 1 Epoch 2 Epoch 3
Current version loss=2.41, top1=17.23 loss=2.02, top1=20.62 loss=1.94, top1=24.56
This PR loss=2.51, top1=17.58 loss=2.01, top1=21.62 loss=1.90, top1=25.88

These results can be noisy because of several factors like random batches, initialization of GHN, etc.

Other

Python script experiments/train_ghn_stable.py is added to automatically resume training GHNs from the last saved checkpoint (if any) if the run failed for some reason (e.g. OOM, nan loss, etc.). Now instead of running python experiments/train_ghn.py -m 8 -n -v 50 --ln one can use python experiments/train_ghn_stable.py experiments/train_ghn.py -m 8 -n -v 50 --ln.