VincLee8188 / GMAN-PyTorch

Implementation of Graph Muti-Attention Network with PyTorch
134 stars 30 forks source link

tf.split and torch.split #3

Open wanzhixiao opened 3 years ago

wanzhixiao commented 3 years ago

note that torch.split is different from tf.split:

query = torch.cat(torch.split(query, self.K, dim=-1), dim=0) (b*b,t,n,K)

original tf version: query = tf.concat(tf.split(query, K, axis = -1), axis = 0) (k*b,t,n,d)

wanzhixiao commented 3 years ago

Those are the differences between Pytorch and Tensorflow syntax.

Yes, so the reproduction of this place does not match the original tensorflow version,you should modify it.

szloveyyy commented 3 years ago

It should be query = torch.cat(torch.split(query, self.d, dim=-1), dim=0) otherwise the shape of query is [d*batch_size, num_step, num_vertex, k] because the second parameter of torch.split() is split_size