greenelab / tybalt

Training and evaluating a variational autoencoder for pan-cancer gene expression data
BSD 3-Clause "New" or "Revised" License
162 stars 61 forks source link

MAD: mean or median? #99

Closed enricoferrero closed 6 years ago

enricoferrero commented 6 years ago

Hi Greg,

Just a quick comment on something I spotted: in the manuscript you mentioned you've used median absolute deviation (MAD), but in the process_data.ipynb notebook you seem to use the pd.DataFrame.mad() method which, according to the docs, calculates the mean absolute deviation: https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.mad.html

Hope this helps (and sorry if I got that wrong!), Enrico

gwaybio commented 6 years ago

Hi @enricoferrero - Fantastic catch, thank you!

Indeed, it does appear there is a discrepancy between the text and processing. I did some digging to determine the extent of the impact.

It looks like ~85% of the genes are the same if we use median absolute deviation (from statsmodels.robust.scale.mad)

mean_median_mad

I had done some analyses a couple years ago about how much information is retained through subsetting genes. One example can be found here in our Cognoma repository.This isn't the same evaluation, but it is using the same data. We see that performance is not impacted after the inclusion of a base set amount of genes.

It also appears that the most variable genes are captured by both mean and median, and therefore capture more variation in the data - which is likely to be keyed in on by the VAE.

mean_median_mad_scatter

This is because gene expression is massively correlated and the measurement of certain genes can predict expression levels of others with high accuracy (Rudd et al. 2015).

Therefore, I do not believe the discrepancy will impact performance much. It should impact interpretation more, as certain unmeasured genes may not be included in pathway analyses.

I will add a note in the README about this discrepancy, so that it will be more clear for the inquisitive reader. Code to reproduce the above figures is pasted below.

Thanks again!

import os
from statsmodels import robust
import numpy as np
import pandas as pd
from sklearn import preprocessing

import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib_venn import venn2

# Input Files
rna_file = os.path.join('data', 'raw', 'HiSeqV2')
rnaseq_df = pd.read_table(rna_file, index_col=0)

# Process RNAseq file
rnaseq_df.index = rnaseq_df.index.map(lambda x: x.split('|')[0])
rnaseq_df.columns = rnaseq_df.columns.str.slice(start=0, stop=15)
rnaseq_df = rnaseq_df.drop('?').fillna(0).sort_index(axis=1)

# Gene is listed twice in RNAseq data, drop both occurrences
rnaseq_df.drop('SLC35E2', axis=0, inplace=True)
rnaseq_df = rnaseq_df.T

# Determine most variably expressed genes and subset
num_mad_genes = 5000
mean_mad_genes = rnaseq_df.mad(axis=0).sort_values(ascending=False)
top_mean_mad_genes = mean_mad_genes.iloc[0:num_mad_genes, ]

# Test MEDIAN absolute deviation
median_mad_genes = robust.scale.mad(rnaseq_df, axis=0)
median_mad_genes = pd.DataFrame(median_mad_genes, columns=['median_mad'],
                                index=rnaseq_df.columns).sort_values(by='median_mad',
                                                                     ascending=False)
top_median_mad_genes = median_mad_genes.iloc[0:num_mad_genes, ]

# Venn Diagram Overlaps
venn2([set(top_mean_mad_genes.index), set(top_median_mad_genes.index)],
      set_labels=('Mean MAD', 'Median MAD'))

# Mean vs. Median Scatter
mad_df = pd.concat([pd.DataFrame(median_mad_genes), pd.DataFrame(mean_mad_genes)], axis=1)
mad_df.columns = ['median_mad', 'mean_mad']
mad_df = mad_df.loc[set(top_mean_mad_genes.index) | set(top_median_mad_genes.index), ]

median_min = median_mad_genes.iloc[0:num_mad_genes, ].min()
mean_min = mean_mad_genes.iloc[0:num_mad_genes, ].min()
g = sns.regplot(x='median_mad', y='mean_mad', data=mad_df, 
                scatter_kws={'s':2.5})
plt.plot([median_min, median_min], [0, 5], linewidth=1, color='red',
         linestyle='--')
plt.plot([-1, 9], [mean_min, mean_min], linewidth=1, color='red',
         linestyle='--')
stale[bot] commented 6 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.