shz9 / magenpy

Modeling and Analysis of (Statistical) Genetics data in python
https://shz9.github.io/magenpy/
MIT License
16 stars 5 forks source link

Explore using `bed-reader` as a backend #12

Open shz9 opened 2 years ago

shz9 commented 2 years ago

I was curious about the possibility of using bed-reader as a backend for some of the operations on the genotype matrix. From the preliminary testing that I have done, it seems fast and provides a lots of nice functionality.

Here's a sample implementation that inherits from GenotypeMatrix and implements some of the useful operations, such as linear scoring and reading the genotype file in chunks. However, in early experiments, it seems a bit slow compared to pandas-plink or using plink directly, as implemented in plinkBEDGenotypeMatrix. May need to improve this for it to be a practical competitor.

class rustBEDGenotypeMatrix(GenotypeMatrix):
    """
    NOTE: Still experimental... lots more work to do...
    """

    def __init__(self, sample_table=None, snp_table=None, temp_dir='temp', bed_mat=None):
        super().__init__(sample_table=sample_table, snp_table=snp_table, temp_dir=temp_dir)

        # xarray matrix object, as defined by pandas-plink:
        self.bed_mat = bed_mat

    @classmethod
    def from_file(cls, file_path, temp_dir='temp'):

        from bed_reader import open_bed

        try:
            bed_mat = open_bed(file_path)
        except Exception as e:
            raise e

        # Set the sample table:
        sample_table = pd.DataFrame({
            'FID': bed_mat.fid,
            'IID': bed_mat.iid,
            'fatherID': bed_mat.father,
            'motherID': bed_mat.mother,
            'sex': bed_mat.sex,
            'phenotype': bed_mat.pheno
        }).astype({
            'FID': str,
            'IID': str,
            'fatherID': str,
            'motherID': str,
            'sex': float,
            'phenotype': float
        })

        sample_table['phenotype'] = sample_table['phenotype'].replace({-9.: np.nan})
        sample_table = sample_table.reset_index()

        # Set the snp table:
        snp_table = pd.DataFrame({
            'CHR': bed_mat.chromosome,
            'SNP': bed_mat.sid,
            'cM': bed_mat.cm_position,
            'POS': bed_mat.bp_position,
            'A1': bed_mat.allele_1,
            'A2': bed_mat.allele_2
        }).astype({
            'CHR': int,
            'SNP': str,
            'cM': float,
            'POS': np.int,
            'A1': str,
            'A2': str
        })

        snp_table = snp_table.reset_index()

        g_mat = cls(sample_table=SampleTable(sample_table),
                    snp_table=snp_table,
                    temp_dir=temp_dir,
                    bed_mat=bed_mat)

        return g_mat

    @property
    def sample_index(self):
        return self.sample_table.table['index'].values

    @property
    def snp_index(self):
        return self.snp_table['index'].values

    def score(self, beta, standardize_genotype=False, skip_na=True):
        """
        Perform linear scoring on the genotype matrix.
        :param beta: A vector or matrix of effect sizes for each variant in the genotype matrix.
        :param standardize_genotype: If True, standardize the genotype when computing the polygenic score.
        :param skip_na: If True, skip missing values when computing the polygenic score.
        """

        pgs = None

        if standardize_genotype:
            from .stats.transforms.genotype import standardize
            for (start, end), chunk in self.iter_col_chunks(return_slice=True):
                if pgs is None:
                    pgs = standardize(chunk).dot(beta[start:end])
                else:
                    pgs += standardize(chunk).dot(beta[start:end])
        else:
            for (start, end), chunk in self.iter_col_chunks(return_slice=True):
                if skip_na:
                    chunk_pgs = np.nan_to_num(chunk).dot(beta[start:end])
                else:
                    chunk_pgs = np.where(np.isnan(chunk), self.maf[start:end], chunk).dot(beta[start:end])

                if pgs is None:
                    pgs = chunk_pgs
                else:
                    pgs += chunk_pgs

        return pgs

    def perform_gwas(self, **gwa_kwargs):

        raise NotImplementedError

    def compute_allele_frequency(self):
        self.snp_table['MAF'] = (np.concatenate([np.nansum(bed_chunk, axis=0)
                                                 for bed_chunk in self.iter_col_chunks()]) / (2. * self.n_per_snp))

    def compute_sample_size_per_snp(self):

        self.snp_table['N'] = self.n - np.concatenate([np.sum(np.isnan(bed_chunk), axis=0)
                                                       for bed_chunk in self.iter_col_chunks()])

    def iter_row_chunks(self, chunk_size='auto', return_slice=False):

        if chunk_size == 'auto':
            matrix_size = self.estimate_memory_allocation()
            # By default, we allocate 128MB per chunk:
            chunk_size = int(self.n // (matrix_size // 128))

        for i in range(int(np.ceil(self.n / chunk_size))):
            start, end = int(i * chunk_size), min(int((i + 1) * chunk_size), self.n)
            chunk = self.bed_mat.read(np.s_[self.sample_index[start:end], self.snp_index], num_threads=1)
            if return_slice:
                yield (start, end), chunk
            else:
                yield chunk

    def iter_col_chunks(self, chunk_size='auto', return_slice=False):

        if chunk_size == 'auto':
            matrix_size = self.estimate_memory_allocation()
            # By default, we allocate 128MB per chunk:
            chunk_size = int(self.m // (matrix_size // 128))

        for i in range(int(np.ceil(self.m / chunk_size))):
            start, end = int(i * chunk_size), min(int((i + 1) * chunk_size), self.m)
            chunk = self.bed_mat.read(np.s_[self.sample_index, self.snp_index[start:end]], num_threads=1)
            if return_slice:
                yield (start, end), chunk
            else:
                yield chunk