chaoshangcs / GTS

Discrete Graph Structure Learning for Forecasting Multiple Time Series, ICLR 2021.
Apache License 2.0
171 stars 30 forks source link

some problem for code #20

Open perveil opened 2 years ago

perveil commented 2 years ago
def encode_onehot(labels):
    classes = set(labels)
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in
                    enumerate(classes)}
    labels_onehot = np.array(list(map(classes_dict.get, labels)),
                             dtype=np.int32)
    return labels_onehot
off_diag = np.ones([self.num_nodes, self.num_nodes])
rel_rec = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32) #?
rel_send = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32) #?
self.rel_rec = torch.FloatTensor(rel_rec).to(device)
self.rel_send = torch.FloatTensor(rel_send).to(device)

hi!I have some questions with the above code,so rel_rec 、rel_send `s mean is what? Thank you for your reply!

perveil commented 2 years ago

when the node num is large,it can`t compte

chaoshangcs commented 2 years ago

Hi, thanks for your questions. The rel_rec and rel_send are the index (one hot) vectors that are used to extract the embeddings of sender and receiver:

receivers = torch.matmul(self.rel_rec, x) senders = torch.matmul(self.rel_send, x)

Each sender and receiver are the two sensor nodes from the graph. You could explore other way to get the embeddings of all nodes.