TimeLovercc / CAF-GNN

[CIKM 2023] Towards Fair Graph Neural Networks via Graph Counterfactual.
https://arxiv.org/abs/2307.04937
MIT License
11 stars 3 forks source link

Request code for reproducing Countertactual Fairness Metric for Synthetic dataset #4

Open Tlhey opened 6 days ago

Tlhey commented 6 days ago

Hi Could you provide the code or method for generating the counterfactual fairness metrics that may yield the data in Table 3 and Figure 3 in the original paper? Is it aligned with the method with that in GEAR? Thanks a lot! image image

TimeLovercc commented 6 days ago

Sure. Here is the code in the evaluation part: elif args.model in ['our', 'our2']: model.load_state_dict(torch.load(f'./weights/weights_{args.model}_{args.dataset}_{args.seed}.pt')) model.eval() output, _ = model(features.to(device), edge_index.to(device)) counter_features = features_cf.clone() counter_output, _ = model(counter_features.to(device), edge_index_cf.to(device))

For datasets generation, here is the code: ` def generate_synthetic_data(path, n, z_dim, p, q, alpha, beta, threshold, dim): sens = np.random.binomial(n=1, p=p, size=n) sens_repeat = np.repeat(sens.reshape(-1, 1), z_dim, axis=1) sens_embedding = np.random.normal(loc=sens_repeat, scale=1, size=(n, z_dim)) labels = np.random.binomial(n=1, p=q, size=n) labels_repeat = np.repeat(labels.reshape(-1, 1), z_dim, axis=1) labels_embedding = np.random.normal(loc=labels_repeat, scale=1, size=(n, z_dim)) features_embedding = np.concatenate((sens_embedding, labels_embedding), axis=1) weight = np.random.normal(loc=0, scale=1, size=(z_dim*2, dim))

features = np.matmul(features_embedding, weight)

features = np.matmul(features_embedding, weight) + np.random.normal(loc=0, scale=1, size=(n, dim))

adj = np.zeros((n, n))
sens_sim = np.zeros((n, n))
labels_sim = np.zeros((n, n))
for i in range(n):
    for j in range(i, n):  # i<=j
        if i == j:
            sens_sim[i][j] = -1
            labels_sim[i][j] = -1
            continue
        sens_sim[i][j] = sens_sim[j][i] = (sens[i] == sens[j])
        labels_sim[i][j] = labels_sim[j][i] = (labels[i] == labels[j])
        # sim_ij = 1 - spatial.distance.cosine(embedding[i], embedding[j])  # [-1, 1]
        # adj[i][j] = adj[j][i] = sim_ij + alpha * (sens[i] == sens[j])

similarities = cosine_similarity(features_embedding)  # n x n
similarities[np.arange(n), np.arange(n)] = -1
adj = similarities + alpha * sens_sim + beta * labels_sim
print('adj max: ', adj.max(), ' min: ', adj.min())
adj[np.where(adj >= threshold)] = 1
adj[np.where(adj < threshold)] = 0
edge_num = adj.sum()
adj = sparse.csr_matrix(adj)
# features = np.concatenate((sens.reshape(-1,1), features), axis=1)

# generate counterfactual
sens_flip = 1 - sens
sens_flip_repeat = np.repeat(sens_flip.reshape(-1, 1), z_dim, axis=1)
sens_flip_embedding = np.random.normal(loc=sens_flip_repeat, scale=1, size=(n, z_dim))
features_embedding = np.concatenate((sens_flip_embedding, labels_embedding), axis=1)
features_cf = np.matmul(features_embedding, weight) + np.random.normal(loc=0, scale=1, size=(n, dim))

adj_cf = np.zeros((n, n))
sens_cf_sim = np.zeros((n, n))
labels_cf_sim = np.zeros((n, n))
for i in range(n):
    for j in range(i, n):
        if i == j:
            sens_cf_sim[i][j] = -1
            labels_cf_sim[i][j] = -1
            continue
        sens_cf_sim[i][j] = sens_cf_sim[j][i] = (sens_flip[i] == sens_flip[j])
        labels_cf_sim[i][j] = labels_cf_sim[j][i] = (labels[i] == labels[j])

similarities_cf = cosine_similarity(features_cf)  # n x n
similarities_cf[np.arange(n), np.arange(n)] = -1
adj_cf = similarities_cf + alpha * sens_cf_sim + beta * labels_cf_sim
print('adj_cf max', adj_cf.max(), ' min: ', adj_cf.min())
adj_cf[np.where(adj_cf >= threshold)] = 1
adj_cf[np.where(adj_cf < threshold)] = 0
adj_cf = sparse.csr_matrix(adj_cf)
# features_cf = np.concatenate((sens_flip.reshape(-1,1), features_cf), axis=1)

# statistics
# pre_analysis(adj, labels, sens)
print('edge num: ', edge_num)
data = {'x': features, 'adj': adj, 'labels': labels, 'sens': sens, 'x_cf': features_cf, 'adj_cf': adj_cf}
# scio.savemat(path, data)
print('data saved in ', path)
return data

`

The evaluation is same for GEAR and our method. Thank you for your questions.

Tlhey commented 5 days ago

Thanks for your timely reply! In GEAR's evaluation part, they generate subgraphs for each node to calculate the mean cf. Do you also implement with generating subgraphs when derive cf metric or simply follows the definition with something like cf = 1 - (np.sum(y_pred_cf == y_pred) / n):

# For convenience, attached is code for how GEAR evaluate cf with subgraphs
def evaluate(model, data, subgraph, cf_subgraph_list, labels, sens, idx_select, type='all'):
    loss_result = compute_loss(model, subgraph, cf_subgraph_list, labels, idx_select)
    if type == 'easy':
        eval_results = {'loss': loss_result['loss'], 'loss_c': loss_result['loss_c'], 'loss_s': loss_result['loss_s']}

    elif type == 'all':
        n = len(labels)
        idx_select_mask = (torch.zeros(n).scatter_(0, idx_select, 1) > 0)  # size = n, bool

        # performance
        emb = get_all_node_emb(model, idx_select_mask, subgraph, n)
        output = model.predict(emb)
        output_preds = (output.squeeze() > 0).type_as(labels)

        auc_roc = roc_auc_score(labels.cpu().numpy()[idx_select], output.detach().cpu().numpy())
        f1_s = f1_score(labels[idx_select].cpu().numpy(), output_preds.cpu().numpy())
        acc = accuracy_score(labels[idx_select].cpu().numpy(), output_preds.cpu().numpy())

        # fairness
        parity, equality = fair_metric(output_preds.cpu().numpy(), labels[idx_select].cpu().numpy(),
                                       sens[idx_select].numpy())
        # counterfactual fairness
        cf = 0.0
        for si in range(len(cf_subgraph_list)):
            cf_subgraph = cf_subgraph_list[si]
            emb_cf = get_all_node_emb(model, idx_select_mask, cf_subgraph, n)
            output_cf = model.predict(emb_cf)
            output_preds_cf = (output_cf.squeeze() > 0).type_as(labels)
            cf_si = 1 - (output_preds.eq(output_preds_cf).sum().item() / idx_select.shape[0])
            cf += cf_si
        cf /= len(cf_subgraph_list)

        eval_results = {'acc': acc, 'auc': auc_roc, 'f1': f1_s, 'parity': parity, 'equality': equality, 'cf': cf,
                        'loss': loss_result['loss'], 'loss_c': loss_result['loss_c'], 'loss_s': loss_result['loss_s']}  # counterfactual_fairness
    return eval_results