wujcan / GIF-torch

18 stars 4 forks source link

Intuition behind the update_edge_index_unlearn() function #3

Open yashpaneliya opened 5 months ago

yashpaneliya commented 5 months ago

I am trying to understand the GIF paper through the code. While traversing the code, I came across this function update_edge_index_unlearn in the ExpGraphInfluenceFunction class. I understood the operations, but can someone please explain the intuition behind those operations in Leyman terms? What is the purpose behind generating those encodings and union operation?

Reference:

 def update_edge_index_unlearn(self, delete_nodes, delete_edge_index=None):
        edge_index = self.data.edge_index.numpy()

        unique_indices = np.where(edge_index[0] < edge_index[1])[0]
        unique_indices_not = np.where(edge_index[0] > edge_index[1])[0]

        if self.args["unlearn_task"] == 'edge':
            remain_indices = np.setdiff1d(unique_indices, delete_edge_index)
        else:
            unique_edge_index = edge_index[:, unique_indices]
            delete_edge_indices = np.logical_or(np.isin(unique_edge_index[0], delete_nodes),
                                                np.isin(unique_edge_index[1], delete_nodes))
            remain_indices = np.logical_not(delete_edge_indices)
            remain_indices = np.where(remain_indices == True)
        remain_encode = edge_index[0, remain_indices] * edge_index.shape[1] * 2 + edge_index[1, remain_indices]
        unique_encode_not = edge_index[1, unique_indices_not] * edge_index.shape[1] * 2 + edge_index[0, unique_indices_not]
        sort_indices = np.argsort(unique_encode_not)
        remain_indices_not = unique_indices_not[sort_indices[np.searchsorted(unique_encode_not, remain_encode, sorter=sort_indices)]]
        remain_indices = np.union1d(remain_indices, remain_indices_not)

        return torch.from_numpy(edge_index[:, remain_indices])

@wujcan

li-yang23 commented 2 months ago

I'm Not the author, but I believe I can give you a hint. First of all, I think there is a mistake in this function, every code since remain_encode = edge_index[0, remain_indices]...should belong to the else branch. After I change this function like below, the node-level unlearning task can run normally.

def update_edge_index_unlearn(self, delete_nodes, delete_edge_index=None):
    edge_index = self.data.edge_index.numpy()

    unique_indices = np.where(edge_index[0] < edge_index[1])[0]
    unique_indices_not = np.where(edge_index[0] > edge_index[1])[0]

    if self.args["unlearn_task"] == 'edge':
        remain_indices = np.setdiff1d(unique_indices, delete_edge_index)
    else:
        unique_edge_index = edge_index[:, unique_indices]
        delete_edge_indices = np.logical_or(np.isin(unique_edge_index[0], delete_nodes),
                                            np.isin(unique_edge_index[1], delete_nodes))
        remain_indices = np.logical_not(delete_edge_indices)
        remain_indices = np.where(remain_indices == True)

        # I believe the original code made a mistake here, remain_encode should get from unique_edge_index by remain_indices, 
        # since np.logical_not() and np.logical_or() above are all based on unique_edge_index
        remain_encode = unique_edge_index[0, remain_indices] * edge_index.shape[1] * 2 + unique_edge_index[1, remain_indices]
        unique_encode_not = edge_index[1, unique_indices_not] * edge_index.shape[1] * 2 + edge_index[0, unique_indices_not]
        sort_indices = np.argsort(unique_encode_not)
        remain_indices_not = unique_indices_not[sort_indices[np.searchsorted(unique_encode_not, remain_encode, sorter=sort_indices)]]
        remain_indices = np.union1d(remain_indices, remain_indices_not)

    return torch.from_numpy(edge_index[:, remain_indices])

So the purpose of this function I believe is to update the graph structure after receiving an unlearning task. For edge unlearning task, we only need to remove the unlearned edge from edge_index. However, for node unlearning task, we need to remove every edge related to these nodes.

this function first split edges by node id (unique_indices are indices for edges(u,v) that u<v and unique_indices_not are indices for edges(u,v) that u>v). Then for node unlearning task, it first find every edge needs to be deleted in unique_indices(that u<v), the np.logical_or(np.isin()) function find every edge (u,v) that u<v and u need to be deleted or v need to be deleted.

Then this function needs to find every edge u>v that needs to be deleted. It computes u*2N+v (or v*2N+u for the edge u>v) to map (u,v) and (v,u) to unique numbers, so it can easily find the corresponding edge of the edge to be deleted in (v<u) among the edges of (u>v)

li-yang23 commented 2 months ago

Not sure if I made some mistakes, after I change the function, node-level unlearning task for sgc model performs poor, F1 score and accuracy score merely reaches 0.5