datamol-io / safe

A single model for all your molecular design tasks
https://safe-docs.datamol.io/
Apache License 2.0
79 stars 8 forks source link

Validity Calculation #57

Closed Anri-Lombard closed 3 weeks ago

Anri-Lombard commented 1 month ago

The paper claims a validity very close to 1 (and in fragment constrained generation exactly 1 on average). Was this calculation done by looking at the validity of generated molecules, or how many molecules were valid from the eventually generated molecules?

For example:

For some more context here is a snippet of my notebook:

ds = load_dataset("datamol-io/safe-drugs")
benchmark_df = DataFrame(ds['train'])
benchmark_df.info()

def calculate_diversity(mols):
    fps = [rdMolDescriptors.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048) for mol in mols]
    similarities = []
    for i in range(len(fps)):
        for j in range(i + 1, len(fps)):
            similarities.append(1 - DataStructs.TanimotoSimilarity(fps[i], fps[j]))
    return np.mean(similarities)

def calculate_distance_to_original(generated_mols, original_mol):
    original_fp = rdMolDescriptors.GetMorganFingerprintAsBitVect(original_mol, 2, nBits=2048)
    generated_fps = [rdMolDescriptors.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048) for mol in generated_mols]
    distances = [1 - DataStructs.TanimotoSimilarity(original_fp, gen_fp) for gen_fp in generated_fps]
    return np.mean(distances)

def run_fragment_constrained_benchmark(designer, benchmark_df, n_samples=1000, n_trials=1):
    results = []

    for _, row in tqdm(benchmark_df.iterrows(), total=len(benchmark_df)):
        original_mol = Chem.MolFromSmiles(row['smiles'])

        # Linker design
        linkers = designer.linker_generation(*row['morphing'].split('.'), n_samples_per_trial=n_samples, n_trials=n_trials, sanitize=True)
        linker_mols = [Chem.MolFromSmiles(smi) for smi in linkers if smi]

        # Motif extension
        motifs = designer.motif_extension(row['motif'], n_samples_per_trial=n_samples, n_trials=n_trials, sanitize=True)
        motif_mols = [Chem.MolFromSmiles(smi) for smi in motifs if smi]

        # Scaffold decoration
        scaffolds = designer.scaffold_decoration(row['scaffold'], n_samples_per_trial=n_samples, n_trials=n_trials, sanitize=True)
        scaffold_mols = [Chem.MolFromSmiles(smi) for smi in scaffolds if smi]

        # Scaffold morphing
        morphs = designer.scaffold_morphing(side_chains=row['morphing'].split('.'), n_samples_per_trial=n_samples, n_trials=n_trials, sanitize=True)
        morph_mols = [Chem.MolFromSmiles(smi) for smi in morphs if smi]

        # Superstructure generation
        superstructures = designer.super_structure(row['superstructure'], n_samples_per_trial=n_samples, n_trials=n_trials, sanitize=True)
        superstructure_mols = [Chem.MolFromSmiles(smi) for smi in superstructures if smi]

        tasks = ['Linker design', 'Motif extension', 'Scaffold decoration', 'Scaffold morphing', 'Superstructure']
        mol_lists = [linker_mols, motif_mols, scaffold_mols, morph_mols, superstructure_mols]

        for task, mols in zip(tasks, mol_lists):
            if mols:
                validity = len(mols) / n_samples
                uniqueness = len(set([Chem.MolToSmiles(mol) for mol in mols])) / len(mols) if mols else 0
                diversity = calculate_diversity(mols)
                distance = calculate_distance_to_original(mols, original_mol)
                sa_scores = [sascorer.calculateScore(mol) for mol in mols]
                sa_mean = np.mean(sa_scores)

                results.append({
                    'Drug': row['pref_name'],
                    'Task': task,
                    'Validity': validity,
                    'Uniqueness': uniqueness,
                    'Diversity': diversity,
                    'Distance': distance,
                    'SA score': sa_mean
                })
            else:
                results.append({
                    'Drug': row['pref_name'],
                    'Task': task,
                    'Validity': 0,
                    'Uniqueness': 0,
                    'Diversity': 0,
                    'Distance': 0,
                    'SA score': 0
                })

    return pd.DataFrame(results)

(Disclaimer: this is with the small 20M model, so that could be the cause now that I ponder this)

maclandrol commented 1 month ago

If I generate 1000 molecules and 952 are valid, but of the generated molecules running it through rdkit shows 100% of them are valid, is validity 0.952 or 1.0? The reason I ask is because after training the model from scratch and replicating your de_novo generation results, the fragment generation results do not seem to have as high a validity as the claim. I'm curious if this is my mistake.

The validity would be 0.952. By default, the sanitize option is set to true, so validity can be computed using # valid compounds returned / # compounds asked. Other metrics were then computed on the valid molecules. You can also just put sanitize to false and compute the metrics the "normal way".

Also note that for the paper we run it in chunks using different seeds, so not all the sampling is done in a single block. I doubt however that it would explain any disparity you might see.

I am quite curious about your results if you could share them. The most important difference is likely the massive amount of data + model capacity. If your training data does not cover a similar chemical space to the benchmark drugs, then you might have some difference in performance which can be fixed by finetuning on compounds from those chemical space.

As a sanity check, I would suggest picking a molecule from your training set, then defining scaffold, motif, fragments, etc from it to make a small scale test. You should normally expect good performance on that molecule.

Also explore the do_no_fragment_further argument, it can really change performance for some fragment-based tasks.

Anri-Lombard commented 4 weeks ago

Thank you @maclandrol; I'd be happy to share the results:

Screenshot 2024-08-13 at 12 54 08

Keep in mind I accidentally multiplied validity by 10, but beyond that these calculations should be taken with a heavy scoop of salt. I'm going to take your advice into account and do some experiments to see if my evaluation is set up correctly.

Just to address a few points you brought up: I am using the benchmark dataset provided by yourselves, although it has a few more molecules than what was used in the paper, if I'm not mistaken. I'll report back after some additional experimentation so we could keep this open for now if you don't mind.

Anri-Lombard commented 3 weeks ago

@maclandrol unfortunately I was not able to replicate the exact findings, but I did not follow the exact experimental setup either. It would be awesome to see different architectures of similar sizes to the large model be compared to safe-gpt, so perhaps we could do an experiment on this at a later date 🙌 😄