VincLee8188 / GMAN-PyTorch

Implementation of Graph Muti-Attention Network with PyTorch
134 stars 30 forks source link

some question #4

Open Dahaoge666 opened 3 years ago

Dahaoge666 commented 3 years ago

您好,您的代码非常棒,在跑实验时与tensorflow版本有差异,在model中,transformAttention的FC层是否应改为FC_q等

Dahaoge666 commented 3 years ago
class transformAttention(nn.Module):
    '''
    transform attention mechanism
    X:        [batch_size, num_his, num_vertex, D]
    STE_his:  [batch_size, num_his, num_vertex, D]
    STE_pred: [batch_size, num_pred, num_vertex, D]
    K:        number of attention heads
    d:        dimension of each attention outputs
    return:   [batch_size, num_pred, num_vertex, D]
    '''

    def __init__(self, K, d, bn_decay):
        super(transformAttention, self).__init__()
        D = K * d
        self.K = K
        self.d = d
        self.FC_q = FC(input_dims=D, units=D, activations=F.relu,
                       bn_decay=bn_decay)
        self.FC_k = FC(input_dims=D, units=D, activations=F.relu,
                       bn_decay=bn_decay)
        self.FC_v = FC(input_dims=D, units=D, activations=F.relu,
                       bn_decay=bn_decay)
        self.FC = FC(input_dims=D, units=D, activations=F.relu,
                     bn_decay=bn_decay)

    def forward(self, X, STE_his, STE_pred):
        batch_size = X.shape[0]
        # [batch_size, num_step, num_vertex, K * d]
       query = self.FC(STE_pred)
        key = self.FC(STE_his)
        value = self.FC(X)
        # [K * batch_size, num_step, num_vertex, d]
        query = torch.cat(torch.split(query, self.K, dim=-1), dim=0)
        key = torch.cat(torch.split(key, self.K, dim=-1), dim=0)
        value = torch.cat(torch.split(value, self.K, dim=-1), dim=0)
        # query: [K * batch_size, num_vertex, num_pred, d]
        # key:   [K * batch_size, num_vertex, d, num_his]
        # value: [K * batch_size, num_vertex, num_his, d]
        query = query.permute(0, 2, 1, 3)
        key = key.permute(0, 2, 3, 1)
        value = value.permute(0, 2, 1, 3)
        # [K * batch_size, num_vertex, num_pred, num_his]
        attention = torch.matmul(query, key)
        attention /= (self.d ** 0.5)
        attention = F.softmax(attention, dim=-1)
        # [batch_size, num_pred, num_vertex, D]
        X = torch.matmul(attention, value)
        X = X.permute(0, 2, 1, 3)
        X = torch.cat(torch.split(X, batch_size, dim=0), dim=-1)
        X = self.FC(X)
        del query, key, value, attention
        return X

emm,the third through the fifth in forward, I think it may be
query = self.FC_q(STE_pred) key = self.FC_k(STE_his) value = self.FC_v(X) but in the code,it is query = self.FC(STE_pred) key = self.FC(STE_his) value = self.FC(X)

thankyou

VincLee8188 commented 3 years ago

yep, sure, those are typos, should be query = self.FC_q(STE_pred) key = self.FC_k(STE_his) value = self.FC_v(X) I revised them in model.py, thanks for your attention.