acharkq / MolCA

Code for EMNLP2023 paper "MolCA: Molecular Graph-Language Modeling with Cross-Modal Projector and Uni-Modal Adapter".
71 stars 1 forks source link

Fail to reproduce the retrival results #9

Closed DukeZhou3 closed 5 months ago

DukeZhou3 commented 5 months ago

Thank you for your excellent work! I would like to ask for your assistance with reproducing the results of Molecule-Text Retrieval for PCDes. However, I seem to be encountering some issues.

I have downloaded the ckpt file from the following link: https://huggingface.co/acharkq/MolCA/tree/main and have executed the stage1.py script. However, I have obtained an accuracy of 0.

Could you kindly provide some guidance on how I can resolve this issue? Thank you in advance for your help!

image

And this is my full script:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import argparse
import torch
import warnings
import pytorch_lightning as pl
from pytorch_lightning import Trainer, strategies
import pytorch_lightning.callbacks as plc
from pytorch_lightning.loggers import CSVLogger
from blip2_stage1 import Blip2Stage1
from stage1_kvplm_dm import Stage1KVPLMDM
from tqdm import tqdm
from retrieval_metrics import recall_at_k

os.environ['OPENBLAS_NUM_THREADS'] = '1'
## for pyg bug
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
## for A5000 gpus
torch.set_float32_matmul_precision('medium') # can be medium (bfloat16), high (tensorfloat32), highest (float32)

def pad_and_concat(tensor_list):
    '''
    concat the first dimension and pad the second dimension
    tensor_list: [[B (diff), N_num, *], ...]
    '''
    device = tensor_list[0].device
    max_dim1 = max(t.shape[1] for t in tensor_list)
    sum_dim0 = sum(t.shape[0] for t in tensor_list)
    if len(tensor_list[0].shape) == 3:
        out = torch.zeros((sum_dim0, max_dim1, tensor_list[0].shape[-1]), device=device)
        i = 0
        for t in tensor_list:
            out[i:i+t.shape[0], :t.shape[1]] = t
            i += t.shape[0]
        return out
    elif len(tensor_list[0].shape) == 2:
        out = torch.zeros((sum_dim0, max_dim1), device=device)
        i = 0
        for t in tensor_list:
            out[i:i+t.shape[0], :t.shape[1]] = t
            i += t.shape[0]
        return out
    raise NotImplementedError()

def eval_retrieval_inbatch_with_rerank(model, args):
    model.eval()
    g2t_acc = 0
    t2g_acc = 0
    g2t_rec20 = 0
    t2g_rec20 = 0
    allcnt = 0

    g2t_rerank_acc = 0
    t2g_rerank_acc = 0
    g2t_rerank_rec20 = 0
    t2g_rerank_rec20 = 0

    graph_rep_total = []  
    text_rep_total = []

    graph_feat_total = [] 
    graph_mask_total = []

    text_total = []
    text_mask_total = []

    for batch in tqdm(model.test_match_loader):
        aug, text, text_mask = batch
        text_total.append(text)
        text_mask_total.append(text_mask)

        aug = aug.to(args.device)
        text = text.to(args.device)
        text_mask = text_mask.to(args.device)
        graph_rep, graph_feat, graph_mask = model.blip2qformer.graph_forward(aug) # shape = [B, num_qs, D]
        text_rep = model.blip2qformer.text_forward(text, text_mask) # shape = [B, D]

        sim_q2t = (graph_rep.unsqueeze(1) @ text_rep.unsqueeze(-1)).squeeze() # shape = [B, 1, num_qs, D]; shape = [B, D, 1]; output shape = [B, B, num_qs]
        sim_g2t, _ = sim_q2t.max(-1) # shape = [B, B]

        B = sim_g2t.shape[0]
        sorted_ids = sim_g2t.argsort(descending=True).cpu()
        g2t_rank = (sorted_ids == torch.arange(B).reshape(-1, 1)).int().argmax(dim=-1)
        sorted_ids = sim_g2t.T.argsort(descending=True).cpu()
        t2g_rank = (sorted_ids == torch.arange(B).reshape(-1, 1)).int().argmax(dim=-1)

        g2t_acc += float((g2t_rank == 0).sum())
        t2g_acc += float((t2g_rank == 0).sum())
        g2t_rec20 += float((g2t_rank < 20).sum())
        t2g_rec20 += float((t2g_rank < 20).sum())

        allcnt += B

        graph_rep_total.append(graph_rep.detach().cpu())
        text_rep_total.append(text_rep.detach().cpu())
        graph_feat_total.append(graph_feat.detach().cpu())
        graph_mask_total.append(graph_mask.detach().cpu())

        ## reranking
        graph_feat = graph_feat.repeat_interleave(B, 0) # shape = [B * B, num_qs, D]
        graph_mask = graph_mask.repeat_interleave(B, 0) # shape = [B * B, num_qs, D]
        text = text.repeat(B, 1) # shape = [B * B, text_len]
        text_mask = text_mask.repeat(B, 1) # shape = [B * B, text_len]

        if False:
            gtm_sim = model.blip2qformer.compute_gtm(graph_feat, graph_mask, text, text_mask).reshape(B, B)
        else:
            ## batched reranking
            batch_size = 64
            gtm_sim = []
            for i in range(0, graph_feat.shape[0], batch_size):
                gtm_sim_local = model.blip2qformer.compute_gtm(graph_feat[i:i+batch_size], graph_mask[i:i+batch_size], text[i:i+batch_size], text_mask[i:i+batch_size])
                gtm_sim.append(gtm_sim_local)
            gtm_sim = torch.cat(gtm_sim, dim=0).reshape(B, B)

        rerank_sim = sim_g2t + gtm_sim

        ## g2t rerank
        sorted_ids = torch.argsort(rerank_sim, descending=True).cpu() # shape = [B, B]
        hit_g2t = (sorted_ids == torch.arange(B).reshape(-1, 1)).float()
        g2t_rerank_acc += float(hit_g2t[:, 0].sum())
        g2t_rerank_rec20 += float(hit_g2t[:, :20].sum())

        ## t2g rerank
        sorted_ids = torch.argsort(rerank_sim.T, descending=True).cpu() # shape = [B, B]
        hit_t2g = (sorted_ids == torch.arange(B).reshape(-1, 1)).float()
        t2g_rerank_acc += float(hit_t2g[:, 0].sum())
        t2g_rerank_rec20 += float(hit_t2g[:, :20].sum())

    graph_rep_total = torch.cat(graph_rep_total, dim=0)
    text_rep_total = torch.cat(text_rep_total, dim=0)
    graph_feat_total = pad_and_concat(graph_feat_total)
    graph_mask_total = pad_and_concat(graph_mask_total)
    text_total = torch.cat(text_total, dim=0)
    text_mask_total = torch.cat(text_mask_total, dim=0)

    g2t_acc = round(g2t_acc/allcnt * 100, 2)
    t2g_acc = round(t2g_acc/allcnt * 100, 2)
    g2t_rec20 = round(g2t_rec20 / allcnt * 100, 2)
    t2g_rec20 = round(t2g_rec20 / allcnt * 100, 2)

    g2t_rerank_acc = round(g2t_rerank_acc / allcnt * 100, 2)
    t2g_rerank_acc = round(t2g_rerank_acc / allcnt * 100, 2)
    g2t_rerank_rec20 = round(g2t_rerank_rec20 / allcnt * 100, 2)
    t2g_rerank_rec20 = round(t2g_rerank_rec20 / allcnt * 100, 2)
    return g2t_acc, t2g_acc, g2t_rec20, t2g_rec20, \
        g2t_rerank_acc, t2g_rerank_acc, g2t_rerank_rec20, t2g_rerank_rec20, \
        graph_rep_total, text_rep_total, graph_feat_total, graph_mask_total, text_total, text_mask_total

def eval_retrieval_fullset(graph_rep, text_rep, device):    
    N = graph_rep.shape[0]
    B = 8
    text_rep = text_rep.to(device)
    sim_g2t = []
    for i in tqdm(range(0, N, B)):
        l_graph_rep = graph_rep[i:i+B].to(device)
        l_sim_q2t = (l_graph_rep.unsqueeze(1) @ text_rep.unsqueeze(-1)).squeeze() # shape = [B, 1, num_qs, D]; shape = [N, D, 1]; output shape = [B, N, num_qs]
        l_sim_g2t, _ = l_sim_q2t.max(-1) # shape = [B, N]
        sim_g2t.append(l_sim_g2t)
    sim_g2t = torch.cat(sim_g2t, dim=0).cpu() # shape = [N, N]

    rank_g2t = []
    for i in range(0, N, B):
        sorted_ids = torch.argsort(sim_g2t[i:i+B].to(device), descending=True)
        rank_g2t.append((sorted_ids == torch.arange(i,i+sorted_ids.shape[0], device=device).reshape(-1, 1)).int().argmax(dim=-1))
    rank_g2t = torch.cat(rank_g2t, dim=0)

    rank_t2g = []
    for i in range(0, N, B):
        sorted_ids = torch.argsort(sim_g2t.T[i:i+B].to(device), descending=True)
        rank_t2g.append((sorted_ids == torch.arange(i,i+sorted_ids.shape[0], device=device).reshape(-1, 1)).int().argmax(dim=-1))
    rank_t2g = torch.cat(rank_t2g, dim=0)

    g2t_acc = float((rank_g2t == 0).float().mean())
    g2t_rec20 = float((rank_g2t < 20).float().mean())
    t2g_acc = float((rank_t2g == 0).float().mean())
    t2g_rec20 = float((rank_t2g < 20).float().mean())
    g2t_acc = round(g2t_acc * 100, 2)
    g2t_rec20 = round(g2t_rec20 * 100, 2)
    t2g_acc = round(t2g_acc * 100, 2)
    t2g_rec20 = round(t2g_rec20 * 100, 2)
    return g2t_acc, g2t_rec20, t2g_acc, t2g_rec20, sim_g2t

def main(args):
    pl.seed_everything(args.seed)

    # model
    if args.init_checkpoint:
        model = Blip2Stage1.load_from_checkpoint(args.init_checkpoint, device=args.devices, args=args)
        print(f"loading model from {args.init_checkpoint}")
        print(model)
    else:
        model = Blip2Stage1(args)

    print('total params:', sum(p.numel() for p in model.parameters()))

    tokenizer = model.blip2qformer.tokenizer
    # data
    if args.root.find('kv') >= 0:
        dm = Stage1KVPLMDM(args.num_workers, 2, args.root, args.text_max_len, args.graph_aug, args)
    # else:
    #     dm = Stage1DM(args.num_workers, args.batch_size, args.root, args.text_max_len, args.graph_aug, tokenizer,
    #                   args)
    model.val_match_loader = dm.val_match_loader
    model.test_match_loader = dm.test_match_loader
    model = model.to(args.device)

    g2t_acc, t2g_acc, g2t_rec20, t2g_rec20, \
    g2t_rerank_acc, t2g_rerank_acc, g2t_rerank_rec20, t2g_rerank_rec20,\
    graph_rep_total, text_rep_total, _, _, _, _ = \
        eval_retrieval_inbatch_with_rerank(model, args)

    g2t_acc, g2t_rec20, t2g_acc, t2g_rec20, sim_g2t = \
        eval_retrieval_fullset(graph_rep_total, text_rep_total, args.device)
    print(g2t_acc)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--filename', type=str, default="pcdes_evaluation")
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument('--root', type=str, default="data/kv_data")
    parser.add_argument("--num_workers", type=int, default=8)
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--match_batch_size", type=int, default=8)
    parser.add_argument("--text_max_len", type=int, default=256)
    parser.add_argument('--graph_aug', type=str, default="dnodes")
    # GPU
    parser.add_argument('--seed', type=int, default=42, help='random seed')
    # MM settings
    parser.add_argument('--gtm', action='store_true', help='use graph-text matching or not', default=True)
    parser.add_argument('--lm', action='store_true', help='use language modeling or not', default=True)
    parser.add_argument('--mode', type=str, default='train')
    parser.add_argument('--accelerator', type=str, default='gpu')
    parser.add_argument('--precision', type=str, default='bf16')
    parser.add_argument('--max_epochs', type=int, default=50)
    parser.add_argument('--check_val_every_n_epoch', type=int, default=1)
    parser = Blip2Stage1.add_model_specific_args(parser)  # add model args

    args = parser.parse_args()

    print("=========================================")
    for k, v in sorted(vars(args).items()):
        print(k, '=', v)
    print("=========================================")
    main(args)
acharkq commented 5 months ago

Hi, can you check the auto-saved log file under the all_checkpoints/pcdes_evaluation directory?

I think the number 0 refer to something else but not retrieval performance.

DukeZhou3 commented 5 months ago

Thank you for your response.

Since I have not used the trainer, there is no log file available. However, I have loaded the ckpt and extracted the text and graph representation. I have then utilized the eval_retrieval_inbatch_with_rerank and eval_retrieval_fullset functions which are used in your on_validation_epoch_end module.

Upon executing the eval_retrieval_inbatch_with_rerank function, I have obtained an accuracy of 11.97, which I have printed for reference.

image

DukeZhou3 commented 5 months ago

I have also want to extract the feature based on your ckpt, and I set the text_max_len=256graph_aug="dnodes". Could you kindly confirm if these hyperparameters are correct for MolCA? Additionally, I have included my script below for reference:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import argparse
import torch
import warnings
import pytorch_lightning as pl
from pytorch_lightning import Trainer, strategies
import pytorch_lightning.callbacks as plc
from pytorch_lightning.loggers import CSVLogger
from blip2_stage1 import Blip2Stage1
from stage1_kvplm_dm import Stage1KVPLMDM
from tqdm import tqdm

os.environ['OPENBLAS_NUM_THREADS'] = '1'
## for pyg bug
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
## for A5000 gpus
torch.set_float32_matmul_precision('medium') # can be medium (bfloat16), high (tensorfloat32), highest (float32)

def extract_feat(model, args):
    mol_rep_total, text_rep_total = [], []
    model.eval()
    for batch in tqdm(model.test_match_loader):
        aug, text, text_mask = batch
        aug = aug.to(args.device)
        text = text.to(args.device)
        text_mask = text_mask.to(args.device)
        graph_rep, graph_feat, graph_mask = model.blip2qformer.graph_forward(aug) # shape = [B, num_qs, D]
        text_rep = model.blip2qformer.text_forward(text, text_mask) # shape = [B, D]

        mol_rep_total.append(graph_rep.detach().cpu())
        text_rep_total.append(text_rep.detach().cpu())

    mol_rep = torch.cat(mol_rep_total, dim=0)
    text_rep = torch.cat(text_rep_total, dim=0)
    sim_q2t = (mol_rep.unsqueeze(1) @ text_rep.unsqueeze(-1)).squeeze() # shape = [B, 1, num_qs, D]; shape = [B, D, 1]; output shape = [B, B, num_qs]
    score, _ = sim_q2t.max(-1) # shape = [B, B]
    return score

def main(args):
    pl.seed_everything(args.seed)

    # model
    if args.init_checkpoint:
        model = Blip2Stage1.load_from_checkpoint(args.init_checkpoint, device=args.devices, args=args)
        print(f"loading model from {args.init_checkpoint}")
        print(model)
    else:
        model = Blip2Stage1(args)

    print('total params:', sum(p.numel() for p in model.parameters()))

    tokenizer = model.blip2qformer.tokenizer
    dm = Stage1KVPLMDM(args.num_workers, 2, args.root, args.text_max_len, args.graph_aug, args)

    model.val_match_loader = dm.val_match_loader
    model.test_match_loader = dm.test_match_loader
    model = model.to(args.device)

    score = extract_feat(model, args)
    print(score)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--filename', type=str, default="pcdes_extract")
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument('--root', type=str, default="./data/kv_data")
    parser.add_argument("--num_workers", type=int, default=8)
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--match_batch_size", type=int, default=2)
    parser.add_argument("--text_max_len", type=int, default=256)
    parser.add_argument('--graph_aug', type=str, default="dnodes")
    # GPU
    parser.add_argument('--seed', type=int, default=42, help='random seed')
    # MM settings
    parser.add_argument('--gtm', action='store_true', help='use graph-text matching or not', default=True)
    parser.add_argument('--lm', action='store_true', help='use language modeling or not', default=True)
    parser.add_argument('--mode', type=str, default='train')
    parser.add_argument('--accelerator', type=str, default='gpu')
    parser.add_argument('--precision', type=str, default='bf16')
    parser.add_argument('--max_epochs', type=int, default=50)
    parser.add_argument('--check_val_every_n_epoch', type=int, default=1)
    parser = Blip2Stage1.add_model_specific_args(parser)  # add model args
    args = parser.parse_args()

    print("=========================================")
    for k, v in sorted(vars(args).items()):
        print(k, '=', v)
    print("=========================================")
    main(args)
acharkq commented 5 months ago

Hi, from the image of your first post. It seems that you did not setup the init_checkpoint parameter. Can you setup this parameter and try again?

Screenshot 2024-03-25 at 20 23 54
DukeZhou3 commented 5 months ago

I have tried it and successfully resolved the issue. Thank you very much for your valuable response!

acharkq commented 5 months ago

Thanks!