Closed jtfields closed 1 year ago
I changed the training_args from evaluate_during_training=True to evaluation_strategy='epoch'. With this change, trainer.train produces the error - AxisError: axis 1 is out of bounds for array of dimension 1. Any suggestions for how to fix this?
I think this might just be an issue due to a change in how EvalPrediction
returns results. In the calc_classification_metrics
function, just set predictions = p.predictions[0]
at the start and use predictions
everywhere. I'll push a change reflecting this sometime soon.
I changed evaluate_during_training = True to evaluation_strategy='epoch' and changed the following in calc_classification_metrics and everything works now. Thank you for the help resolving this issue!
def calc_classification_metrics(p: EvalPrediction):
predictions = p.predictions[0]
pred_labels = np.argmax(predictions, axis=1)
pred_scores = softmax(predictions, axis=1)[:, 1]
labels = p.label_ids
if len(np.unique(labels)) == 2: # binary classification
roc_auc_pred_score = roc_auc_score(labels, pred_scores)
precisions, recalls, thresholds = precision_recall_curve(labels,
fscore = (2 * precisions * recalls) / (precisions + recalls)
fscore[np.isnan(fscore)] = 0
ix = np.argmax(fscore)
threshold = thresholds[ix].item()
pr_auc = auc(recalls, precisions)
tn, fp, fn, tp = confusion_matrix(labels, pred_labels, labels=[0, 1]).ravel()
result = {'roc_auc': roc_auc_pred_score,
'threshold': threshold,
'pr_auc': pr_auc,
'recall': recalls[ix].item(),
'precision': precisions[ix].item(), 'f1': fscore[ix].item(),
'tn': tn.item(), 'fp': fp.item(), 'fn': fn.item(), 'tp': tp.item()
acc = (pred_labels == labels).mean()
f1 = f1_score(y_true=labels, y_pred=pred_labels)
result = {
"acc": acc,
"f1": f1,
"acc_and_f1": (acc + f1) / 2,
"mcc": matthews_corrcoef(labels, pred_labels)
return result
When running the sample Colab file, I receive an error when evaluate_during_training is set to True. When I comment out evaluate_during_training, the code runs but the calc_classification_metrics are not generated. I did find this message on HuggingFace which seems to be related to the issue with evaluate_during training --