Open Dahaoge666 opened 3 years ago
class transformAttention(nn.Module):
'''
transform attention mechanism
X: [batch_size, num_his, num_vertex, D]
STE_his: [batch_size, num_his, num_vertex, D]
STE_pred: [batch_size, num_pred, num_vertex, D]
K: number of attention heads
d: dimension of each attention outputs
return: [batch_size, num_pred, num_vertex, D]
'''
def __init__(self, K, d, bn_decay):
super(transformAttention, self).__init__()
D = K * d
self.K = K
self.d = d
self.FC_q = FC(input_dims=D, units=D, activations=F.relu,
bn_decay=bn_decay)
self.FC_k = FC(input_dims=D, units=D, activations=F.relu,
bn_decay=bn_decay)
self.FC_v = FC(input_dims=D, units=D, activations=F.relu,
bn_decay=bn_decay)
self.FC = FC(input_dims=D, units=D, activations=F.relu,
bn_decay=bn_decay)
def forward(self, X, STE_his, STE_pred):
batch_size = X.shape[0]
# [batch_size, num_step, num_vertex, K * d]
query = self.FC(STE_pred)
key = self.FC(STE_his)
value = self.FC(X)
# [K * batch_size, num_step, num_vertex, d]
query = torch.cat(torch.split(query, self.K, dim=-1), dim=0)
key = torch.cat(torch.split(key, self.K, dim=-1), dim=0)
value = torch.cat(torch.split(value, self.K, dim=-1), dim=0)
# query: [K * batch_size, num_vertex, num_pred, d]
# key: [K * batch_size, num_vertex, d, num_his]
# value: [K * batch_size, num_vertex, num_his, d]
query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 3, 1)
value = value.permute(0, 2, 1, 3)
# [K * batch_size, num_vertex, num_pred, num_his]
attention = torch.matmul(query, key)
attention /= (self.d ** 0.5)
attention = F.softmax(attention, dim=-1)
# [batch_size, num_pred, num_vertex, D]
X = torch.matmul(attention, value)
X = X.permute(0, 2, 1, 3)
X = torch.cat(torch.split(X, batch_size, dim=0), dim=-1)
X = self.FC(X)
del query, key, value, attention
return X
emm,the third through the fifth in forward, I think it may be
query = self.FC_q(STE_pred)
key = self.FC_k(STE_his)
value = self.FC_v(X)
but in the code,it is
query = self.FC(STE_pred)
key = self.FC(STE_his)
value = self.FC(X)
thankyou
yep, sure, those are typos, should be query = self.FC_q(STE_pred) key = self.FC_k(STE_his) value = self.FC_v(X) I revised them in model.py, thanks for your attention.
您好,您的代码非常棒,在跑实验时与tensorflow版本有差异,在model中,transformAttention的FC层是否应改为FC_q等