tensorflow / model-analysis

Model analysis tools for TensorFlow
Apache License 2.0
1.26k stars 276 forks source link

Custom multilabel Keras metrics: dynamically initialize weight shape #137

Closed EdwardCuiPeacock closed 3 years ago

EdwardCuiPeacock commented 3 years ago

The custom class is a modification of MeanTensor native from Keras. I changed the update_state function to make sure it's able to accept y_true and y_pred as inputs. I have changed the result function to output only a scalar value instead of a vector (see their original docstring). The MeanTensor class will not initialize any weights (i.e. _total and _count) via add_weight until the first time being called via update_state. This is because it will need to determine the shape of these weights based on the input.

class MeanTensor(tf.keras.metrics.Metric):
    """Computes the element-wise (weighted) mean of the given tensors.

    `MeanTensor` returns a tensor with the same shape of the input tensors. The
    mean value is updated by keeping local variables `total` and `count`. The
    `total` tracks the sum of the weighted values, and `count` stores the sum of
    the weighted counts.

    Args:
      name: (Optional) string name of the metric instance.
      dtype: (Optional) data type of the metric result.

    Standalone usage:

    >>> m = tf.keras.metrics.MeanTensor()
    >>> m.update_state([0, 1, 2, 3])
    >>> m.update_state([4, 5, 6, 7])
    >>> m.result().numpy()
    array([2., 3., 4., 5.], dtype=float32)

    >>> m.update_state([12, 10, 8, 6], sample_weight= [0, 0.2, 0.5, 1])
    >>> m.result().numpy()
    array([2.       , 3.6363635, 4.8      , 5.3333335], dtype=float32)
    """

    def __init__(self, name="mean_tensor", dtype=None):
        super(MeanTensor, self).__init__(name=name, dtype=dtype)
        self._shape = None
        self._total = None
        self._count = None
        self._built = False

    def _build(self, shape):
        self._shape = tensor_shape.TensorShape(shape)
        self._build_input_shape = self._shape
        # Create new state variables
        self._total = self.add_weight(
            "total", shape=shape, initializer=init_ops.zeros_initializer
        )
        self._count = self.add_weight(
            "count", shape=shape, initializer=init_ops.zeros_initializer
        )
        with ops.init_scope():
            if not context.executing_eagerly():
                K._initialize_variables(
                    K._get_session()
                )  # pylint: disable=protected-access
        self._built = True

    @property
    def total(self):
        return self._total if self._built else None

    @property
    def count(self):
        return self._count if self._built else None

    def update_state(self, y_true, values, sample_weight=None):
        """Accumulates statistics for computing the element-wise mean.

        Args:
          values: Per-example value.
          sample_weight: Optional weighting of each example. Defaults to 1.

        Returns:
          Update op.
        """
        values = math_ops.cast(values, self._dtype)
        if not self._built:
            self._build(values.shape)
        elif values.shape != self._shape:
            raise ValueError(
                "MeanTensor input values must always have the same "
                "shape. Expected shape (set during the first call): {}. "
                "Got: {}".format(self._shape, values.shape)
            )

        num_values = array_ops.ones_like(values)
        if sample_weight is not None:
            sample_weight = math_ops.cast(sample_weight, self._dtype)

            # Update dimensions of weights to match with values if possible.
            values, _, sample_weight = tf_losses_utils.squeeze_or_expand_dimensions(
                values, sample_weight=sample_weight
            )
            try:
                # Broadcast weights if possible.
                sample_weight = weights_broadcast_ops.broadcast_weights(
                    sample_weight, values
                )
            except ValueError:
                # Reduce values to same ndim as weight array
                ndim = K.ndim(values)
                weight_ndim = K.ndim(sample_weight)
                values = math_ops.reduce_mean(
                    values, axis=list(range(weight_ndim, ndim))
                )

            num_values = math_ops.multiply(num_values, sample_weight)
            values = math_ops.multiply(values, sample_weight)

        update_total_op = self._total.assign_add(values)
        with ops.control_dependencies([update_total_op]):
            return self._count.assign_add(num_values)

    def result(self):
        if not self._built:
            raise ValueError(
                "MeanTensor does not have any result yet. Please call the MeanTensor "
                "instance or use `.update_state(value)` before retrieving the result."
            )
        return tf.reduce_mean(math_ops.div_no_nan(self.total, self.count))

    def reset_states(self):
        if self._built:
            K.batch_set_value(
                [(v, np.zeros(self._shape.as_list())) for v in self.variables]
            )

When running the evaluation with tfma, I am receiving the following error:

Traceback (most recent call last):
  File "/usr/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.7/dist-packages/tfx/orchestration/kubeflow/container_entrypoint.py", line 360, in <module>
    main()
  File "/usr/local/lib/python3.7/dist-packages/tfx/orchestration/kubeflow/container_entrypoint.py", line 353, in main
    execution_info = launcher.launch()
  File "/usr/local/lib/python3.7/dist-packages/tfx/orchestration/launcher/base_component_launcher.py", line 209, in launch
    copy.deepcopy(execution_decision.exec_properties))
  File "/usr/local/lib/python3.7/dist-packages/tfx/orchestration/launcher/in_process_component_launcher.py", line 72, in _run_executor
    copy.deepcopy(input_dict), output_dict, copy.deepcopy(exec_properties))
  File "/usr/local/lib/python3.7/dist-packages/tfx/components/evaluator/executor.py", line 259, in Do
    tensor_adapter_config=tensor_adapter_config))
  File "/usr/local/lib/python3.7/dist-packages/apache_beam/pipeline.py", line 582, in __exit__
    self.result = self.run()
  File "/usr/local/lib/python3.7/dist-packages/apache_beam/pipeline.py", line 561, in run
    return self.runner.run_pipeline(self, self._options)
  File "/usr/local/lib/python3.7/dist-packages/apache_beam/runners/portability/fn_api_runner/fn_runner.py", line 183, in run_pipeline
    pipeline.to_runner_api(default_environment=self._default_environment))
  File "/usr/local/lib/python3.7/dist-packages/apache_beam/runners/portability/fn_api_runner/fn_runner.py", line 193, in run_via_runner_api
    return self.run_stages(stage_context, stages)
  File "/usr/local/lib/python3.7/dist-packages/apache_beam/runners/portability/fn_api_runner/fn_runner.py", line 360, in run_stages
    bundle_context_manager,
  File "/usr/local/lib/python3.7/dist-packages/apache_beam/runners/portability/fn_api_runner/fn_runner.py", line 556, in _run_stage
    bundle_manager)
  File "/usr/local/lib/python3.7/dist-packages/apache_beam/runners/portability/fn_api_runner/fn_runner.py", line 596, in _run_bundle
    data_input, data_output, input_timers, expected_timer_output)
  File "/usr/local/lib/python3.7/dist-packages/apache_beam/runners/portability/fn_api_runner/fn_runner.py", line 897, in process_bundle
    result_future = self._worker_handler.control_conn.push(process_bundle_req)
  File "/usr/local/lib/python3.7/dist-packages/apache_beam/runners/portability/fn_api_runner/worker_handlers.py", line 380, in push
    response = self.worker.do_instruction(request)
  File "/usr/local/lib/python3.7/dist-packages/apache_beam/runners/worker/sdk_worker.py", line 607, in do_instruction
    getattr(request, request_type), request.instruction_id)
  File "/usr/local/lib/python3.7/dist-packages/apache_beam/runners/worker/sdk_worker.py", line 644, in process_bundle
    bundle_processor.process_bundle(instruction_id))
  File "/usr/local/lib/python3.7/dist-packages/apache_beam/runners/worker/bundle_processor.py", line 1000, in process_bundle
    element.data)
  File "/usr/local/lib/python3.7/dist-packages/apache_beam/runners/worker/bundle_processor.py", line 228, in process_encoded
    self.output(decoded_value)
  File "apache_beam/runners/worker/operations.py", line 357, in apache_beam.runners.worker.operations.Operation.output
  File "apache_beam/runners/worker/operations.py", line 359, in apache_beam.runners.worker.operations.Operation.output
  File "apache_beam/runners/worker/operations.py", line 221, in apache_beam.runners.worker.operations.SingletonConsumerSet.receive
  File "apache_beam/runners/worker/operations.py", line 927, in apache_beam.runners.worker.operations.CombineOperation.process
  File "apache_beam/runners/worker/operations.py", line 931, in apache_beam.runners.worker.operations.CombineOperation.process
  File "apache_beam/runners/worker/operations.py", line 359, in apache_beam.runners.worker.operations.Operation.output
  File "apache_beam/runners/worker/operations.py", line 221, in apache_beam.runners.worker.operations.SingletonConsumerSet.receive
  File "apache_beam/runners/worker/operations.py", line 927, in apache_beam.runners.worker.operations.CombineOperation.process
  File "apache_beam/runners/worker/operations.py", line 931, in apache_beam.runners.worker.operations.CombineOperation.process
  File "/usr/local/lib/python3.7/dist-packages/apache_beam/transforms/combiners.py", line 974, in extract_only
    return self.combine_fn.extract_output(accumulator)
  File "/usr/local/lib/python3.7/dist-packages/tensorflow_model_analysis/evaluators/metrics_plots_and_validations_evaluator.py", line 357, in extract_output
    output = c.extract_output(a)
  File "/usr/local/lib/python3.7/dist-packages/tensorflow_model_analysis/metrics/tf_metric_wrapper.py", line 707, in extract_output
    metric.set_weights(weights)
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 1810, in set_weights
    (self.name, len(weights), expected_num_weights, str(weights)[:50]))
ValueError: You called `set_weights(weights)` on layer "mean_tensor" with a weight list of length 2, but the layer was expecting 0 weights. Provided weights: [array([[0.22180203, 0.2250297 , 0.23172688, ..., ...

The error suggests that, when the metrics are being combined via the beam combiner, another instance of the metric class is being initiated, without being called (not further initialized via update_state -> _build). It appears this is only a problem at the final combining step, as the function call of update_state was able to succeed when each beam container was running.

Setting a fixed shape avoids this problem. However, in our application within TFX, the output shape is dynamic and cannot be set until the Trainer component of the pipeline finished running. So we would like to have the ability to initialize the shape dynamically.

EdwardCuiPeacock commented 3 years ago

A workaround is identified. The problem is associated with metric.set_weights(weights) from tensorflow_model_analysis.metrics.tf_metric_wrapper.py.

def extract_output(
      self, accumulator: _CompilableMetricsAccumulator
  ) -> Dict[metric_types.MetricKey, Any]:
    self._process_batch(accumulator)
    result = {}
    for output_index, output_name in enumerate(self._output_names):
      for metric_index, metric in enumerate(self._metrics[output_name]):
        key = metric_types.MetricKey(
            name=metric.name,
            model_name=self._model_name,
            output_name=output_name,
            sub_key=self._sub_key)
        weights = accumulator.get_weights(output_index, metric_index)
        if weights is not None:
          metric.set_weights(weights)
        else:
          metric.reset_states()
        result[key] = metric.result().numpy()
    return result

A workaround is to override set_weights function in the custom keras metric class:

def set_weights(self, weights: List):
        """Override for proper beam combining."""
        if not self._is_built:
            self._build((1, len(weights[0])))
        self.score.assign(weights[0])