TimeLovercc / CAF-GNN

[CIKM 2023] Towards Fair Graph Neural Networks via Graph Counterfactual.
https://arxiv.org/abs/2307.04937
MIT License
11 stars 3 forks source link

Request for details for counterfacutal fairness calculation implementation #5

Closed Tlhey closed 2 months ago

Tlhey commented 3 months ago

Thanks for your timely reply! In GEAR's evaluation part, they generate subgraphs for each node to calculate the mean cf. Do you also implement with generating subgraphs when derive cf metric or simply follows the definition with something like cf = 1 - (np.sum(y_pred_cf == y_pred) / n):

# For convenience, attached is code for how GEAR evaluate cf with subgraphs
def evaluate(model, data, subgraph, cf_subgraph_list, labels, sens, idx_select, type='all'):
    loss_result = compute_loss(model, subgraph, cf_subgraph_list, labels, idx_select)
    if type == 'easy':
        eval_results = {'loss': loss_result['loss'], 'loss_c': loss_result['loss_c'], 'loss_s': loss_result['loss_s']}

    elif type == 'all':
        n = len(labels)
        idx_select_mask = (torch.zeros(n).scatter_(0, idx_select, 1) > 0)  # size = n, bool

        # performance
        emb = get_all_node_emb(model, idx_select_mask, subgraph, n)
        output = model.predict(emb)
        output_preds = (output.squeeze() > 0).type_as(labels)

        auc_roc = roc_auc_score(labels.cpu().numpy()[idx_select], output.detach().cpu().numpy())
        f1_s = f1_score(labels[idx_select].cpu().numpy(), output_preds.cpu().numpy())
        acc = accuracy_score(labels[idx_select].cpu().numpy(), output_preds.cpu().numpy())

        # fairness
        parity, equality = fair_metric(output_preds.cpu().numpy(), labels[idx_select].cpu().numpy(),
                                       sens[idx_select].numpy())
        # counterfactual fairness
        cf = 0.0
        for si in range(len(cf_subgraph_list)):
            cf_subgraph = cf_subgraph_list[si]
            emb_cf = get_all_node_emb(model, idx_select_mask, cf_subgraph, n)
            output_cf = model.predict(emb_cf)
            output_preds_cf = (output_cf.squeeze() > 0).type_as(labels)
            cf_si = 1 - (output_preds.eq(output_preds_cf).sum().item() / idx_select.shape[0])
            cf += cf_si
        cf /= len(cf_subgraph_list)

        eval_results = {'acc': acc, 'auc': auc_roc, 'f1': f1_s, 'parity': parity, 'equality': equality, 'cf': cf,
                        'loss': loss_result['loss'], 'loss_c': loss_result['loss_c'], 'loss_s': loss_result['loss_s']}  # counterfactual_fairness
    return eval_results