sgkit-dev / sgkit

Scalable genetics toolkit
https://sgkit-dev.github.io/sgkit
Apache License 2.0
235 stars 32 forks source link

Add variant/sample summary statistic methods #29

Closed eric-czech closed 3 years ago

eric-czech commented 4 years ago

At a minimum, we need to be able to calculate these axis-wise aggregations (for GWAS QC):

I would propose that we start by making methods for each that take a single Dataset and return Dataset instances (it will be easier to define the frequencies/rates when the counts are defined in the same functions).

This should consider whatever the solution to https://github.com/pystatgen/sgkit/issues/3 ends up being.

alimanfoo commented 4 years ago

Hi @eric-czech, sounds good. Could you propose an API for these functions? I.e., a proposed list of new function names, signatures and return types?

Also, would these return Dataset or DataArray?

eric-czech commented 4 years ago

Yea for sure. I had most of these in one place or another in my prototype so pulling them all together and adapting them a bit to the new conventions:

import xarray as xr
from xarray import Dataset
from typing_extensions import Literal

Dimension = Literal['samples', 'variants']

def _swap(dim: Dimension) -> Dimension:
    return 'samples' if dim == 'variants' else 'variants'

def call_rate(ds: Dataset, dim: Dimension) -> Dataset:
    odim = _swap(dim)[:-1]
    n_called = (~ds['call/genotype_mask'].any(dim='ploidy')).sum(dim=dim)
    return xr.Dataset({
        f'{odim}/n_called': n_called,
        f'{odim}/call_rate': n_called / ds.dims[dim]
    })

def genotype_count(ds: Dataset, dim: Dimension) -> Dataset:
    odim = _swap(dim)[:-1]
    mask, gt = ds['call/genotype_mask'].any(dim='ploidy'), ds['call/genotype']
    n_het = (gt > 0).any(dim='ploidy') & (gt == 0).any(dim='ploidy')
    n_hom_ref = (gt == 0).all(dim='ploidy')
    n_hom_alt = (gt > 0).all(dim='ploidy')
    n_non_ref = (gt > 0).any(dim='ploidy')
    agg = lambda x: xr.where(mask, False, x).sum(dim=dim)
    return xr.Dataset({
        f'{odim}/n_het': agg(n_het),
        f'{odim}/n_hom_ref': agg(n_hom_ref),
        f'{odim}/n_hom_alt': agg(n_hom_alt),
        f'{odim}/n_non_ref': agg(n_non_ref)
    })

def allele_count(ds: Dataset) -> Dataset:
    # Collapse 3D calls into 2D array where calls are flattened into columns
    gt = ds['call/genotype'].stack(calls=('samples', 'ploidy'))
    mask = ds['call/genotype_mask'].stack(calls=('samples', 'ploidy'))

    # Count number of non-missing alleles (works with partial calls)
    an = (~mask).sum(dim='calls')
    # Count each individual allele
    ac = xr.concat([
        xr.where(mask, 0, gt == i).sum(dim='calls')
        for i in range(ds.dims['alleles'])
    ], dim='alleles').T

    return xr.Dataset({
        'variant/allele_count': ac,
        'variant/allele_total': an,
        'variant/allele_frequency': ac / an
    })

#######################
# Convenience functions

def variant_stats(ds: Dataset) -> Dataset:
    return xr.merge([
        call_rate(ds, dim='samples'),
        genotype_count(ds, dim='samples'),
        allele_count(ds)
    ])    

def sample_stats(ds: Dataset) -> Dataset:
    return xr.merge([
        call_rate(ds, dim='variants'),
        genotype_count(ds, dim='variants')
    ])

I think it makes sense for most of them to return Dataset like this, but DataArrays function a lot like Datasets do so there isn't any reason a stat summary with a single variable couldn't be DataArray instead.

That allele_count function there seems to solve https://github.com/pystatgen/sgkit/issues/3, assuming it doesn't need to be more complicated than collapsing the ploidy and sample dimensions together before doing a nan-aware sum for each allele index. That would work for differing numbers of alleles per locus too, though obviously the loop wouldn't scale well if the maximum number of alleles at any one site was high.

Invocation and some results would look like this:

ds = sgkit_plink.read_plink(path)
xr.merge([
  variant_stats(ds), 
  sample_stats(ds)
]).compute()

<xarray.Dataset>
Dimensions:                   (alleles: 2, samples: 165, variants: 1457897)
Dimensions without coordinates: alleles, samples, variants
Data variables:
    variant/n_called          (variants) int64 165 161 165 165 ... 162 165 164
    variant/call_rate         (variants) float64 1.0 0.9758 1.0 ... 1.0 0.9939
    variant/n_het             (variants) int64 0 0 0 46 37 46 ... 82 86 85 0 45
    variant/n_hom_ref         (variants) int64 165 160 165 116 ... 40 39 165 116
    variant/n_hom_alt         (variants) int64 0 1 0 3 3 3 3 ... 37 21 37 38 0 3
    variant/n_non_ref         (variants) int64 0 1 0 49 40 ... 103 123 123 0 48
    variant/allele_count      (variants, alleles) int64 330 0 320 2 ... 0 277 51
    variant/allele_total      (variants) int64 330 322 330 330 ... 324 330 328
    variant/allele_frequency  (variants, alleles) float64 1.0 0.0 ... 0.1555
    sample/n_called           (samples) int64 1453694 1437110 ... 1455881
    sample/call_rate          (samples) float64 0.9971 0.9857 ... 0.9835 0.9986
    sample/n_het              (samples) int64 399167 384755 ... 391392 397963
    sample/n_hom_ref          (samples) int64 959731 954114 ... 946913 962834
    sample/n_hom_alt          (samples) int64 94796 98241 98518 ... 95568 95084
    sample/n_non_ref          (samples) int64 493963 482996 ... 486960 493047
alimanfoo commented 4 years ago

Invocation and some results would look like this:

Very neat :-)

alimanfoo commented 4 years ago
def genotype_count(ds: Dataset, dim: Dimension) -> Dataset:
    odim = _swap(dim)[:-1]
    mask, gt = ds['call/genotype_mask'].any(dim='ploidy'), ds['call/genotype']
    n_het = (gt > 0).any(dim='ploidy') & (gt == 0).any(dim='ploidy')
    n_hom_ref = (gt == 0).all(dim='ploidy')
    n_hom_alt = (gt > 0).all(dim='ploidy')
    n_non_ref = (gt > 0).any(dim='ploidy')
    agg = lambda x: xr.where(mask, False, x).sum(dim=dim)
    return xr.Dataset({
        f'{odim}/n_het': agg(n_het),
        f'{odim}/n_hom_ref': agg(n_hom_ref),
        f'{odim}/n_hom_alt': agg(n_hom_alt),
        f'{odim}/n_non_ref': agg(n_non_ref)
    })

Just noting implementation of het and hom_alt here doesn't work for multiallelic variants.

eric-czech commented 4 years ago

Just noting implementation of het and hom_alt here doesn't work for multiallelic variants

Ah thank you, how does this look instead?

hom_alt = ((gt > 0) & (gt[..., 0] == gt)).all(dim='ploidy')
hom_ref = (gt == 0).all(dim='ploidy')
het = ~(hom_alt | hom_ref)

That would call any combination of non-zero but identical calls at a site homozygous alternate and define heterozygous as not being homozygous.

alimanfoo commented 4 years ago

Ah thank you, how does this look instead?

hom_alt = ((gt > 0) & (gt[..., 0] == gt)).all(dim='ploidy')
hom_ref = (gt == 0).all(dim='ploidy')
het = ~(hom_alt | hom_ref)

Looks good for hom_alt and hom_ref, but het would also return true for calls where one or more alleles were missing. Better to compare alleles within calls, e.g., scikit-allel implementation.

eric-czech commented 4 years ago

but het would also return true for calls where one or more alleles were missing

I meant those as changes within the larger function though where the mask would make all the counts 0 if any were missing:

def genotype_count(ds: Dataset, dim: Dimension) -> Dataset:
    odim = _swap(dim)[:-1]
    mask, gt = ds['call/genotype_mask'].any(dim='ploidy'), ds['call/genotype']
    non_ref = (gt > 0).any(dim='ploidy')
    hom_alt = ((gt > 0) & (gt[..., 0] == gt)).all(dim='ploidy')
    hom_ref = (gt == 0).all(dim='ploidy')
    het = ~(hom_alt | hom_ref)
    # This would 0 out the `het` case with any missing calls
    agg = lambda x: xr.where(mask, False, x).sum(dim=dim)
    return xr.Dataset({
        f'{odim}/n_het': agg(het),
        f'{odim}/n_hom_ref': agg(hom_ref),
        f'{odim}/n_hom_alt': agg(hom_alt),
        f'{odim}/n_non_ref': agg(non_ref)
    })

ds = get_dataset([ 
    [[2, 2], [1, 1], [0, 0]],
    [[0, 1], [1, 2], [2, 1]],
    [[-1, 0], [-1, 1], [-1, 2]],
    [[-1, -1], [-1, -1], [-1, -1]],
])
print(genotype_count(ds, dim='samples').to_dataframe().to_markdown())
variants variant/n_het variant/n_hom_ref variant/n_hom_alt variant/n_non_ref
0 0 1 2 2
1 3 0 0 3
2 0 0 0 0
3 0 0 0 0

If one sample at one variant had calls like [0, 0, -1], would you call that homozygous reference or just omit it from the counts? I was assuming the latter.

daletovar commented 4 years ago

At this point it looks like #102 has a genotype_counts, allele_counts and allele_frequency, and call_rate and #76 added HWE. What still needs doing to merge #102 and are there any other aggregations we should add?

eric-czech commented 4 years ago

What still needs doing to merge #102 and are there any other aggregations we should add?

It would be best to focus on getting those ones merged first and at a glance some remaining todos are:

hammer commented 4 years ago

I believe https://github.com/pystatgen/sgkit/issues/282 is related; @eric-czech as part of issue triage is it worth enumerating what's left to close this issue out?

eric-czech commented 4 years ago

I believe #282 is related; @eric-czech as part of issue triage is it worth enumerating what's left to close this issue out?

I think #282 can stay on its own since it's not an aggregation. I turned the original bullet points into a checklist. All that's left is to add functions that call the some of the same internal functions with a different dimension for the sample-wise stats. I.e. we need a sample_stats function like variant_stats.