codestar12 / pruning-distilation-bias

BSD 2-Clause "Simplified" License
1 stars 0 forks source link

I have a question for distillation work #1

Open RapBleach opened 3 years ago

RapBleach commented 3 years ago

Thank you for sharing the code. Here I raise a little personal doubt. What network are the teacher that undertakes distillation to learn and student respectively excuse me? Transferring network knowledge from pre-pruning to post-pruning? Are the datasets used by teachers and students all unbalanced datasets after resampling?

If you can see these questions, could you please help me to answer my questions?Thank you very much.

codestar12 commented 3 years ago

More than happy to answer some of your questions. I'm sure you would have fewer questions if I had actually gotten around to updating the read me.

What network are the teacher that undertakes distillation to learn and student respectively excuse me?

We use a normally trained resnet32x4 as the teacher. The students in this case are an identical copy of the trained teacher model. They are pruned gradually through the distillation process until they reach the desired sparsity. In this case, you can just think of the teacher model or the distillation as a regularization nudging the pruned models towards their original dense behavior.

Are the datasets used by teachers and students all unbalanced datasets after resampling?

The teacher model is trained only on the full dataset. When conducting the unbalanced experiments we are distilling the student on the resampled dataset. The student model is of course initialized with the teacher model weights so in a sense it has "seen" data from the full set and yet still begins to forget when pruned/finetuned on this resampled set. It isn't a perfect experiment to simulated biased data but I think it still provides some strong signal.

I will make an effort to provide a more in-depth overview in the readme and clean up the repo soon!

RapBleach commented 3 years ago

Thank you for your answer. I have some problems when I run the code. I would appreciate it if you could answer my questions.

python train student.py --batch size 128 --epochs 150 --Learning rate 0.05 --tr decay epochs 50, 90, 120 --dataset cifar100 --model s resnet110 --distill kd -r 0.5 -a 0.5 --bias True --target_ sparsity 0.3 --path t ' ./save/student model/resnet110 cifar100 Lr 0.1 decay 0.0005 trial 0. ts:0.3 strat:struct/resnet110_ best. pth '

When I ran this part of the code, he made an error. File "/home/haha/huangbowen/pruning-distilation-bias -main/dataset/cifar100.py", line 256, in get_ cifar100_ imbalanced for index, label in enumerate (train_ ,loader . dataset. targets): AttributeError: 'CIFAR100Instance' object has no attribute ' targets ' I've tried to find a solution in search engines, but it didn't work.

codestar12 commented 3 years ago

I think this may be an issue with the version of pytorch/torchvision you have installed. Can you tell me which version you have in the environment you are trying to run this? (I may need to update the environment.yml file for others)

codestar12 commented 3 years ago

Try upgrading torchvision >= 0.9 and see if that fixes your error

RapBleach commented 3 years ago

Thank you! I got it! But which model should be the teacher model? I first train a resnet110 with 0.3 target_sparsity and normal dataset. And I try to put it to "train_students.py" for "model_path", but it has something error in loading state_dict() resnet. It said there is some keys missing. Then I try to not use the filepath but it doesn't work. How should I solve this problem?

codestar12 commented 3 years ago

so in our experiments, we do not use pruned models as teachers we use dense models which you can download with the fetch_pretrained_teachers.sh script https://github.com/codestar12/pruning-distilation-bias/blob/main/scripts/fetch_pretrained_teachers.sh

as far as your sparse teacher model the pytorch pruning library changes the key names of weights in the model https://pytorch.org/tutorials/intermediate/pruning_tutorial.html#serializing-a-pruned-model

We have an example of how we have dealt with this in our analytics file we use for running eval on our pruned students https://github.com/codestar12/pruning-distilation-bias/blob/fea14561859661b49f370c8bc5cad17bf0e7c5f5/src/analytics.py#L12

you basically want to rename the weights/biases keys to what model function expects (or make a new state dict like we do