facebookresearch / ppuda

Code for Parameter Prediction for Unseen Deep Architectures (NeurIPS 2021)
MIT License
485 stars 60 forks source link

Improved training speed of GHNs, extra results for CIFAR-10 #2

Closed bknyaz closed 2 years ago

bknyaz commented 2 years ago

Training times

Implementation of some steps in the decoder of GHNs is improved to speed up the training time of GHNs without altering their overall behavior. These improvements mainly affect the speed when a meta-batch size > 1 is used (see the tables below).

Speed is measured on NVIDIA Quadro RTX 6000 in terms of seconds per training iteration (averaged for the first 100 iterations).

CIFAR-10

Model Current version Our PR Estimated total speed up for 300 epochs
MLP with meta-batch size bm = 1 0.21 sec/iter 0.13 sec/iter 0.5 days -> 0.3 days
MLP with meta-batch size bm = 8 6.35 sec/iter 0.89 sec/iter 15.5 days -> 2.2 days
GHN-2 with meta-batch size bm = 1 0.77 sec/iter 0.72 sec/iter 1.9 days -> 1.8 days
GHN-2 with meta-batch size bm = 8 7.74 sec/iter 1.99 sec/iter 18.9 days -> 4.9 days

ImageNet

Model Current version Our PR Estimated total speed up for 300/150 epochs*
MLP with bm = 1 0.53 sec/iter 0.37 sec/iter 8.9 days -> 6.2 days
MLP with bm = 8 (4 GPUs) 1.78 sec/iter 1.36 sec/iter 14.9 days -> 11.4 days
GHN-2 with bm = 1 1.08 sec/iter 0.92 sec/iter 18.0 days -> 15.4 days
GHN-2 with bm = 8 (4 GPUs) 3.78 sec/iter 3.50 sec/iter 31.6 days -> 29.2 days

*To estimate the total training time, 300 epochs is used for bm=1 and 150 epochs is used for bm=8 (according to the paper).

When 4 GPUs and bm = 8 is used, the speed up is not significant, because each GPU receives only two architectures.

Evaluation of GHNs

To make sure that the evaluation results (classification accuracies of predicted parameters) reported in the paper are the same as in this PR, the GHNs were evaluated on selected architectures and the same results were obtained (see the table below).

Model ResNet-50 ViT Test Architecture (index in the test split)
GHN-2-CIFAR-10 58.6 11.4 77.1 (210)
GHN-2-ImageNet 5.3 4.4 48.3 (85)

Extra results on CIFAR-10

c10_extended_results

Other minor updates