Closed EdwardCuiPeacock closed 3 years ago
A workaround is identified. The problem is associated with metric.set_weights(weights)
from tensorflow_model_analysis.metrics.tf_metric_wrapper.py
.
def extract_output(
self, accumulator: _CompilableMetricsAccumulator
) -> Dict[metric_types.MetricKey, Any]:
self._process_batch(accumulator)
result = {}
for output_index, output_name in enumerate(self._output_names):
for metric_index, metric in enumerate(self._metrics[output_name]):
key = metric_types.MetricKey(
name=metric.name,
model_name=self._model_name,
output_name=output_name,
sub_key=self._sub_key)
weights = accumulator.get_weights(output_index, metric_index)
if weights is not None:
metric.set_weights(weights)
else:
metric.reset_states()
result[key] = metric.result().numpy()
return result
A workaround is to override set_weights
function in the custom keras metric class:
def set_weights(self, weights: List):
"""Override for proper beam combining."""
if not self._is_built:
self._build((1, len(weights[0])))
self.score.assign(weights[0])
The custom class is a modification of
MeanTensor
native from Keras. I changed theupdate_state
function to make sure it's able to accepty_true
andy_pred
as inputs. I have changed theresult
function to output only a scalar value instead of a vector (see their original docstring). TheMeanTensor
class will not initialize any weights (i.e._total
and_count
) viaadd_weight
until the first time being called viaupdate_state
. This is because it will need to determine the shape of these weights based on the input.When running the evaluation with
tfma
, I am receiving the following error:The error suggests that, when the metrics are being combined via the beam combiner, another instance of the metric class is being initiated, without being called (not further initialized via
update_state
->_build
). It appears this is only a problem at the final combining step, as the function call ofupdate_state
was able to succeed when each beam container was running.Setting a fixed shape avoids this problem. However, in our application within TFX, the output shape is dynamic and cannot be set until the
Trainer
component of the pipeline finished running. So we would like to have the ability to initialize the shape dynamically.