Closed stewu5 closed 2 years ago
trainer = Trainer(
...
).extend(
Validator(
lambda y_true, y_pred: classification_metrics(y_true, y_pred, average=None),
early_stopping=30, trace_order=5, warming_up=warming_up, accuracy=1.0, micro_f1=1.0, micro_precision=1.0, ...),
)
Current recall and precision calculation in the classification_metrics function is only for multi-class cases, because the average option is set to be "weighted", which becomes a weighted recall or precision for even the binary classification case. So, in the binary case, it is not the true recall / precision.