Starlien95 / GraphPrompt

GraphPrompt: Unifying Pre-Training and Downstream Tasks for Graph Neural Networks
142 stars 14 forks source link

Question for Training #14

Open DuanhaoranCC opened 6 days ago

DuanhaoranCC commented 6 days ago

Pretrain Section

    model.train()
    total_time=0
    for batch_id, batch in enumerate(data_loader):
        ids, graph_label, graph, graph_len= batch
        # print(batch)
        graph=graph.to(device)
        print(graph_label.size())
        print(graph.ndata["feature"].size())

        graph_label=graph_label.to(device)
        graph_len = graph_len.to(device)
        s=time.time()
        x,pred = model(graph, graph_len)

        #########################################################################
        ############为了解决在NCI1上loss出现0的情况###################################
        #需要约束embedding为>0的值,可以通过加一个可以通过relu,sigmoid实现
        pred=F.sigmoid(pred)

        #修改后的计算过程
        adj = graph.adjacency_matrix()
        # adj = adj.to(device)
        # adj = graph.adjacency_matrix()
        # temp = adj.to_dense()
        # temp+= (self_weight*torch.eye(temp.size(0)))
        # adj = temp.to_sparse().to(device)
        size_ = adj.size(0)
        p = [i for i in range(size_)]
        x = torch.tensor([p,p])
        q = [self_weight for i in range(size_)]
        tt = torch.sparse_coo_tensor(x,q,(size_,size_))
        adj = (adj + tt).to(device)
        '''print('---------------------------------------------------')
        print('adj: ',adj.size())
        print('pred: ',adj.size())
        print('---------------------------------------------------')'''

        pred = torch.matmul(adj, pred)
        #print(pred.size())
        _pred=split_and_batchify_graph_feats(pred, graph_len)[0]
        sample = graph.ndata['sample']
        _sample=split_and_batchify_graph_feats(sample, graph_len)[0]
        sample_=_sample.reshape(_sample.size(0),-1,1)
        #print(_pred.size())
        #print(sample_.size())
        _pred=torch.gather(input=_pred,dim=1,index=sample_)
        #print(_pred.size())
        _pred=_pred.resize_as(_sample)
        #print(_pred.size())
        print(_pred.size())

        reg_loss = compareloss(_pred,train_config["temperature"])
        reg_loss.requires_grad_(True)

Could you please explain why the model aggregates the node representations to obtain pred and then performs aggregation again? I don't quite understand the reason for this part.

Prompt

    for batch_id, batch in enumerate(data_loader):
        ids, graph_label, graph, graph_len = batch
        # print(batch)
        cnt = graph_label.shape[0]
        total_cnt+=cnt
        batchcnt+=1
        graph = graph.to(device)
        graph_label=graph_label.to(device)
        graph_len = graph_len.to(device)
        s = time.time()

        x,embedding=pretrain_model(graph,graph_len)
        graph_label_onehot=l2onehot(graph_label)
        embedding = model(embedding, graph_len)*train_config["scalar"]
        embedding=F.dropout(embedding,p=train_config["downstream_dropout"])
        c_embedding=center_embedding(embedding,graph_label,label_num)
        distance=distance2center(embedding,c_embedding)
        #print(distance)

        distance = 1/F.normalize(distance, dim=1)

        #distance=distance2center2(embedding,c_embedding)

        #print('distance: ',distance )
        pred=F.log_softmax(distance,dim=1)
        #----------------------------------------
        #reg_loss = reg_crit(pred, graph_label_onehot)
        #对NLL LOSS用这个公式,否则用上面的
        reg_loss = reg_crit(pred, graph_label.squeeze().type(torch.LongTensor).to(device))
        #------------------------------------------------
        reg_loss.requires_grad_(True)
        _pred = torch.argmax(pred, dim=1, keepdim=True)
        accuracy = correctness_GPU(_pred, graph_label)
        total_acc+=accuracy
        if isinstance(config["bp_loss_slp"], (int, float)):
            neg_slp = float(config["bp_loss_slp"])
        else:
            bp_loss_slp, l0, l1 = config["bp_loss_slp"].rsplit("$", 3)
            neg_slp = anneal_fn(bp_loss_slp, batch_id + epoch * epoch_step, T=total_step // 4, lambda0=float(l0),
                                lambda1=float(l1))

        #bp_loss = bp_crit(pred.float(), graph_label_onehot.float(), neg_slp)
        #对NLL LOSS用这个公式,否则用上面的
        bp_loss = bp_crit(pred.float(), graph_label.squeeze().type(torch.LongTensor).to(device),neg_slp)
        bp_loss.requires_grad_(True)

        # float
        reg_loss_item = reg_loss.item()
        bp_loss_item = bp_loss.item()
        total_reg_loss += reg_loss_item
        total_bp_loss += bp_loss_item

        if writer:
            writer.add_scalar("%s/REG-%s" % (data_type, config["reg_loss"]), reg_loss_item,
                              epoch * epoch_step + batch_id)
            writer.add_scalar("%s/BP-%s" % (data_type, config["bp_loss"]), bp_loss_item, epoch * epoch_step + batch_id)

        # if logger and (batch_id % config["print_every"] == 0 or batch_id == epoch_step - 1):
        #     logger.info(
        #         "epoch: {:0>3d}/{:0>3d}\tdata_type: {:<5s}\tbatch: {:0>5d}/{:0>5d}\treg loss: {:0>5.8f}\tbp loss: {:0>5.8f}".format(
        #             epoch, config["epochs"], data_type, batch_id, epoch_step,
        #             reg_loss_item, bp_loss_item))
        # for name, para in pre_train_model.named_parameters():
        #     print("xxx")
        #     print(para)
        # bp_loss.backward()
        # for name, para in pre_train_model.named_parameters():
        #     print("yyy")
        #     print(para)
        '''for name, parms in model.named_parameters():
            print('-->name:', name, '-->grad_requirs:', parms.requires_grad, \
                  ' -->grad_value:', parms.grad)'''
        '''for name, parms in model.named_parameters():
            print('-->name:', name, ' -->value:', parms)'''
        if (config["update_every"] < 2 or batch_id % config["update_every"] == 0 or batch_id == epoch_step - 1):
            if config["max_grad_norm"] > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), config["max_grad_norm"])
            if scheduler is not None:
                scheduler.step(epoch * epoch_step + batch_id)
            optimizer.step()
            optimizer.zero_grad()
        e=time.time()
        total_time+=e-s

    mean_reg_loss = total_reg_loss / total_cnt
    mean_bp_loss = total_bp_loss / total_cnt
    mean_acc=total_acc/batchcnt
    if writer:
        writer.add_scalar("%s/REG-%s-epoch" % (data_type, config["reg_loss"]), mean_reg_loss, epoch)
        writer.add_scalar("%s/BP-%s-epoch" % (data_type, config["bp_loss"]), mean_bp_loss, epoch)
    # if logger:
    #     logger.info("epoch: {:0>3d}/{:0>3d}\tdata_type: {:<5s}\treg loss: {:0>5.8f}\tbp loss: {:0>5.8f}\tmean_acc: {:0>1.3f}".format(
    #         epoch, config["epochs"], data_type, mean_reg_loss, mean_bp_loss,mean_acc))
    gc.collect()

Also, there are two loss functions, reg_loss and bp_loss, but I couldn't find the backward() call. Did I miss something? Thank you for your patience!"

Starlien95 commented 18 hours ago
  1. We use subgraph similarity calculation as the task template, which requires us to aggregate node embeddings to obtain subgraph embeddings.
  2. We only pre-train the graph encoder once for both node and graph classification tasks. The pre-training code can be found in \nodedownstream\pre-train.py.