Open as754770178 opened 5 years ago
ChannelPrunedLearner
does not remove pruned models in the .ckpt files. Instead, we only set the corresponding channel's weights to zeros. They are removed in the .pb & *.tflite models. Therefore, the size of weights should be the same.
Can you tell me in which you convert the model? I did't find the pb models in cp_best_path
and cp_channel_pruned_path
.
.pb & .tflite models are generated using tools/conversion/export_chn_pruned_tflite_model.py
.
In __build_pruned_train_model
the collection of input and output is train_images
and logits
, the code show as below:
logits = tf.get_collection('logits')[0]
train_images = tf.get_collection('train_images')[0]
train_labels = tf.get_collection('train_labels')[0]
mem_images = tf.get_collection('mem_images')[0]
mem_labels = tf.get_collection('mem_labels')[0]
In tools/conversion/export_chn_pruned_tflite_model.py
, the parameter input_coll
and output_coll
should be train_images
and logits
?
You need to export *.ckpt saved from the evaluation graph (defined in __build_pruned_evaluate_model()
), instead of the training graph.
If the structure of model has not changed, can the code self.saver_train = tf.train.import_meta_graph(path + '.meta')
in __build_pruned_train_model
be changed to
# model definition
with tf.variable_scope(self.model_scope):
# forward pass
logits = self.forward_train(mem_images)
loss, accuracy = self.calc_loss(mem_labels, logits, self.trainable_vars)
self.accuracy_keys = list(accuracy.keys())
for key in self.accuracy_keys:
tf.add_to_collection(key, accuracy[key])
tf.add_to_collection('loss', loss)
tf.add_to_collection('logits', logits)
The eval model is only used in __train_pruned_model
, whether self.__build(is_train=False)
is useless in __init__
?
For your above two questions, we need to review the detailed implementation of ChannelPrunedLearner
in the next few days, and refactor it if needed.
I prune resnet_20 at cifar_10 by ChannelPrunedLearner, but I reader
original_model.ckpt
,pruned_model.ckpt
andbest_model.ckpt
, the size of weight is same.origin val
is the size of weight inoriginal_model.ckpt
,pruned val
is the size of weight inpruned_model.ckpt
orbest_model.ckpt
.command :
./scripts/run_local.sh nets/resnet_at_cifar10_run.py --learner channel
modle: https://api.ai.tencent.com/pocketflow/models_resnet_20_at_cifar_10.tar.gz