Diego999 / pyGAT

Pytorch implementation of the Graph Attention Network model by Veličković et. al (2017, https://arxiv.org/abs/1710.10903)
MIT License
2.89k stars 689 forks source link

Batch Size bigger than 1 #36

Open brickee opened 4 years ago

brickee commented 4 years ago

How can we use the codes for other datasets with batch sizes bigger than 1?

gloriatao commented 4 years ago

i have the same problem, the current setting only allows batch size=1.

gloriatao commented 4 years ago

Got it! use torch.matmul instead of torch.mm in class GraphAttentionLayer

tanjia123456 commented 3 years ago

hi, There seems to be no batch-size in the code. How do you set the batch-size? I always get RuntimeError: CUDA out of memory. Tried to allocate 6.25 GiB (GPU 0; 31.75 GiB total capacity; 25.01 GiB already allocated; 4.51 GiB free; 1.18 GiB cached)

Lijiachen1018 commented 3 years ago

I 've modify the GraphAttentionLayer in layers.py and add some comment to show the changing of tensor dimension. Please let me know if there is mistake.

class GraphAttentionLayer(nn.Module):
    """
    https://github.com/Diego999/pyGAT/blob/master/layers.py
    Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
    """

    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat

        self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.empty(size=(2 * out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, h, adj):
        """

        :param h: (batch_zize, number_nodes, in_features)
        :param adj: (batch_size, number_nodes, number_nodes)
        :return: (batch_zize, number_nodes, out_features)
        """
        # batchwise matrix multiplication
        # (batch_zize, number_nodes, in_features) * (in_features, out_features)
        # -> (batch_zize, number_nodes, out_features)
        Wh = torch.matmul(h, self.W)

        # (batch_zize, number_nodes, number_nodes, 2 * out_features)
        a_input = self.batch_prepare_attentional_mechanism_input(Wh)

        # (batch_zize, number_nodes, number_nodes, 2 * out_features) * (2 * out_features, 1)
        # -> (batch_zize, number_nodes, number_nodes, 1)
        e = torch.matmul(a_input, self.a)

        # (batch_zize, number_nodes, number_nodes)
        e = self.leakyrelu(e.squeeze(-1))

        # (batch_zize, number_nodes, number_nodes)
        zero_vec = -9e15 * torch.ones_like(e)

        # (batch_zize, number_nodes, number_nodes)
        attention = torch.where(adj > 0, e, zero_vec)

        # (batch_zize, number_nodes, number_nodes)
        attention = F.softmax(attention, dim=-1)

        # (batch_zize, number_nodes, number_nodes)
        attention = F.dropout(attention, self.dropout, training=self.training)

        # batched matrix multiplication (batch_zize, number_nodes, out_features)
        h_prime = torch.matmul(attention, Wh)

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

    def batch_prepare_attentional_mechanism_input(self, Wh):
        """
        with batch training
        :param Wh: (batch_zize, number_nodes, out_features)
        :return:
        """
        B, M, E = Wh.shape # (batch_zize, number_nodes, out_features)
        Wh_repeated_in_chunks = Wh.repeat_interleave(M, dim=1)  # (B, M*M, E)
        Wh_repeated_alternating = Wh.repeat(1, M, 1)  # (B, M*M, E)
        all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=-1)  # (B, M*M,2E)
        return all_combinations_matrix.view(B, M, M, 2 * E)

    def _prepare_attentional_mechanism_input(self, Wh_n):
        """
        no batch dimension
        :param Wh_n:
        :return:
        """
        M = Wh_n.size()[0]  # number of nodes(M, E)
        Wh_repeated_in_chunks = Wh_n.repeat_interleave(M, dim=0)  # (M, M, E)
        Wh_repeated_alternating = Wh_n.repeat(M, 1)  # (M, M, E)
        all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=1)  # (M*M,2E)
        return all_combinations_matrix.view(M, M, 2 * self.out_features)  # (M, M, 2E)

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'
baniks commented 3 years ago

Also in models.py/forward dim should be changed to last dimension.
Existing code: x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
Chnaged code: x = torch.cat([att(x, adj) for att in self.attentions], dim=-1)