oandrienko / fast-semantic-segmentation

ICNet and PSPNet-50 in Tensorflow for real-time semantic segmentation
220 stars 41 forks source link

I cannot compress the model #29

Open dannadori opened 4 years ago

dannadori commented 4 years ago

Thanks for you great work.

I tried the training my own dataset with refering https://github.com/oandrienko/fast-semantic-segmentation/blob/master/docs/icnet.md

And stage1 works fine. But I cannot compress the model at stage2.

python compress.py --prune_config configs/compression/icnet_resnet_v1_pruner_v2.prune_config --input_checkpoint stage2/model.ckpt --output_dir stage2_compress --compression_factor 0.5

And I got this error.

Traceback (most recent call last):
  File "compress.py", line 103, in <module>
    tf.app.run()
  File "/home/wataru/git_work/fast-semantic-segmentation/venv/lib/python3.7/site-packages/tensorflow_core/python/platform/app.py", line 40, in run
    _run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
  File "/home/wataru/git_work/fast-semantic-segmentation/venv/lib/python3.7/site-packages/absl/app.py", line 299, in run
    _run_main(main, args)
  File "/home/wataru/git_work/fast-semantic-segmentation/venv/lib/python3.7/site-packages/absl/app.py", line 250, in _run_main
    sys.exit(main(argv))
  File "compress.py", line 96, in main
    compressor.compress(FLAGS.input_checkpoint)
  File "/home/wataru/git_work/fast-semantic-segmentation/libs/filter_pruner.py", line 385, in compress
    self._create_pruner_specs_recursively(self.input_node)
  File "/home/wataru/git_work/fast-semantic-segmentation/libs/filter_pruner.py", line 374, in _create_pruner_specs_recursively
    self._create_pruner_specs_recursively(next_node)
  File "/home/wataru/git_work/fast-semantic-segmentation/libs/filter_pruner.py", line 351, in _create_pruner_specs_recursively
    curr_node_name)
  File "/home/wataru/git_work/fast-semantic-segmentation/libs/filter_pruner.py", line 323, in _get_following_bn_and_conv_names
    raise ValueError('Incompatable model file.')
ValueError: Incompatable model file.

I tried to know which node is bad by inserting print(next_node.op) to filter_pruner.py and this output is 'FusedBatchNormV3'

Do you have any idea workaround this.

oandrienko commented 4 years ago

Hey @dannadori, thanks for your interest. What version of Tensorflow are you using? I think back when I implemented the pruner, it assumed the batch norm layers all had the op name as FusedBatchNorm. I'm guessing Tensorflow must have added multiple versions of the batch norm kernel, so maybe the easiest fix would be to replace all the references to FusedBatchNorm with FusedBatchNormV3.

Let me know if that works for you, but if not, I can take a look later in the week.

dannadori commented 4 years ago

Thanks @oandrienko

I tried. That is I replaced usedBatchNorm with FusedBatchNormV3 at two files.

  1. configs/compression/icnet_resnet_v1_pruner_v2.prune_config
  2. libs/filter_pruner.py But I got the error message below.
Traceback (most recent call last):
  File "compress.py", line 103, in <module>
    tf.app.run()
  File "/home/wataru/git_work/fast-semantic-segmentation/tf/lib/python3.7/site-packages/tensorflow_core/python/platform/app.py", line 40, in run
    _run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
  File "/home/wataru/git_work/fast-semantic-segmentation/tf/lib/python3.7/site-packages/absl/app.py", line 299, in run
    _run_main(main, args)
  File "/home/wataru/git_work/fast-semantic-segmentation/tf/lib/python3.7/site-packages/absl/app.py", line 250, in _run_main
    sys.exit(main(argv))
  File "compress.py", line 99, in main
    output_checkpoint_name=output_path_name)
  File "/home/wataru/git_work/fast-semantic-segmentation/libs/filter_pruner.py", line 405, in save
    session.graph, trainable_var, init_value)
  File "/home/wataru/git_work/fast-semantic-segmentation/libs/graph_utils.py", line 128, in add_variable_to_graph
    validate_shape=False)
  File "/home/wataru/git_work/fast-semantic-segmentation/tf/lib/python3.7/site-packages/tensorflow_core/python/ops/variables.py", line 260, in __call__
    return cls._variable_v2_call(*args, **kwargs)
TypeError: _variable_v2_call() got an unexpected keyword argument 'collections'

So, I inserted print('new_name->>',new_name) at add_variable_to_graph. Then I got

new_name->> CascadeFeatureFusion/Conv/BatchNorm/beta

Any idea?

oandrienko commented 4 years ago

Can you let me know what version of Tensorflow you are using?

dannadori commented 4 years ago

1.15.2

oandrienko commented 4 years ago

Hey @dannadori, sorry for the delay, I've been super busy. I can try to check it this weekend, I think v1.15 had introduced a new version of TF variables. It might be a matter of just fixing the import in libs/graph_utils.py. Let me know if you figure it out. If not I'll try to check it out myself and update the repo soon.

dannadori commented 4 years ago

Sorry, I couldn't do anything about this, because I was very busy. I will try if you give more concrete instruction today or tomorrow. How to fix "libs/graph_utils.py"? I can not find string "BatchNorm"...

tzhong518 commented 4 years ago

Hey @dannadori, sorry for the delay, I've been super busy. I can try to check it this weekend, I think v1.15 had introduced a new version of TF variables. It might be a matter of just fixing the import in libs/graph_utils.py. Let me know if you figure it out. If not I'll try to check it out myself and update the repo soon.

Hi @oandrienko , I met the same problem: TypeError: _variable_v2_call() got an unexpected keyword argument 'collections'. I'm using tensorflow 1.14. I replaced the Variable in libs/gragh_utils.py with VariableV1 and it worked. However, the pruned model has very poor performance (around 0.015 mIoU). Is this normal? Will retraining fix it? For the evaluation of pruned model I used the two_stage_icnet_0.5_1025_resnet_v1_stage_2.config.

Thank you in advance.