tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.49k stars 320 forks source link

sparsity.prune_low_magnitude fails with mixed precision policy mixed_float16 #409

Open dsuthar-nvidia opened 4 years ago

dsuthar-nvidia commented 4 years ago

Describe the bug When using tf.keras.mixed_precision.experimental.Policy("mixed_float16", loss_scale="dynamic") the sparsity.prune_low_magnitude fails in tensor type conversion with the error Tensor conversion requested dtype float32 for Tensor with dtype float16: <tf.Tensor 'pruning_ops/Cast_2:0' shape=(3, 3, 1, 12) dtype=float16>. Things work perfectly fine when precision is set to the default float32. Looks like some piece of code is not properly respecting the dtype.

System information

TensorFlow installed from (source or binary): pip3

TensorFlow version: 2.2.0

TensorFlow Model Optimization version:

Python version: 3.6

Describe the expected behavior

The prune_low_magnitude should work with layers using mixed_float16 policy.

Describe the current behavior

Throws error described above.

Code to reproduce the issue See this colab link.

Screenshots If applicable, add screenshots to help explain your problem.

Additional context Add any other context about the problem here.

dsuthar-nvidia commented 4 years ago

@alanchiao any insights on when will this be fixed?

dsuthar-nvidia commented 4 years ago

@alanchiao Is this a major change? From the stack trace it looked more like a minor bug where some piece of code is not properly respecting the dtype.

teijeong commented 3 years ago

Hi @dd1923 , sorry for really late response.

Just want to check if this still bugs you before picking up this again.

e-dupuis commented 3 years ago

Hi, I can confirm that the issue still exist

ValueError: Tensor conversion requested dtype float32 for Tensor with dtype float16: <tf.Tensor 'prune_low_magnitude_conv1/Mul:0' shape=(7, 7, 3, 64) dtype=float16> 

Complete stack trace:

Traceback (most recent call last):
  File "official/vision/image_classification/classifier_trainer.py", line 530, in <module>
    app.run(main)
  File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 303, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "official/vision/image_classification/classifier_trainer.py", line 517, in main
    stats = run(flags.FLAGS)
  File "official/vision/image_classification/classifier_trainer.py", line 509, in run
    return train_and_eval(params, strategy_override)
  File "official/vision/image_classification/classifier_trainer.py", line 435, in train_and_eval
    clone_function=apply_pruning,
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/models.py", line 431, in clone_model
    model, input_tensors=input_tensors, layer_fn=clone_function)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/models.py", line 201, in _clone_functional_model
    created_layers=created_layers))
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py", line 1285, in reconstruct_from_config
    process_node(layer, node_data)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py", line 1233, in process_node
    output_tensors = layer(input_tensors, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 952, in __call__
    input_list)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 1091, in _functional_construction_call
    inputs, input_masks, args, kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 822, in _keras_tensor_symbolic_call
    return self._infer_output_signature(inputs, args, kwargs, input_masks)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 863, in _infer_output_signature
    outputs = call_fn(inputs, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py", line 670, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:

    /usr/local/lib/python3.6/dist-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py:258 call  *
        self.add_update(self.pruning_obj.weight_mask_op())
    /usr/local/lib/python3.6/dist-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py:195 weight_mask_op  *
        return tf.group(self._weight_assign_objs())
    /usr/local/lib/python3.6/dist-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py:168 update_var  *
        return tf_compat.assign(variable, reduced_value)
    /usr/local/lib/python3.6/dist-packages/tensorflow_model_optimization/python/core/keras/compat.py:28 assign  *
        return ref.assign(value, name=name)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/mixed_precision/autocast_variable.py:237 assign  **
        name, read_value)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/mixed_precision/autocast_variable.py:209 _apply_assign_update
        assign_op = update_fn(value, use_locking, name, False)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:882 assign
        value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/profiler/trace.py:163 wrapped
        return func(*args, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:1509 convert_to_tensor
        (dtype.name, value.dtype.name, value))
paulds8 commented 3 years ago

This is an issue for me as well. TF2.4.1 with 0.5 of tensorflow_model_optimization.

I've been trying to wrap my head around this - even tried setting various layer kwargs.

It seems the failure happens on the model clone step: keras.models.clone_model()

ValueError                                Traceback (most recent call last)
<command--1> in <module>
     12 
     13 with open(filename, "rb") as f:
---> 14   exec(f.read())
     15 

<string> in <module>

/databricks/python/lib/python3.8/site-packages/SCVS/cli.py in main()
      8 def main():
      9     args = parse_args()
---> 10     action_args(args)
     11 
     12 

/databricks/python/lib/python3.8/site-packages/SCVS/cli.py in action_args(args)
     79     else:
     80         # run function based on specified command
---> 81         args.func(args)

/databricks/python/lib/python3.8/site-packages/SCVS/cli.py in cmd_task(args)
     61     # execute the task
     62     from SCVS.app import main
---> 63     main(task_id, config_file, instructions)
     64     # shutdown (need to flush logs before the python interpreter terminates)
     65     logging.shutdown()

/databricks/python/lib/python3.8/site-packages/SCVS/app.py in main(task_id, config_file, instructions)
     45         util.flush_logger()
     46         time.sleep(10)
---> 47         raise e
     48 
     49 

/databricks/python/lib/python3.8/site-packages/SCVS/app.py in main(task_id, config_file, instructions)
     38     logger.info(f"starting task {task_id} (class {task_cls.__name__}): {task.summary}")
     39     try:
---> 40         task.start()
     41         t1 = time.perf_counter()
     42         logger.info(f"task {task_id} finished after {util.format_time_span(t1-t0)}")

/databricks/python/lib/python3.8/site-packages/SCVS/pipeline/task.py in start(self)
    221         """
    222         self.validate_input()
--> 223         self.run()
    224         self.validate_output()
    225 

/databricks/python/lib/python3.8/site-packages/SCVS/pipeline/train_object_detection.py in run(self)
    203                     #     record_summaries=True,
    204                     # )
--> 205                     custom_train_loop(
    206                         pipeline_config_path=self.path_pipeline_config,
    207                         model_dir=self.path_model,

/databricks/python/lib/python3.8/site-packages/SCVS/ml/tf_training.py in custom_train_loop(pipeline_config_path, model_dir, config_override, train_steps, use_tpu, save_final_config, checkpoint_every_n, checkpoint_max_to_keep, record_summaries, mlflow_log_every_n, **kwargs)
    293                 #TODO(PdS): Investigate if we should prune other parts of the architecture: _box_predictor, etc
    294 
--> 295                 detection_model._feature_extractor._efficientnet = tfmot.sparsity.keras.prune_low_magnitude(
    296                     to_prune=detection_model._feature_extractor._efficientnet,
    297                     pruning_schedule=pruning_schedule,

/databricks/python/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/prune.py in prune_low_magnitude(to_prune, pruning_schedule, block_size, block_pooling_type, **kwargs)
    182     return _prune_list(to_prune, **params)
    183   elif is_sequential_or_functional:
--> 184     return keras.models.clone_model(
    185         to_prune, input_tensors=None, clone_function=_add_pruning_wrapper)
    186   elif is_keras_layer:

/databricks/python/lib/python3.8/site-packages/tensorflow/python/keras/models.py in clone_model(model, input_tensors, clone_function)
    428         model, input_tensors=input_tensors, layer_fn=clone_function)
    429   else:
--> 430     return _clone_functional_model(
    431         model, input_tensors=input_tensors, layer_fn=clone_function)
    432 

/databricks/python/lib/python3.8/site-packages/tensorflow/python/keras/models.py in _clone_functional_model(model, input_tensors, layer_fn)
    198   # Reconstruct model from the config, using the cloned layers.
    199   input_tensors, output_tensors, created_layers = (
--> 200       functional.reconstruct_from_config(model_configs,
    201                                          created_layers=created_layers))
    202   metrics_names = model.metrics_names

/databricks/python/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py in reconstruct_from_config(config, custom_objects, created_layers)
   1283       if layer in unprocessed_nodes:
   1284         for node_data in unprocessed_nodes.pop(layer):
-> 1285           process_node(layer, node_data)
   1286 
   1287   input_tensors = []

/databricks/python/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py in process_node(layer, node_data)
   1231         input_tensors = (
   1232             base_layer_utils.unnest_if_single_tensor(input_tensors))
-> 1233       output_tensors = layer(input_tensors, **kwargs)
   1234 
   1235       # Update node index map.

/databricks/python/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
    949     # >> model = tf.keras.Model(inputs, outputs)
    950     if _in_functional_construction_mode(self, inputs, args, kwargs, input_list):
--> 951       return self._functional_construction_call(inputs, args, kwargs,
    952                                                 input_list)
    953 

/databricks/python/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _functional_construction_call(self, inputs, args, kwargs, input_list)
   1088           layer=self, inputs=inputs, build_graph=True, training=training_value):
   1089         # Check input assumptions set after layer building, e.g. input shape.
-> 1090         outputs = self._keras_tensor_symbolic_call(
   1091             inputs, input_masks, args, kwargs)
   1092 

/databricks/python/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _keras_tensor_symbolic_call(self, inputs, input_masks, args, kwargs)
    820       return nest.map_structure(keras_tensor.KerasTensor, output_signature)
    821     else:
--> 822       return self._infer_output_signature(inputs, args, kwargs, input_masks)
    823 
    824   def _infer_output_signature(self, inputs, args, kwargs, input_masks):

/databricks/python/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _infer_output_signature(self, inputs, args, kwargs, input_masks)
    861           # TODO(kaftan): do we maybe_build here, or have we already done it?
    862           self._maybe_build(inputs)
--> 863           outputs = call_fn(inputs, *args, **kwargs)
    864 
    865         self._handle_activity_regularization(inputs, outputs)

/databricks/python/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
    668       except Exception as e:  # pylint:disable=broad-except
    669         if hasattr(e, 'ag_error_metadata'):
--> 670           raise e.ag_error_metadata.to_exception(e)
    671         else:
    672           raise

ValueError: in user code:

    /databricks/python/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py:258 call  *
        self.add_update(self.pruning_obj.weight_mask_op())
    /databricks/python/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py:195 weight_mask_op  *
        return tf.group(self._weight_assign_objs())
    /databricks/python/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py:190 _weight_assign_objs  *
        assign_objs.append(tf_compat.assign(weight, masked_weight))
    /databricks/python/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/keras/compat.py:28 assign  *
        return ref.assign(value, name=name)
    /databricks/python/lib/python3.8/site-packages/tensorflow/python/keras/mixed_precision/autocast_variable.py:236 assign  **
        return self._apply_assign_update(self._variable.assign, value, use_locking,
    /databricks/python/lib/python3.8/site-packages/tensorflow/python/keras/mixed_precision/autocast_variable.py:209 _apply_assign_update
        assign_op = update_fn(value, use_locking, name, False)
    /databricks/python/lib/python3.8/site-packages/tensorflow/python/distribute/values.py:781 assign
        return values_util.on_write_assign(
    /databricks/python/lib/python3.8/site-packages/tensorflow/python/distribute/values_util.py:140 on_write_assign
        return var._update(  # pylint: disable=protected-access
    /databricks/python/lib/python3.8/site-packages/tensorflow/python/distribute/values.py:940 _update
        return self._update_cross_replica(update_fn, value, **kwargs)
    /databricks/python/lib/python3.8/site-packages/tensorflow/python/distribute/values.py:893 _update_cross_replica
        return self.distribute_strategy.extended.update(
    /databricks/python/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:2494 update
        return self._update(var, fn, args, kwargs, group)
    /databricks/python/lib/python3.8/site-packages/tensorflow/python/distribute/mirrored_strategy.py:710 _update
        fn(v, *distribute_utils.select_replica_mirrored(i, args),
    /databricks/python/lib/python3.8/site-packages/tensorflow/python/distribute/values_util.py:139 <lambda>  **
        assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
    /databricks/python/lib/python3.8/site-packages/tensorflow/python/ops/resource_variable_ops.py:882 assign
        value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
    /databricks/python/lib/python3.8/site-packages/tensorflow/python/profiler/trace.py:163 wrapped
        return func(*args, **kwargs)
    /databricks/python/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:1507 convert_to_tensor
        raise ValueError(

    ValueError: Tensor conversion requested dtype float32 for Tensor with dtype float16: <tf.Tensor 'prune_low_magnitude_stem_conv2d/Mul:0' shape=(3, 3, 3, 48) dtype=float16>
A-strategist commented 1 year ago

I find the result layer from prune_low_magnitude would end up with different dtype_policy from the original one. And here is a solution: call _set_dtype_policy after prune_low_magnitude.

for example

from tensorflow.keras import models, layers
import tensorflow_model_optimization as tfmo

d1 = layers.Dense(5, dtype='mixed_float16')
d2 = layers.Dense(5, dtype='mixed_float16')

print(d1.dtype_policy)
print(d2.dtype_policy)

d11 = tfmo.sparsity.keras.prune_low_magnitude(d1) 
d22 = tfmo.sparsity.keras.prune_low_magnitude(d2) 

print(d11.dtype_policy)
print(d22.dtype_policy)

d11._set_dtype_policy(d1.dtype_policy)
d22._set_dtype_policy(d2.dtype_policy)

print(d11.dtype_policy)
print(d22.dtype_policy)

# now it works fine
inp = layers.Input((10))
tensor = d11(inp)
tensor = d22(tensor)
m = models.Model(inputs=inp, outputs=tensor)
m.compile(loss='mse')

I haven't try global policy, but it should be the same