jasonkyuyim / se3_diffusion

Implementation for SE(3) diffusion model with application to protein backbone generation
https://arxiv.org/abs/2302.02277
MIT License
305 stars 50 forks source link

Question about Table 1 FrameDiff sample metrics #32

Closed RinGhalSun closed 12 months ago

RinGhalSun commented 1 year ago

Hi Jason,

I am currently in the process of re-running the inference based on your paper. In particular, I am focusing on the metrics related to Table 1, specifically the FrameDiff sample metrics. Could you kindly provide some clarification on how the data from the sc_results.csv file is utilized in this context? I am encountering challenges in replicating the percentages reported in your paper using your provided weights.

My approach involves iterating through CSV files ranging from 100 to 500 with a step size of 5, and considering all samples within the range of 0 to 9.

And by the way, according to your paper, I found pdbTM is used to measeure the novelty of this model and diversity appeared in Table 1, and how are pdbTM and diversity calculated?

Your insights would be greatly appreciated.

Best regards,

Zheng

Z-MU-Z commented 1 year ago

same question!

jasonkyuyim commented 12 months ago

Here's the code I use

import os
import pandas as pd
import numpy as np
import plotnine as gg

def read_samples(results_dir):
    all_csvs = []
    print(f'Reading samples from {results_dir}')
    for sample_length in os.listdir(results_dir):
        if '.' in sample_length:
            continue
        length_dir = os.path.join(results_dir, sample_length)
        length = int(sample_length.split('_')[1])
        for i,sample_name in enumerate(os.listdir(length_dir)):
            if '.' in sample_name:
                continue
            csv_path = os.path.join(length_dir, sample_name, 'self_consistency', 'sc_results.csv')
            if os.path.exists(csv_path):
                design_csv = pd.read_csv(csv_path, index_col=0)
                design_csv['length'] = length
                design_csv['sample_id'] = i
                all_csvs.append(design_csv)
    results_df = pd.concat(all_csvs)
    return results_df

def sc_filter(raw_df, metric):
    # Pick best self-consistency sample
    if metric == 'tm_score':
        df = raw_df.sort_values('tm_score', ascending=False)
        df['designable'] = df.tm_score.map(lambda x: x > 0.5)
    elif metric == 'rmsd':
        df = raw_df.sort_values('rmsd', ascending=True)
        df['designable'] = df.rmsd.map(lambda x: x < 2.0)
    else:
        raise ValueError(f'Unknown metric {metric}')
    df = df.groupby(['length', 'sample_id']).first().reset_index()
    percent_designable = df['designable'].mean()
    print(f'Percent designable: {percent_designable}')
    return df

Then in a notebook I run this

samples_df = read_samples(PATH_TO_INFERENCE_OUTPUTS)
samples_df = samples_df[samples_df.sample_id < 8] # Ensure we only consider 8 sequences per backbone.

scrmsd_results = sc_filter(samples_df, 'rmsd')
sctm_results = sc_filter(samples_df, 'tm_score')

I re-ran inference with the published model weights paper_weights.pth and got 25% scRMSD and 78% scTM for noise scale 0.1, N_steps 500, N_seq 8 which agrees with result in Table 1. Note there will be some variance in the metrics.

RinGhalSun commented 12 months ago

Hi Jason,

I want to express my gratitude for your assistance and the effort you've put into re-running the code with your paper's weight. I achieved similar results: 77.41% and match in the experimental outcomes. While reviewing the code, I came across the following line:

samples_df = samples_df[samples_df.sample_id < 8] # Ensure we only consider 8 sequences per backbone.

I was curious if there was a specific reason for setting the condition sample_id < 8. Could you please provide some insight into this choice?

Thank you once again for your help!

Best regards,

Zheng

jasonkyuyim commented 12 months ago

No problem. The choice of 8 sequences is an convention for how many "chances" we get to find the matching sequence to the sampled backbone. I think it was started in one of David Baker's deep learning papers; I don't remember exactly which. This causes high variance in the success rates so I also included results with 100 sequences in table 1.

RinGhalSun commented 12 months ago

That make sense, thanks again for your invaluable help and time, Jason!