malariagen / malariagen-data-python

Analyse MalariaGEN data from Python
https://malariagen.github.io/malariagen-data-python/latest/
MIT License
14 stars 24 forks source link

Add functions for pairwise FST comparisons between cohorts #387

Closed KellyLBennett closed 11 months ago

KellyLBennett commented 1 year ago

Add functions to calculate FST between a pair of cohorts, a set of cohorts and to plot these values as a heat map. Draft functions are as follows:

def average_fst(
    region,
    admin_level,
    cohort1,
    cohort2,
    blen=10000,
    site_mask=None,
    site_class=None,
    cohort_size=10,
    random_seed=None,
    ):
    cohort1_counts=ag3.snp_allele_counts(region=region, sample_sets=None, sample_query=f"{admin_level} == '{cohort1}'", site_mask=site_mask, site_class=site_class, cohort_size=cohort_size, random_seed=random_seed)
    cohort2_counts=ag3.snp_allele_counts(region=region, sample_sets=None, sample_query=f"{admin_level} == '{cohort2}'", site_mask=site_mask, site_class=site_class, cohort_size=cohort_size, random_seed=random_seed)
    fst_hudson, se_hudson, vb_hudson, _ = allel.blockwise_hudson_fst(cohort1_counts, cohort2_counts, blen=blen)
    return fst_hudson, se_hudson
KellyLBennett commented 1 year ago

Ordinarily produces a standard comparison table. If diag==True it will produce a table with results on the diagonal. If both diag==True and annotate_se==True it will produce a table with FST on the lower diagonal and SE on the upper diagonal. These can be plugged into figure function.

def pairwise_average_fst(
    region,
    admin_level,
    cohort_list,
    blen=10000,
    site_mask=None,
    site_class=None,
    cohort_size=10,
    random_seed=None,
    diag=False,
    annotate_se=False
    ):
    #if wish to sort list by species, will place sample location samples together before year.
    #cohort_list.sort(key = lambda x: x.split("_")[1])
    fst_dict = {key:[] for key in cohort_list}
    se_dict = {key:[] for key in cohort_list}
    for cohort1 in cohort_list:
        for cohort2 in cohort_list:
            fst_hudson, se_hudson = average_fst(region=region, admin_level=admin_level, cohort1=cohort1, cohort2=cohort2, blen=blen, site_mask=site_mask, site_class=site_class, cohort_size=cohort_size, random_seed=random_seed)
            #convert minus numbers to 0
            if fst_hudson < 0:
                fst_hudson = 0
            # put values in dict for df.
            fst_dict[cohort1].append(round(fst_hudson,3))
            se_dict[cohort1].append(round(se_hudson,3))
            print ("calculating FST between " + cohort1, cohort2)
    #put dictionary into pd df and add index.
    fst_df = pd.DataFrame.from_dict(fst_dict)
    fst_df.index = cohort_list
    se_df = pd.DataFrame.from_dict(se_dict)
    se_df.index = cohort_list
    #format as string to get correct decimals.
    for col in fst_df:
        new_col= fst_df[col].map('{:.3f}'.format).astype(str)
        fst_df[col]=new_col
    for col in se_df:
        new_col= se_df[col].map('{:.3f}'.format).astype(str)
        se_df[col]=new_col
    #if display options not required return df here.
    #if you want to return a diagonal table with upper triangle masked set diag == True. 
    if diag==True:
        if annotate_se==False:
            fst_df=fst_df.mask(np.triu(np.ones(fst_df.shape)).astype(bool))
            se_df=se_df.mask(np.triu(np.ones(se_df.shape)).astype(bool))
     # if want to produce table with fst in lower triangle and se in upper triangle. se_df currently still returned for this option even though not really needed.     
        if annotate_se==True:
            fst_df=fst_df.mask(np.triu(np.ones(fst_df.shape)).astype(bool))
            se_df=se_df.mask(np.tril(np.ones(se_df.shape)).astype(bool))
            fst_df=fst_df.fillna(se_df)
            for i in range(min(fst_df.shape)): fst_df.iloc[i, i] = 'nan'           
    #replace null values with empty string so doesn't display na in plots. Option to move this to plots function.
    fst_df=fst_df.astype(str).replace('nan','')
    se_df=se_df.astype(str).replace('nan','')    
    return fst_df, se_df
KellyLBennett commented 1 year ago
def plot_pairwise_average_fst(
    fst_df,
    annotate=True,
    zmin=0,
    zmax=1,
    width=None,
    height=None,
    text_auto=False,
    color_continuous_scale='gray',
    aspect="auto",
    title=None,
    plot_bgcolor='rgba(0,0,0,0)',
    showgrid=False,
    linecolor='black'
    ):
    #if want nan in table move replace here
    #fst_df=fst_df.astype(str).replace('nan','')
    fig = px.imshow(img=fst_df,zmin=zmin,zmax=zmax,text_auto=text_auto, color_continuous_scale=color_continuous_scale, title=title,aspect=aspect)
    fig.update_traces(text=fst_df.values,texttemplate="%{text}")
    fig.update_layout(plot_bgcolor=plot_bgcolor)
    fig.update_yaxes(showgrid=showgrid, linecolor=linecolor)
    fig.update_xaxes(showgrid=showgrid, linecolor=linecolor)
    return fig
alimanfoo commented 1 year ago

Thanks @KellyLBennett.

For the average_fst() function, here's a couple of suggestions to get API consistency with other functions.


For specifying which cohorts to use, relevant functions are h1x_gwss() and fst_gwss(). In these functions, to specify the two cohorts there are two parameters cohort1_query and cohort2_query each of which takes a pandas sample query. These support a sample query rather than just the name of a cohort in order have the full flexibility to specify any cohort you like, not just the predefined cohorts.

For consistency, for average_fst suggest we provide parameters cohort1_query and cohort2_query instead of admin_level, cohort and cohort2.

So, e.g., a user would then do something like:

fst, fst_se = ag3.average_fst(
    cohort1_query="cohort_admin2_year == 'ML-2_Kati_colu_2014'",
    cohort2_query="cohort_admin2_year == 'ML-2_Kati_gamb_2014'",
    # ... other params ...
)

I'm conscious this is a bit less convenient for the case where the user just wants to use predefined cohorts. Suggest we figure out how to support that in a later PR, providing the same support across all functions that take a pair of cohorts.


Regarding management of the cohort size, we've also recently transitioned to preferring a combination of min_cohort_size and max_cohort_size parameters, rather than a fixed cohort_size. Comparable functions like fst_gwss() support all three parameters, but cohort_size is None, and there are default values for min_cohort_size and max_cohort_size. The rationale is that we do want to ensure a minimum cohort size, as we know from experience that average Fst values are definitely biased below 10 samples. We also don't want too many samples, as that is unnecessary and can slow down the computation. But otherwise anything between 10 and 50 is fine, and we'd prefer to use up to 50 where we have them. There is no need to fix the same cohort size in both cohorts, Hudspn's Fst is robust to that.

So suggest adding min_cohort_size and max_cohort_size params in addition to cohort_size, with default values as given in fst_gwss().


One other parameter to look at is blen. This is the name of the parameter in the scikit-allel API, but it's a bit of a cryptic name. The purpose is to decide how many SNPs to include within each block for the block jackknife procedure. But we do something similar for the cohort_diversity_stats() function, i.e., use a block jackknife to estimate confidence intervals. In that function I opted to expose an n_jack parameter which is the number of blocks to divide the data up into (and hence the number of jackknife replicates). The actual block length is then computed from the number of sites.

So for consistency, suggest exposing an n_jack parameter instead of blen and computing blen internally.


If those suggestions sound OK then feel free to go ahead and start working on a PR. Maybe we could start with the average_fst() function, then add in the others.

alimanfoo commented 11 months ago

Resolved via #419.