Open wanzhixiao opened 3 years ago
Those are the differences between Pytorch and Tensorflow syntax.
Yes, so the reproduction of this place does not match the original tensorflow version,you should modify it.
It should be query = torch.cat(torch.split(query, self.d, dim=-1), dim=0) otherwise the shape of query is [d*batch_size, num_step, num_vertex, k] because the second parameter of torch.split() is split_size
note that torch.split is different from tf.split:
query = torch.cat(torch.split(query, self.K, dim=-1), dim=0) (b*b,t,n,K)
original tf version: query = tf.concat(tf.split(query, K, axis = -1), axis = 0) (k*b,t,n,d)