Closed cshjin closed 1 year ago
the forward in HGTConv layer only update dst_type
node?
# Iterate over edge-types:
for edge_type, edge_index in edge_index_dict.items():
src_type, _, dst_type = edge_type
edge_type = '__'.join(edge_type)
a_rel = self.a_rel[edge_type]
k = (k_dict[src_type].transpose(0, 1) @ a_rel).transpose(1, 0)
m_rel = self.m_rel[edge_type]
v = (v_dict[src_type].transpose(0, 1) @ m_rel).transpose(1, 0)
# propagate_type: (k: Tensor, q: Tensor, v: Tensor, rel: Tensor)
out = self.propagate(edge_index, k=k, q=q_dict[dst_type], v=v,
rel=self.p_rel[edge_type], size=None)
out_dict[dst_type].append(out)
Resolved by adding a bidirectional link. This won't be a problem in the conv layers.
Popup error: