tensorflow / adanet

Fast and flexible AutoML with learning guarantees.
https://adanet.readthedocs.io
Apache License 2.0
3.47k stars 529 forks source link

Correct place to add custom metric_fn? #150

Closed le-dawg closed 3 years ago

le-dawg commented 4 years ago

I wouldlike to my implementation of the MCC so that I can look at its convergence during training. In the generic Estimator API custom metric_fn can be added as part of the output_spec in dependence on the ModeKeys inside the model_fn. In AdaNet this neat structure has ben shuffled a bit. Where should I look ?

def metric_fn(per_example_loss, label_ids, logits, is_real_example):
          """Compute Matthew's correlations for STS-B."""
          predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
          # https://en.wikipedia.org/wiki/Matthews_correlation_coefficient
          tp, tp_op = tf.metrics.true_positives(
              predictions, label_ids, weights=is_real_example)
          tn, tn_op = tf.metrics.true_negatives(
              predictions, label_ids, weights=is_real_example)
          fp, fp_op = tf.metrics.false_positives(
              predictions, label_ids, weights=is_real_example)
          fn, fn_op = tf.metrics.false_negatives(
              predictions, label_ids, weights=is_real_example)

          # Compute Matthew's correlation
          mcc = tf.div_no_nan(
              tp * tn - fp * fn,
              tf.pow((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn), 0.5))
cweill commented 4 years ago

See the metric_fn argument in the Estimator and AutoEnsembleEstimator constructors: https://adanet.readthedocs.io/en/v0.8.0/adanet.html#adanet.Estimator

le-dawg commented 4 years ago

Including my metric function there resulted in an error that i cannot remmebr off thetop of my head. I will post the error when I return to my workstation .

cweill commented 3 years ago

Closing for now.