stefanradev93 / BayesFlow

A Python library for amortized Bayesian workflows using generative neural networks.
https://bayesflow.org/
MIT License
297 stars 45 forks source link

Possible axis mislabelling in bayesflow.diagnostics.plot_calibration_curves #90

Closed Chad-Chong closed 1 year ago

Chad-Chong commented 1 year ago

Hello all,

Thank you all for the marvelous work on Bayesflow. I have been trying to understand the calibration curve outputted by bayesflow.diagnostics.plot_calibration_curves. It seems to me that the calibration curve should be a function of predicted probability instead of the true probability. As defined in the expected_calibration_error function,

    # Loop for each model and compute calibration errs per bin
    for k in range(n_models):
        y_true = (m_true.argmax(axis=1) == k).astype(np.float32)
        y_prob = m_pred[:, k]
        prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=num_bins)

        # Compute ECE by weighting bin errors by bin size
        bins = np.linspace(0.0, 1.0, num_bins + 1)
        binids = np.searchsorted(bins[1:-1], y_prob)
        bin_total = np.bincount(binids, minlength=len(bins))
        nonzero = bin_total != 0
        cal_err = np.sum(np.abs(prob_true - prob_pred) * (bin_total[nonzero] / len(y_true)))

        cal_errs.append(cal_err)
        probs.append((prob_true, prob_pred))
    return cal_errs, probs

Then, the true probability array should be the first element in the output list instead of the second element. Going back to the bayesflow.diagnostics.plot_calibration_curves function,

    # Determine n_subplots dynamically
    n_row = int(np.ceil(num_models / 6))
    n_col = int(np.ceil(num_models / n_row))
    cal_errs, cal_probs = expected_calibration_error(true_models, pred_models, num_bins)

    # Initialize figure
    if fig_size is None:
        fig_size = (int(5 * n_col), int(5 * n_row))
    fig, axarr = plt.subplots(n_row, n_col, figsize=fig_size)
    if n_row > 1:
        ax = axarr.flat

    # Plot marginal calibration curves in a loop
    if n_row > 1:
        ax = axarr.flat
    else:
        ax = axarr
    for j in range(num_models):
        # Plot calibration curve
        ax[j].plot(cal_probs[j][0], cal_probs[j][1], color=color)

If I did not misunderstand, the true probability is plotted in the x-axis while the predicted probability is plotted in the y-axis. However, the axis labels are,

        ax[j].set_xlabel("Predicted probability", fontsize=label_fontsize)
        ax[j].set_ylabel("True probability", fontsize=label_fontsize)

Did I miss anything here? Once again, thank you for the splendid work on Bayesflow!

elseml commented 1 year ago

Hi Chad, thank you very much for spotting this important issue! The axis ordering is now corrected.