VincLee8188 / GMAN-PyTorch

Implementation of Graph Muti-Attention Network with PyTorch
136 stars 30 forks source link

temporalAttention #2

Open wanzhixiao opened 3 years ago

wanzhixiao commented 3 years ago

Hi,thanks for your implementation. i find a bug in model.py, attention = torch.where(mask, attention, -2 ** 15 + 1),when i run code, this line raise TypeError: where(): argument 'other' (position 3) must be Tensor, not int,

VincLee8188 commented 3 years ago

please check your pytorch version, my pytorch is 1.4.0, and the argument (at postion 3) of the tocrch.where() can be tensor or scalar.