Starlien95 / GraphPrompt

GraphPrompt: Unifying Pre-Training and Downstream Tasks for Graph Neural Networks
131 stars 13 forks source link

Prompt loss function #12

Closed DuanhaoranCC closed 2 weeks ago

DuanhaoranCC commented 1 month ago
    for batch_id, batch in enumerate(data_loader):
        ids, graph_label, graph, graph_len = batch
        # print(batch)
        cnt = graph_label.shape[0]
        total_cnt+=cnt
        batchcnt+=1
        graph = graph.to(device)
        graph_label=graph_label.to(device)
        graph_len = graph_len.to(device)
        s = time.time()

        x,embedding=pretrain_model(graph,graph_len)
        graph_label_onehot=l2onehot(graph_label)
        embedding = model(embedding, graph_len)*train_config["scalar"]
        embedding=F.dropout(embedding,p=train_config["downstream_dropout"])
        c_embedding=center_embedding(embedding,graph_label,label_num)
        distance=distance2center(embedding,c_embedding)
        #print(distance)

        distance = 1/F.normalize(distance, dim=1)

        #distance=distance2center2(embedding,c_embedding)

        #print('distance: ',distance )
        pred=F.log_softmax(distance,dim=1)
        #----------------------------------------
        #reg_loss = reg_crit(pred, graph_label_onehot)
        #对NLL LOSS用这个公式,否则用上面的
        reg_loss = reg_crit(pred, graph_label.squeeze().type(torch.LongTensor).to(device))
        #------------------------------------------------
        reg_loss.requires_grad_(True)
        _pred = torch.argmax(pred, dim=1, keepdim=True)
        accuracy = correctness_GPU(_pred, graph_label)
        total_acc+=accuracy
        if isinstance(config["bp_loss_slp"], (int, float)):
            neg_slp = float(config["bp_loss_slp"])
        else:
            bp_loss_slp, l0, l1 = config["bp_loss_slp"].rsplit("$", 3)
            neg_slp = anneal_fn(bp_loss_slp, batch_id + epoch * epoch_step, T=total_step // 4, lambda0=float(l0),
                                lambda1=float(l1))

        #bp_loss = bp_crit(pred.float(), graph_label_onehot.float(), neg_slp)
        #对NLL LOSS用这个公式,否则用上面的
        bp_loss = bp_crit(pred.float(), graph_label.squeeze().type(torch.LongTensor).to(device),neg_slp)
        bp_loss.requires_grad_(True)

        # float
        reg_loss_item = reg_loss.item()
        bp_loss_item = bp_loss.item()
        total_reg_loss += reg_loss_item
        total_bp_loss += bp_loss_item

        if writer:
            writer.add_scalar("%s/REG-%s" % (data_type, config["reg_loss"]), reg_loss_item,
                              epoch * epoch_step + batch_id)
            writer.add_scalar("%s/BP-%s" % (data_type, config["bp_loss"]), bp_loss_item, epoch * epoch_step + batch_id)

The above code is located in prompt_fewshot.py.

I find the prompt loss function (reg_loss) is not consistent with the original paper (Equ. 14). The original paper uses softmax over pairs of graph representation and class prototypical subgraph representation, whereas the code uses cross-entropy between predictions class and ground truth. This discrepancy does not seem to unify the pre-training loss and fine-tuning loss functions.

Additionally, which equation in the paper corresponds to the bp_loss loss?

Starlien95 commented 2 weeks ago

Only (bp_loss) is used for backward, which is same as Eq. 14. (reg_loss) isn't used in the code.