Closed DukeZhou3 closed 5 months ago
Hi, can you check the auto-saved log file under the all_checkpoints/pcdes_evaluation directory?
I think the number 0 refer to something else but not retrieval performance.
Thank you for your response.
Since I have not used the trainer, there is no log file available. However, I have loaded the ckpt and extracted the text and graph representation. I have then utilized the eval_retrieval_inbatch_with_rerank and eval_retrieval_fullset functions which are used in your on_validation_epoch_end module.
Upon executing the eval_retrieval_inbatch_with_rerank function, I have obtained an accuracy of 11.97, which I have printed for reference.
I have also want to extract the feature based on your ckpt, and I set the text_max_len=256
,graph_aug="dnodes"
.
Could you kindly confirm if these hyperparameters are correct for MolCA?
Additionally, I have included my script below for reference:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import argparse
import torch
import warnings
import pytorch_lightning as pl
from pytorch_lightning import Trainer, strategies
import pytorch_lightning.callbacks as plc
from pytorch_lightning.loggers import CSVLogger
from blip2_stage1 import Blip2Stage1
from stage1_kvplm_dm import Stage1KVPLMDM
from tqdm import tqdm
os.environ['OPENBLAS_NUM_THREADS'] = '1'
## for pyg bug
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
## for A5000 gpus
torch.set_float32_matmul_precision('medium') # can be medium (bfloat16), high (tensorfloat32), highest (float32)
def extract_feat(model, args):
mol_rep_total, text_rep_total = [], []
model.eval()
for batch in tqdm(model.test_match_loader):
aug, text, text_mask = batch
aug = aug.to(args.device)
text = text.to(args.device)
text_mask = text_mask.to(args.device)
graph_rep, graph_feat, graph_mask = model.blip2qformer.graph_forward(aug) # shape = [B, num_qs, D]
text_rep = model.blip2qformer.text_forward(text, text_mask) # shape = [B, D]
mol_rep_total.append(graph_rep.detach().cpu())
text_rep_total.append(text_rep.detach().cpu())
mol_rep = torch.cat(mol_rep_total, dim=0)
text_rep = torch.cat(text_rep_total, dim=0)
sim_q2t = (mol_rep.unsqueeze(1) @ text_rep.unsqueeze(-1)).squeeze() # shape = [B, 1, num_qs, D]; shape = [B, D, 1]; output shape = [B, B, num_qs]
score, _ = sim_q2t.max(-1) # shape = [B, B]
return score
def main(args):
pl.seed_everything(args.seed)
# model
if args.init_checkpoint:
model = Blip2Stage1.load_from_checkpoint(args.init_checkpoint, device=args.devices, args=args)
print(f"loading model from {args.init_checkpoint}")
print(model)
else:
model = Blip2Stage1(args)
print('total params:', sum(p.numel() for p in model.parameters()))
tokenizer = model.blip2qformer.tokenizer
dm = Stage1KVPLMDM(args.num_workers, 2, args.root, args.text_max_len, args.graph_aug, args)
model.val_match_loader = dm.val_match_loader
model.test_match_loader = dm.test_match_loader
model = model.to(args.device)
score = extract_feat(model, args)
print(score)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--filename', type=str, default="pcdes_extract")
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument('--root', type=str, default="./data/kv_data")
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--match_batch_size", type=int, default=2)
parser.add_argument("--text_max_len", type=int, default=256)
parser.add_argument('--graph_aug', type=str, default="dnodes")
# GPU
parser.add_argument('--seed', type=int, default=42, help='random seed')
# MM settings
parser.add_argument('--gtm', action='store_true', help='use graph-text matching or not', default=True)
parser.add_argument('--lm', action='store_true', help='use language modeling or not', default=True)
parser.add_argument('--mode', type=str, default='train')
parser.add_argument('--accelerator', type=str, default='gpu')
parser.add_argument('--precision', type=str, default='bf16')
parser.add_argument('--max_epochs', type=int, default=50)
parser.add_argument('--check_val_every_n_epoch', type=int, default=1)
parser = Blip2Stage1.add_model_specific_args(parser) # add model args
args = parser.parse_args()
print("=========================================")
for k, v in sorted(vars(args).items()):
print(k, '=', v)
print("=========================================")
main(args)
Hi, from the image of your first post. It seems that you did not setup the init_checkpoint parameter. Can you setup this parameter and try again?
I have tried it and successfully resolved the issue. Thank you very much for your valuable response!
Thanks!
Thank you for your excellent work! I would like to ask for your assistance with reproducing the results of Molecule-Text Retrieval for PCDes. However, I seem to be encountering some issues.
I have downloaded the ckpt file from the following link: https://huggingface.co/acharkq/MolCA/tree/main and have executed the stage1.py script. However, I have obtained an accuracy of 0.
Could you kindly provide some guidance on how I can resolve this issue? Thank you in advance for your help!
And this is my full script: