henryzhongsc / adv_robust_gkp

Official implementation for Zhong et al., One Less Reason for Filter Pruning: Gaining Free Adversarial Robustness with Structured Grouped Kernel Pruning. NeurIPS 2023
4 stars 1 forks source link

Training a network from scratch: not doable #1

Open giorgiopiras opened 7 months ago

giorgiopiras commented 7 months ago

Hi @henryzhongsc, thanks for your work on this repo.

I was wondering whether it could be possible to train a different model than a ResNet20 with your current state of the code. I am trying to prune a ResNet18, but looks like I cannot perform the standard pretraining with this current code state: 2024-02-19 15:39:53,474 | ERROR : Input task <train> is not supported. Likewise, it is a bit hard to load the state dict of an externally pretrained model into your ModifiedResNet class in prune_cifar.py (line 33)

Any chance you are planning to solve the problems in the code, or load some pretrained? Thanks, Giorgio

henryzhongsc commented 7 months ago

Hi Giorgio,

Thank you for checking our work out. It is possible to train/prune a different model than ResNet20. The current prune_cifar.py supports all BasicBlock CifarResNets (namely ResNet 20/32/56/110 as reported in the paper). The case of ResNet18 is a bit tricky as it is normally viewed as a ResNet for ImageNet, in which we opt to support the BottleNeck variants of the family (ResNet50 and ResNet101) in the prune_imagenet.py series of pipeline for ImageNet-like evaluations. Given ResNet18 follows the BottleNeck architecture, the prune_imagenet.py won't be directly applicable either.

However, this should not affect the training of ResNet18 as our method is post-train. The supplied training code here is vanilla and model-agnostic. Suppose you supply a ResNet18 model definition, randomly initialize it, save the randomly initialized model as a checkpoint, and then hit it with a script like:

!python main.py \
--exp_desc resnet18_ba_train \
--setting_dir /content/drive/MyDrive/adv_robust_gkp/settings/cifar_ba_train_setting.json \  
--dataset cifar10 \
--model_dir /content/drive/MyDrive/adv_robust_gkp/ckpts/resnet18_init.pt \
--output_folder_dir /content/drive/MyDrive/adv_robust_gkp/output/resnet18_ba_trained/ \
--task train \
--adv_attack no_attack

(The cifar_ba_train_setting.json here is gkp_cifar10_finetune.json but with initial "lr": 0.1 if you want to be consistent with our cifar baseline training setting).

It should totally work. Based on the error you are showing, my guess is you tried to run the gkp_main.py file, which is dedicated to pruning but not training/finetuning. This is my bad, I will update the doc/demo to highlight that.

For now, you do need to modify the prune_imagenet code to make it prune ResNet18 (trained on cifar or not). We do plan to support more models, but not in a hard-coded way as we are doing here. We are working on an index-based implementation for some typical pruning granularities (basically filter and grouped kernel), in which setting we can easily expand our pruning implementation to ResNet18. It might take a while, though, as we are still working on that implementation. My plan is to get the model checkpoints and (hard-coded) implementation for other pruning methods out first.