lightaime / deep_gcns_torch

Pytorch Repo for DeepGCNs (ICCV'2019 Oral, TPAMI'2021), DeeperGCN (arXiv'2020) and GNN1000(ICML'2021): https://www.deepgcns.org
MIT License
1.13k stars 155 forks source link

May be a serious bug #58

Closed jiyanbio closed 4 years ago

jiyanbio commented 4 years ago

def train(data, dataset, model, optimizer, criterion, device):

loss_list = []
model.train()
sg_nodes, sg_edges, sg_edges_index, _ = data

**train_y = dataset.y[dataset.train_idx]**
idx_clusters = np.arange(len(sg_nodes))
np.random.shuffle(idx_clusters)

for idx in idx_clusters:

    x = dataset.x[sg_nodes[idx]].float().to(device)
    sg_nodes_idx = torch.LongTensor(sg_nodes[idx]).to(device)

    sg_edges_ = sg_edges[idx].to(device)
    sg_edges_attr = dataset.edge_attr[sg_edges_index[idx]].to(device)

    mapper = {node: idx for idx, node in enumerate(sg_nodes[idx])}

    **inter_idx = intersection(sg_nodes[idx], dataset.train_idx.tolist())**
    training_idx = [mapper[t_idx] for t_idx in inter_idx]

    optimizer.zero_grad()

    pred = model(x, sg_nodes_idx, sg_edges_, sg_edges_attr)

    **target = train_y[inter_idx].to(device)**  # inter_idx may be out of the maximal index range

    **loss = criterion(pred[training_idx].to(torch.float32), target.to(torch.float32))**
    loss.backward()
    optimizer.step()
    loss_list.append(loss.item())

return statistics.mean(loss_list)
Elizabeth1997 commented 4 years ago

Hi, @jiyanbio thanks for your interest in our work. Your concern is right while the potential issue you mentioned doesn't exist for this dataset cause the training index is continuous and starting from 0.

{'train': tensor([    0,     1,     2,  ..., 86616, 86617, 86618]),
 'valid': tensor([ 86619,  86620,  86621,  ..., 107852, 107853, 107854]),
 'test': tensor([107855, 107856, 107857,  ..., 132531, 132532, 132533])}

Therefore, statement train_y = dataset.y[dataset.train_idx] and train_y = dataset.y will produce the same result. Please let us know if your concern is solved and feel free to ask more questions.