Open RapBleach opened 3 years ago
More than happy to answer some of your questions. I'm sure you would have fewer questions if I had actually gotten around to updating the read me.
What network are the teacher that undertakes distillation to learn and student respectively excuse me?
We use a normally trained resnet32x4 as the teacher. The students in this case are an identical copy of the trained teacher model. They are pruned gradually through the distillation process until they reach the desired sparsity. In this case, you can just think of the teacher model or the distillation as a regularization nudging the pruned models towards their original dense behavior.
Are the datasets used by teachers and students all unbalanced datasets after resampling?
The teacher model is trained only on the full dataset. When conducting the unbalanced experiments we are distilling the student on the resampled dataset. The student model is of course initialized with the teacher model weights so in a sense it has "seen" data from the full set and yet still begins to forget when pruned/finetuned on this resampled set. It isn't a perfect experiment to simulated biased data but I think it still provides some strong signal.
I will make an effort to provide a more in-depth overview in the readme and clean up the repo soon!
Thank you for your answer. I have some problems when I run the code. I would appreciate it if you could answer my questions.
When I ran this part of the code, he made an error.
File "/home/haha/huangbowen/pruning-distilation-bias -main/dataset/cifar100.py", line 256, in get_ cifar100_ imbalanced for index, label in enumerate (train_ ,loader . dataset. targets): AttributeError: 'CIFAR100Instance' object has no attribute ' targets '
I've tried to find a solution in search engines, but it didn't work.
I think this may be an issue with the version of pytorch/torchvision you have installed. Can you tell me which version you have in the environment you are trying to run this? (I may need to update the environment.yml file for others)
Try upgrading torchvision >= 0.9 and see if that fixes your error
Thank you! I got it! But which model should be the teacher model? I first train a resnet110 with 0.3 target_sparsity and normal dataset. And I try to put it to "train_students.py" for "model_path", but it has something error in loading state_dict() resnet. It said there is some keys missing. Then I try to not use the filepath but it doesn't work. How should I solve this problem?
so in our experiments, we do not use pruned models as teachers we use dense models which you can download with the fetch_pretrained_teachers.sh script https://github.com/codestar12/pruning-distilation-bias/blob/main/scripts/fetch_pretrained_teachers.sh
as far as your sparse teacher model the pytorch pruning library changes the key names of weights in the model https://pytorch.org/tutorials/intermediate/pruning_tutorial.html#serializing-a-pruned-model
We have an example of how we have dealt with this in our analytics file we use for running eval on our pruned students https://github.com/codestar12/pruning-distilation-bias/blob/fea14561859661b49f370c8bc5cad17bf0e7c5f5/src/analytics.py#L12
you basically want to rename the weights/biases keys to what model function expects (or make a new state dict like we do
Thank you for sharing the code. Here I raise a little personal doubt. What network are the teacher that undertakes distillation to learn and student respectively excuse me? Transferring network knowledge from pre-pruning to post-pruning? Are the datasets used by teachers and students all unbalanced datasets after resampling?
If you can see these questions, could you please help me to answer my questions?Thank you very much.