deepmodeling / Uni-Mol

Official Repository for the Uni-Mol Series Methods
MIT License
706 stars 124 forks source link

用unimol+模型 inference出的pos_pred取同一个id的均值吗? #269

Open chenkk717 opened 1 month ago

chenkk717 commented 1 month ago

在make_pcq_test_dev_submission.py文件中gap_pred是取同一个id对应值的平均,大多数id是8个值的平均。然而在预测出的pos_pred(原子坐标信息中),为什么相同id得到的pos_pred形状会不一样?也就是说同一个分子的原子数量对不上?如下是某个id预测出的pos_pred test 形状为Shape mismatch for id group: [(19, 3), (19, 3), (19, 3), (19, 3), (19, 3), (19, 3), (19, 3), (15, 3)],第8个预测出来的pos_pred只有15个原子对应的坐标信息,而其它为19个原子。 想请问这是什么情况?以及想得到某个id的最终pos_pred值,是采用什么方法?除去有异常原子数量的预测值,然后其余取均值吗?

chenkk717 commented 1 month ago

同时想请问一下原paper(Data-driven quantum chemical property prediction leveraging 3D conformationswith Uni-Mol+)里Fig. 2 image 中predicted conformations的sdf文件是如何得来的?因为模型预测后只有原子的坐标信息,是根据smiles使用RDKit生成mol对象然后再把pos_pred代入吗? image

ShuqiLu commented 1 month ago

您好,请问第一个问题的id是什么,我可以先复现一下。然后也确认下这个id对应的构象是8个吗,因为不是所有的id都有8个构象;第二个问题确实是的,是把mol对象的坐标换成预测的坐标来生成的

chenkk717 commented 1 month ago

①第一个问题里的id在图片的上方:3379091,我是运行的test-dev split这个测试集,实际上不止这一个id有这个情况,截取部分预测原子数量不一致的id: image image ②想请问最终的final pos_pred是取同一个id的所有预测pos_pred的均值吗?还是按照某种规则挑选其一? ③smiles中如果含有H原子,那么在用预测的坐标替换掉由smiles结构生成的mol对象坐标时,氢原子的坐标信息应该怎么处理呢?因为在get_3d_lmdb.py对数据集的预处理文件中,我发现在def rdkit_mmff(mol)函数中return rdkit_remove_hs(mol)也就是input_pos或者label_pos都是不含有H原子信息的。在得到sdf文件时,是在smiles生成mol对象后再进行了H原子的移除吗?然后再代入预测坐标值

ShuqiLu commented 1 month ago

1:不确定这里是怎么输出的每个id对应的构象坐标shape,如果可以的话可以提供一下相应的代码。不过不是所有id对应的构象都是8个,如果按照给出的图里全都是按照8个来切分,可能混淆了不同id的分子,所以原子数不一样。

  1. 这篇论文里我们最终是为了获得预测的homo lumo gap,所以没有返回最终的预测坐标,在得到不同构象下的pos_pred之后没有再做处理了。
  2. rdkit_remove_hs(mol)返回的是一个不带H的mol对象,我们show case的时候是基于这个mol对象进行的坐标替换再展示画图,过程中都没有H。
chenkk717 commented 1 month ago

1.发现相同id对应的预测构象坐标shape不同后,我检查了模型预测后得到的直接输出test-dev_0.pkl,以防是我输出shape的代码有误,check预测的坐标值(以id3379091为例子)结果如下: image 其中第8个构象的原子数确实为15,而其它为19,和shape的输出[(19, 3), (19, 3), (19, 3), (19, 3), (19, 3), (19, 3), (19, 3), (15, 3)]吻合,详细的数据见 id3379091_pos_pred.txt 这部分的代码如下:

import numpy as np
import torch
import pickle
import glob
import pandas as pd

input_folder = "results"
subset = "test-dev"
split = torch.load("./scripts/pcqm4m-v2/split_dict.pt")
valid_index = split[subset]

def flatten(d, index):
    res = []
    for x in d:
        res.extend(x[index])
    return np.array(res)

# 提取并处理每个id的所有原子坐标
def one_ckp(folder, subset):
    s = f"{folder}/" + subset + "*.pkl"
    files = sorted(glob.glob(s))
    data = []
    for file in files:
        with open(file, "rb") as f:
            try:
                data.extend(pickle.load(f))
            except Exception as e:
                print("Error in file: ", file)
                raise e

    # 提取 id 和 pos_pred
    id = flatten(data, 0)  # 分子id
    pos_pred = flatten(data, 1)  # 该分子所有原子的三维坐标

    # 将数据放入 DataFrame,每个id保留完整的分子坐标信息
    df = pd.DataFrame({"id": id, "pos_pred": list(pos_pred)})

    # 按 id 分组,计算同一个 id 的分子坐标均值(多个预测结果的均值)
    df_grouped = df.groupby("id")
    # 调试函数:检查同一个 id 下的 pos_pred 是否具有相同的形状
    def check_shapes_and_mean(x):
        shapes = [arr.shape for arr in x]
        if not all(shape == shapes[0] for shape in shapes):
            print(f"Shape mismatch for id {x.name} group: {shapes}")
            return None
        return np.mean(np.stack(x), axis=0)
    df_mean = df_grouped["pos_pred"].apply(check_shapes_and_mean)
    return df_mean

#保存结果为 Parquet 格式
def save_pos_submission_parquet(df_grouped, output_file):
    # 将每个分子的平均坐标保存为 Parquet 文件
    df_grouped = pd.DataFrame(df_grouped.tolist(), index=df_grouped.index, columns=["pos_pred"])
    df_grouped.to_parquet(output_file, index=True, compression='snappy')
    print(f"Saved position predictions to {output_file}")

df_mean_pos = one_ckp(input_folder, subset)
output_file = "pos_pred_mean.parquet"
save_pos_submission_parquet(df_mean_pos, output_file)

我想代码中是以 df_grouped = df.groupby("id")按id分组的,并不是以8个为一组,如果是切分有误,那么问题可能发生在我的模型inference过程中,但是inference.py文件我并无改动,想知道您的test-dev测试集中该id的pos_pred结果原子数量有异吗? 2.关于final pos_pred的选择问题,我下载了本文提供的Supplementary_Data中conformation_compare_fig2的sdf文件,如下: image 分子id后的数字是表示选取的第几个conformation吗?对于Fig. 2中id为3388743是第4个而id为3428088是第2个,以此类推。所以想请问这里的选择标准是什么?是根据可视化之后的RMSD吗?然后选择RMSD最小的一个构象作为展示?

chenkk717 commented 1 month ago

我重新检查了unimol+模型的inference过程。在unimol_plus文件夹下的pcq.py中的load_dataset函数部分,涉及到pcq_dataset.py文件中的PCQDataset函数,其中以下代码:

        max_node_num = max([item["atom_mask"].shape[0] for item in items])
        max_node_num = (max_node_num + 1 + 3) // 4 * 4 - 1
        batched_data = {}
        for key in items[0].keys():
            samples = [item[key] for item in items]
            if key in pad_fns:
                batched_data[key] = pad_fns[key](samples, max_node_num)

max_node_num 是每个batch里最大分子原子数,然后将其调整到最接近的 4 的倍数减去 1。atom_mask对分子的真实原子位赋1,新添加的虚拟原子位赋0,包括后面涉及到的attn_mask对新添的虚拟原子位赋-inf。经过这种处理,新添的虚拟原子坐标值一开始都是0,但是在经过inference后,虚拟原子(mask标记为0)的坐标也有了预测值,导致我上面出现的同一个id(相同分子)对应的原子数不相同的情况(原子个数为每个batch里的max_node_num)。 因为有的分子不一定是生成8个conformers,batch_size是设定的为8的倍数,就会导致同一个batch里可能包含不同的分子id,然而预测过程中,mask即使为0的原子也有了坐标的预测输出,所以同一batch里面的分子原子数都是固定的max_node_num。不知道我的理解是否有偏差,希望能解答一下疑惑。

chenkk717 commented 1 week ago

您好,请问复现有结果了吗?不知道我上述猜想是否正确? @ShuqiLu

ShuqiLu commented 1 week ago

不好意思没有看到您的回复,你理解的其实也差不多,这里把所有分子的原子数都置为max_node_num是为了能用pytorch并行处理一个batch内的所有分子,需要所有tensor的shape一致,所以这里用padding操作,把batch的的分子的原子数补充成相同的数目;为了使得padding的内容不实际影响模型的运算结果,所以使用atom_mask和attn_mask让padding的内容不参与实际运算;因为这篇工作我们只预测分子的能量并不取出分子坐标独立研究,所以没有对返回的分子坐标处理padding的部分,所以看起来同一个分子生成的坐标在不同batch内shape不一样。实际上如果需要取出真实原子的坐标,去掉padding的部分,可以利用atom_mask,把每个分子的atom_mask=1的位置对应的坐标取出,就是所有真实原子的坐标; 或者假设真实原子数为k,可以取出前k个坐标,利用数据中的smiles生成rdkit初始构象,再将前k个坐标填入即得到预测的3d构象(需要原始数据中的smiles不然可能没法对应)。

至于show case中选择的标准是什么,这里其实我们就选择了几个能量预测误差相对小的case展示了一下,没有过多的特殊筛选。