Open Chloe-Liu33 opened 1 year ago
It seems that line 361-367 of the following codes in model.py should be deleted.
x = torch.einsum('bnft, knm -> bkmft', x, wavelets)
x = torch.einsum('bkmft, kfo -> bk')
x = torch.einsum('bknft, k -> bnft', x, self.Gate)
a = torch.einsum('fs, bknf -> bkns', self.Att_W, xs)
a = torch.einsum('bkns, ks -> bkn', xs, self.Att_U)
And there is an additional argument _adjfile that refers to the location of the adjacency matrix data, should actually be appointed.
hello ,when run the code according to your README.md on dataset METR-LA, I encountered the bug. When I run the code on my own dataset, the same bug occured again.