SamsungSAILMontreal / ghn3

Code for "Can We Scale Transformers to Predict Parameters of Diverse ImageNet Models?" [ICML 2023]
https://arxiv.org/abs/2303.04143
MIT License
29 stars 3 forks source link

Question regarding random init performance too good #3

Closed sorobedio closed 2 months ago

sorobedio commented 2 months ago

Hello thank you for your work. Why random INIT performance is so high in Table 7 CIFAR-10 if the model has not been trained before. does random init mean model initialize with Xavier initialization? when using swint without any pretrained weights on CIFAR-10 as random initialize the performance is around 5 to 10. so I do not understand how you get that huge performance with no pretraining with the baseline random init. thank you

bknyaz commented 2 months ago

In Table 7 we always train/fine-tune the networks. RandInit means it was trained from random initialization. GHN-2 or 3 means we predict the parameters and then fine-tune. This is described in Section 5.3 "We fine-tune the networks initialized with one of these approaches...". In other figures and tables we sometimes report results without training/fine-tuning (like Figure 2 or Table 2), in which cases we explicitly say No Fine-tuning or No training.

sorobedio commented 2 months ago

thank you.

In Table 7, you mentioned using transfer learning from ImageNet to CIFAR-10/100. My understanding is that you generate the parameters using GHN trained on ImageNet, then initialize the networks for CIFAR-10/CIFAR-100 and evaluate without any training step, which corresponds to the first row. The second row represents a 1-epoch fine-tuning of the same weights. Based on your explanation is this correct :

My objective is to reproduce the results of the first row.

bknyaz commented 2 months ago

Let me elaborate on the following 5 rows to make the pipeline explicit:

  1. the row RANDINIT NO 61.6 corresponds to random initialization (pytorch default init)+training on CIFAR-10 for 300 epochs (I explained these hyperparameters in issue https://github.com/SamsungSAILMontreal/ghn3/issues/2)
  2. the row GHN-3-XL/M16 NO 72.9 corresponds to predicting ImageNet parameters+training on CIFAR-10 for 300 epochs
  3. the row RANDINIT 1 EPOCH 74.0 corresponds random initialization (pytorch default init)+training on imagenet for 1 epoch (as the 2nd column suggests)+training on CIFAR-10 for 300 epochs
  4. the row GHN-3-XL/M16 1 EPOCH 77.8 corresponds to predicting ImageNet parameters+training on imagenet for 1 epoch+training on CIFAR-10 for 300 epochs
  5. the row RANDINIT 90-600 EP 88.7 corresponds random initialization (pytorch default init)+training on imagenet for 90-600 epochs (depending on the architecture)+training on CIFAR-10 for 300 epochs
sorobedio commented 2 months ago

Thank you very much.