tskit-dev / tskit

Population-scale genomics
MIT License
147 stars 69 forks source link

Folded AFS between branch and site modes differs by factor of two #2925

Closed nspope closed 2 months ago

nspope commented 2 months ago

The following is a little unintuitive to me:

import msprime
import numpy as np

mu = 1e-8
ts = msprime.sim_ancestry(
    samples=10,
    recombination_rate=1e-8,
    sequence_length=1e6,
    population_size=1e4,
)
ts = msprime.sim_mutations(ts, rate=mu)

ss_site = ts.segregating_sites()
ss_branch = ts.segregating_sites(mode='branch') * mu
afs_site = np.sum(ts.allele_frequency_spectrum(polarised=True))
afs_branch = np.sum(ts.allele_frequency_spectrum(mode='branch', polarised=True) * mu)
foldafs_site = np.sum(ts.allele_frequency_spectrum(polarised=False))
foldafs_branch = np.sum(ts.allele_frequency_spectrum(mode='branch', polarised=False) * mu)

print(f"Ratio of site to branch, segsites: {ss_site/ss_branch}")
print(f"Ratio of site to branch, sum(AFS): {afs_site/afs_branch}")
print(f"Ratio of site to branch, sum(folded AFS): {foldafs_site/foldafs_branch}")

# Ratio of site to branch, segsites: 1.0018654284984811
# Ratio of site to branch, sum(AFS): 1.0018654284984798
# Ratio of site to branch, sum(folded AFS): 2.0037308569969605

the entries of the branch-mode folded AFS are a factor of 2 from the site-mode folded AFS ... Shouldn't mode='branch' give the expectation exactly here; like it does for segregating sites and the unfolded AFS?

nspope commented 2 months ago

Ah, I see this quote in the documentation:

This means that the sum of an unpolarised AFS will be equal to the total number of alleles that are inherited by any of the samples in the tree sequence, divided by two.

So I think branch-mode folded AFS satisfies this -- the sum of entries is 1/2 segregating sites. But the site-mode folded AFS does not -- the sum of the entries equals segregating sites. Is this expected?

petrelharp commented 2 months ago

Well gee this sure seems like a bug, but if so I'm surprised since I thought we'd thought hard about all these permutations (but there were a lot of them...). Let's see: the one half happens here; mirroring a similar accounting here. It's possible the (correct) one-half from the site code got moved over into the branch code without thinking hard enough about it. The situation is different because in the site code there's the ancestral state that contributes to the calculation, while in the branch code accounting for the ancestral state would look like this (here):

                        x = [tree.num_tracked_samples(node) for tree in trees]
                        not_x = [len(s) - tree.num_tracked_samples(node) for s, tree in zip(sample_sets, trees)]
                        # Note x must be a tuple for indexing to work
                        if polarised:
                            S[tuple(x)] += t.branch_length(node) * tr_len
                        else:
                            x = fold(x, out_dim)
                            S[tuple(x)] += 0.5 * t.branch_length(node) * tr_len
                            S[tuple(not_x)] += 0.5 * t.branch_length(node) * tr_len

However, that would be redundant, since fold(x) == fold(not_x).

Oh, and here's a MWE (doesn't show anything different to the above, just makes the situation totally clear):

import tskit
t = tskit.TableCollection(sequence_length=4.0)
a = t.nodes.add_row(time=2, flags=0)

b = t.nodes.add_row(time=1, flags=0)
t.edges.add_row(left=0, right=1, parent=a, child=b)
s = t.sites.add_row(position=b-1, ancestral_state='A')
t.mutations.add_row(site=s, derived_state='C', node=b)

n = t.nodes.add_row(time=0, flags=1)
t.edges.add_row(left=0, right=1, parent=a, child=n)
s = t.sites.add_row(position=n-1, ancestral_state='A')
t.mutations.add_row(site=s, derived_state='C', node=n)

n = t.nodes.add_row(time=0, flags=1)
t.edges.add_row(left=0, right=1, parent=b, child=n)
s = t.sites.add_row(position=n-1, ancestral_state='A')
t.mutations.add_row(site=s, derived_state='C', node=n)

n = t.nodes.add_row(time=0, flags=1)
t.edges.add_row(left=0, right=1, parent=b, child=n)
s = t.sites.add_row(position=n-1, ancestral_state='A')
t.mutations.add_row(site=s, derived_state='C', node=n)

t.sort()
ts = t.tree_sequence()

for p in (True, False):
    for m in ('site', 'branch'):
        print(['unpolarised', 'polarised'][p], m,
              ts.allele_frequency_spectrum(polarised=p, mode=m, span_normalise=False))

which produces

polarised site [0. 3. 1. 0.]
polarised branch [0. 4. 1. 0.]
unpolarised site [0. 4. 0. 0.]
unpolarised branch [0.  2.5 0.  0. ]
petrelharp commented 2 months ago

Currently thinking it's a bug; will see if I still think that on Monday.

petrelharp commented 2 months ago

Just talked this through with @nate. The conclusion is:

  1. This is a bug, we need that factor of 2;
  2. we made the mistake since the multiallelic case is Confusing and needs (a) separate polarized & unpolarized implementations, including a factor of 0.5 in the unpolarized one
  3. however, being multiallelic only applies to the site version, so
  4. we can fix this by computing the unpolarized branch version by just first computing the polarized version and then folding it, instead of having separate algorithms.

I probably owe @jeromekelleher an apology for not catching that the first time around!

jeromekelleher commented 2 months ago

No apologies required @petrelharp - this stuff is super confusing and I could also have spotted it!

I think we can just do a straight fix here and document the change as a Bug Fix, right?

petrelharp commented 2 months ago

We could

  1. remove the 0.5 in the unpolarised branch code, or
  2. reimplement it as folded(polarised), and
  3. add a test that checks whether unpolariseed = folded(polarised) for branch AFS

Note that the test in (3) should also hold for site stats if the mutations are done with infinite sites.

jeromekelleher commented 2 months ago

Option 1 + 3 seems good?