Open jasperzelu opened 8 months ago
When i set cfg.dataset.task=node, cfg.model.type=gnn , cfg.gnn.stage_type=stack, then it come s to self.post_mp = GNNHead(dim_in=d_in, dim_out=dim_out) in gnn using:
class GNNNodeHead(nn.Module): '''Head of GNN, node prediction''' def __init__(self, dim_in, dim_out): super(GNNNodeHead, self).__init__() self.layer_post_mp = MLP(dim_in, dim_out, num_layers=cfg.gnn.layers_post_mp, bias=True) def _apply_index(self, batch): if batch.node_label_index.shape[0] == batch.node_label.shape[0]: return batch.node_feature[batch.node_label_index], batch.node_label else: return batch.node_feature[batch.node_label_index], \ batch.node_label[batch.node_label_index] def forward(self, batch): batch = self.layer_post_mp(batch) pred, label = self._apply_index(batch) return pred, label
i want to know Where is the "batch.node_label_index", "batch.node_label" property set
Sorry to bother , but I think I've found the answer
When i set cfg.dataset.task=node, cfg.model.type=gnn , cfg.gnn.stage_type=stack, then it come s to self.post_mp = GNNHead(dim_in=d_in, dim_out=dim_out) in gnn using:
i want to know Where is the "batch.node_label_index", "batch.node_label" property set