Tencent / PocketFlow

An Automatic Model Compression (AutoMC) framework for developing smaller and faster AI applications.
https://pocketflow.github.io
Other
2.78k stars 490 forks source link

How to use Pre trained model for Weight sparsifictaion? #69

Closed umadevimcw closed 5 years ago

umadevimcw commented 5 years ago

Hai, I have followed the installation and usage steps of Weight sparsification. I have a doubt How to use pre-trained tensor flow model in weight sparse algorithm. I understood that pre-trained model downloaded by the script is used in the "Optimal pruning protocol" module to find the optimal pruning ratio (Correct me if i am wrong). How to use that pre-trained model in main pruning module ? Is it possible to use pre trained tensorflow model for weight sparse algorithm. How to do it.?

jiaxiang-wu commented 5 years ago

It is possible, but is not implemented in the weight sparsification algorithm, since we observe that training from scratch already gives relatively satisfying results. If you indeed want to use the pre-trained model, then you can create another model in the training graph, but with some other name scope, restore weights from pre-trained checkpoint files, and assign them to the model whose weights will be sparsified. You may take a look at the UniformQuantTFLearner's implementation to see how to use a pre-trained model for warm start.

umadevimcw commented 5 years ago

Thanks. I will look into the UniformQuantTFLearner implementation.

umadevimcw commented 5 years ago

Hai @jiaxiang-wu

NotFoundError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error: Key model/mask/model/resnet_model/conv2d/kernel_mask not found in checkpoint [[node save/RestoreV2 (defined at PocketFlow/learners/weight_sparsification/learner.py:241) = RestoreV2[dtypes=[DT_INT64, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]"

I am getting this error while I am trying to restore the Pretrained checkpoints