tskit-dev / msprime

Simulate genealogical trees and genomic sequence data using population genetic models
GNU General Public License v3.0
177 stars 88 forks source link

mean_coalescence_time(): convergence problems? #1775

Closed grahamgower closed 3 years ago

grahamgower commented 3 years ago

DemographyDebugger.mean_coalescence_time() seems to have difficulty with more complex population relationships. For the example below, the mean value does not appear to be converging (at least, not quickly enough). Increasing the max_iter parameter doesn't seem to be a viable option, because the number of time steps is doubled in each iteration (so the cpu-time doubles each iteration). The rtol parameter is also not useful in this case --- the logged mean_diff value is smallest in the first iteration, so either we take the mean coalescent time estimated in that first iteration, or we run forever. Perhaps there's some fine tuning that could be done here, or the time steps could be chosen differently?

import demes
import msprime
import daiquiri

test_case = """\
time_units: generations
defaults:
  epoch: {start_size: 1000}
demes:
- name: A
- name: B
  ancestors: [A]
  start_time: 3000
- name: C
  ancestors: [B]
  start_time: 2000
- name: D
  ancestors: [C]
  start_time: 1000
migrations:
- demes: [A, D]
  rate: 1e-5
"""

graph = demes.loads(test_case)
dbg = msprime.Demography.from_demes(graph).debug()
print(dbg)
daiquiri.setup(level="DEBUG")
t = dbg.mean_coalescence_time({"A": 1, "C": 1}, max_iter=20)

Output before I killed the run.

DemographyDebugger
╠══════════════════════════════════╗
║ Epoch[0]: [0, 1e+03) generations ║
╠══════════════════════════════════╝
╟    Populations (total=4 active=4)
║    ┌─────────────────────────────────────────────────────────────┐
║    │   │    start│      end│growth_rate  │   A   │ B │ C │   D   │
║    ├─────────────────────────────────────────────────────────────┤
║    │  A│   1000.0│   1000.0│ 0           │   0   │ 0 │ 0 │ 1e-05 │
║    │  B│   1000.0│   1000.0│ 0           │   0   │ 0 │ 0 │   0   │
║    │  C│   1000.0│   1000.0│ 0           │   0   │ 0 │ 0 │   0   │
║    │  D│   1000.0│   1000.0│ 0           │ 1e-05 │ 0 │ 0 │   0   │
║    └─────────────────────────────────────────────────────────────┘
╟    Events @ generation 1e+03
║    ┌──────────────────────────────────────────────────────────────────────────────────┐
║    │  time│type            │parameters         │effect                                │
║    ├──────────────────────────────────────────────────────────────────────────────────┤
║    │  1000│Population      │derived=[D],       │Moves all lineages from the 'D'       │
║    │      │Split           │ancestral=C        │derived population to the ancestral   │
║    │      │                │                   │'C' population. Also set 'D' to       │
║    │      │                │                   │inactive, and all migration rates to  │
║    │      │                │                   │and from the derived population to    │
║    │      │                │                   │zero.                                 │
║    │┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈│
║    │  1000│Migration rate  │source=D, dest=A,  │Backwards-time migration rate from D  │
║    │      │change          │rate=0             │to A → 0                              │
║    │┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈│
║    │  1000│Migration rate  │source=A, dest=D,  │Backwards-time migration rate from A  │
║    │      │change          │rate=0             │to D → 0                              │
║    └──────────────────────────────────────────────────────────────────────────────────┘
╠══════════════════════════════════════╗
║ Epoch[1]: [1e+03, 2e+03) generations ║
╠══════════════════════════════════════╝
╟    Populations (total=4 active=3)
║    ┌─────────────────────────────────────────────────┐
║    │   │    start│      end│growth_rate  │ A │ B │ C │
║    ├─────────────────────────────────────────────────┤
║    │  A│   1000.0│   1000.0│ 0           │ 0 │ 0 │ 0 │
║    │  B│   1000.0│   1000.0│ 0           │ 0 │ 0 │ 0 │
║    │  C│   1000.0│   1000.0│ 0           │ 0 │ 0 │ 0 │
║    └─────────────────────────────────────────────────┘
╟    Events @ generation 2e+03
║    ┌─────────────────────────────────────────────────────────────────────────┐
║    │  time│type        │parameters    │effect                                │
║    ├─────────────────────────────────────────────────────────────────────────┤
║    │  2000│Population  │derived=[C],  │Moves all lineages from the 'C'       │
║    │      │Split       │ancestral=B   │derived population to the ancestral   │
║    │      │            │              │'B' population. Also set 'C' to       │
║    │      │            │              │inactive, and all migration rates to  │
║    │      │            │              │and from the derived population to    │
║    │      │            │              │zero.                                 │
║    └─────────────────────────────────────────────────────────────────────────┘
╠══════════════════════════════════════╗
║ Epoch[2]: [2e+03, 3e+03) generations ║
╠══════════════════════════════════════╝
╟    Populations (total=4 active=2)
║    ┌─────────────────────────────────────────────┐
║    │   │    start│      end│growth_rate  │ A │ B │
║    ├─────────────────────────────────────────────┤
║    │  A│   1000.0│   1000.0│ 0           │ 0 │ 0 │
║    │  B│   1000.0│   1000.0│ 0           │ 0 │ 0 │
║    └─────────────────────────────────────────────┘
╟    Events @ generation 3e+03
║    ┌─────────────────────────────────────────────────────────────────────────┐
║    │  time│type        │parameters    │effect                                │
║    ├─────────────────────────────────────────────────────────────────────────┤
║    │  3000│Population  │derived=[B],  │Moves all lineages from the 'B'       │
║    │      │Split       │ancestral=A   │derived population to the ancestral   │
║    │      │            │              │'A' population. Also set 'B' to       │
║    │      │            │              │inactive, and all migration rates to  │
║    │      │            │              │and from the derived population to    │
║    │      │            │              │zero.                                 │
║    └─────────────────────────────────────────────────────────────────────────┘
╠════════════════════════════════════╗
║ Epoch[3]: [3e+03, inf) generations ║
╠════════════════════════════════════╝
╟    Populations (total=4 active=1)
║    ┌─────────────────────────────────────┐
║    │   │    start│      end│growth_rate  │
║    ├─────────────────────────────────────┤
║    │  A│   1000.0│   1000.0│ 0           │
║    └─────────────────────────────────────┘

2021-07-19 12:01:41,242 [533989] DEBUG    msprime.demography: iter    mean    P_diff    mean_diff last_P    adjust_type  num_steps  last_step
2021-07-19 12:01:41,288 [533989] DEBUG    msprime.demography: 1 4679.1 0 0.000817941 0.000549623 extend 122 18000
2021-07-19 12:01:41,336 [533989] DEBUG    msprime.demography: 2 4904.1 1.33227e-15 0.0458792 0.000549623 refine 243 18000
2021-07-19 12:01:41,429 [533989] DEBUG    msprime.demography: 3 4754.1 1.55431e-15 0.0315519 0.000549623 refine 485 18000
2021-07-19 12:01:41,614 [533989] DEBUG    msprime.demography: 4 4716.6 2.10942e-15 0.0079507 0.000549623 refine 969 18000
2021-07-19 12:01:41,985 [533989] DEBUG    msprime.demography: 5 4697.85 6.66134e-15 0.0039912 0.000549623 refine 1937 18000
2021-07-19 12:01:42,724 [533989] DEBUG    msprime.demography: 6 4594.72 2.64233e-14 0.0224442 0.000549623 refine 3873 18000
2021-07-19 12:01:44,193 [533989] DEBUG    msprime.demography: 7 4793.94 5.87308e-14 0.0415563 0.000549623 refine 7745 18000
2021-07-19 12:01:47,130 [533989] DEBUG    msprime.demography: 8 4728.32 7.43849e-14 0.0138791 0.000549623 refine 15489 18000
2021-07-19 12:01:53,048 [533989] DEBUG    msprime.demography: 9 4593.55 1.23901e-13 0.029338 0.000549623 refine 30977 18000
2021-07-19 12:02:04,898 [533989] DEBUG    msprime.demography: 10 4714.26 2.08999e-13 0.0256039 0.000549623 refine 61953 18000
2021-07-19 12:02:29,464 [533989] DEBUG    msprime.demography: 11 4791.6 6.47593e-13 0.0161415 0.000549623 refine 123905 18000
2021-07-19 12:03:21,502 [533989] DEBUG    msprime.demography: 12 4535.47 1.95299e-12 0.0564722 0.000549623 refine 247809 18000
2021-07-19 12:05:02,021 [533989] DEBUG    msprime.demography: 13 4810.02 4.04887e-12 0.0570784 0.000549623 refine 495617 18000
2021-07-19 12:08:32,834 [533989] DEBUG    msprime.demography: 14 4753.29 4.00124e-12 0.0119341 0.000549623 refine 991233 18000
2021-07-19 12:15:21,837 [533989] DEBUG    msprime.demography: 15 4605.72 7.76867e-12 0.0320414 0.000549623 refine 1982465 18000
^C

FYI, I applied the following change to avoid recalculating the trajectory in consecutive iterations.

diff --git a/msprime/demography.py b/msprime/demography.py
index febe2fe2..613e67ea 100644
--- a/msprime/demography.py
+++ b/msprime/demography.py
@@ -4020,7 +4020,7 @@ class DemographyDebugger:
         step_type = "none"
         n = 0
         logger.debug(
-            "iter    mean    P_diff    mean_diff last_P    adjust_type"
+            "iter    mean    P_diff    mean_diff last_P    adjust_type  "
             "num_steps  last_step"
         )
         # The factors of 20 here are probably not optimal: clearly, we need to
@@ -4032,13 +4032,14 @@ class DemographyDebugger:
             last_P > rtol or p_diff > rtol / 20 or m_diff > rtol / 20
         ):
             last_steps = steps
-            _, P1 = self.coalescence_rate_trajectory(
-                steps=last_steps,
-                lineages=lineages,
-                min_pop_size=min_pop_size,
-                double_step_validation=False,
-            )
-            m1 = mean_time(last_steps, P1)
+            if n == 0:
+                _, P1 = self.coalescence_rate_trajectory(
+                    steps=last_steps,
+                    lineages=lineages,
+                    min_pop_size=min_pop_size,
+                    double_step_validation=False,
+                )
+                m1 = mean_time(last_steps, P1)
             if last_P > rtol:
                 step_type = "extend"
                 steps = np.concatenate(
@@ -4073,6 +4074,8 @@ class DemographyDebugger:
                 len(steps),
                 max(steps),
             )
+            P1 = P2
+            m1 = m2

         if n == max_iter:
             raise ValueError(
grahamgower commented 3 years ago

~Oh and I think for this model the expected coalescence time between A and C should be the same as between A and B. Calculating the mean_coalescence_time between A and B converges after three iterations to 4364.~

EDIT: Nope. That's definitely wrong.

grahamgower commented 3 years ago

I found even more extreme behaviour by making deme B's start time a bit older. Alternate iterations of mean_coalescence_rate() obtain trajectories that are totally bogus. Plots below show p(t) and r(t) from coalescence_rate_trajectory() using steps that match the first four iterations of mean_coalescence_rate(). (code follows)

cr

2021-07-21 15:54:44,011 [221156] DEBUG    msprime.demography: iter    mean    P_diff    mean_diff last_P    adjust_type  num_steps  last_step
2021-07-21 15:54:44,057 [221156] DEBUG    msprime.demography: 1 899.37 0 0.00206987 0.000184327 extend 123 21600
2021-07-21 15:54:44,131 [221156] DEBUG    msprime.demography: 2 7049.69 1 0.872424 0.000406011 refine 245 21600
2021-07-21 15:54:44,277 [221156] DEBUG    msprime.demography: 3 899.37 1 6.83847 0.000184327 refine 489 21600
2021-07-21 15:54:44,568 [221156] DEBUG    msprime.demography: 4 7049.68 1 0.872424 0.000406011 refine 977 21600
2021-07-21 15:54:45,150 [221156] DEBUG    msprime.demography: 5 899.37 1 6.83847 0.000184327 refine 1953 21600
2021-07-21 15:54:46,301 [221156] DEBUG    msprime.demography: 6 7049.68 1 0.872424 0.000406011 refine 3905 21600
2021-07-21 15:54:48,620 [221156] DEBUG    msprime.demography: 7 899.37 1 6.83847 0.000184327 refine 7809 21600
2021-07-21 15:54:53,218 [221156] DEBUG    msprime.demography: 8 7049.68 1 0.872424 0.000406011 refine 15617 21600
2021-07-21 15:55:02,448 [221156] DEBUG    msprime.demography: 9 899.37 1 6.83847 0.000184327 refine 31233 21600
2021-07-21 15:55:02,449 [221156] CRITICAL root: Traceback (most recent call last):
  File "/home/grg/src/demes/demesdraw/misc/cr.py", line 132, in <module>
    t = dbg.mean_coalescence_time({"A": 1, "C": 1}, max_iter=9)
  File "/home/grg/src/msprime/msprime/demography.py", line 4078, in mean_coalescence_time
    raise ValueError(
ValueError: Did not converge on an adequate discretisation: Increase max_iter or rtol. Consult the log for debugging information
import logging
import itertools

import daiquiri
import numpy as np
import demes
import demesdraw
import msprime
import matplotlib
import matplotlib.pyplot as plt

logger = logging.getLogger(__name__)

def get_steps(dbg):
    # Get initial steps. Copied from mean_coalescence_time().
    last_N = max(dbg.population_size_history[:, dbg.num_epochs - 1])
    last_epoch = dbg.epoch_start_time[-1]
    steps = sorted(
        list(
            set(np.linspace(0, last_epoch + 12 * last_N, 101)).union(
                set(dbg.epoch_start_time)
            )
        )
    )
    return steps

def refine_steps(steps):
    # Double the number of steps. Copied from mean_coalescence_time().
    inter = steps[:-1] + np.diff(steps) / 2
    steps = np.concatenate([steps, inter])
    steps.sort()
    return steps

def get_axes(aspect=9 / 16, scale=1.5, **subplot_kwargs):
    """Make a matplotlib axes."""
    figsize = scale * plt.figaspect(aspect)
    fig, ax = plt.subplots(figsize=figsize, **subplot_kwargs)
    fig.set_tight_layout(True)
    return fig, ax

def get_line_plot_styles():
    linestyles = ["solid", "dashed", "dashdot"]
    linewidths = [1, 2, 4, 8]
    path_effects_lists = [
        [matplotlib.patheffects.withStroke(linewidth=2, foreground="white", alpha=0.7)],
        [matplotlib.patheffects.withStroke(linewidth=3, foreground="white", alpha=0.7)],
        [matplotlib.patheffects.withStroke(linewidth=5, foreground="white", alpha=0.7)],
        [],
    ]
    z_top = 1000  # Top of the z order stacking.
    return (
        dict(
            linestyle=linestyle,
            linewidth=linewidth,
            zorder=z_top - linewidth,
            alpha=0.7,
            solid_capstyle="butt",
            path_effects=path_effects,
        )
        for linestyle, linewidth, path_effects in zip(
            *map(itertools.cycle, (linestyles, linewidths, path_effects_lists))
        )
    )

def plot_figure(graph):
    fig, axs = get_axes(nrows=2, ncols=2, gridspec_kw=dict(width_ratios=[1, 2]))

    axs[1, 0].set_axis_off()
    ax_tubes = axs[0, 0]
    ax_cr = axs[0, 1]
    ax_cp = axs[1, 1]

    w = 1.3 * demesdraw.utils.size_max(graph)
    positions = dict(C=0, B=w, D=2 * w, A=3 * w)
    demesdraw.tubes(graph, ax=ax_tubes, log_time=True, positions=positions)

    style_cr = get_line_plot_styles()
    style_cp = get_line_plot_styles()

    dbg = msprime.Demography.from_demes(graph).debug()
    steps = get_steps(dbg)

    for _ in range(4):
        r, p = dbg.coalescence_rate_trajectory(
            steps, lineages=dict(A=1, C=1), double_step_validation=False
        )
        ax_cr.plot(steps, r, label=f"{len(steps)}", **next(style_cr))
        ax_cp.plot(steps, p, label=f"{len(steps)}", **next(style_cp))
        steps = refine_steps(steps)

    ax_cr.set_title("coalescence rate (lineages: A=1, C=1")
    ax_cp.set_title("Pr{A and C not coalesced}")
    ax_cr.set_ylabel("rate")
    ax_cp.set_ylabel("probability")
    for ax in (ax_cr, ax_cp):
        ax.set_xlabel("time ago (generations)")
        ax.legend(title="len(steps)")

    return fig

test_case = """\
time_units: generations
defaults:
  epoch: {start_size: 1000}
demes:
- name: A
- name: B
  ancestors: [A]
  start_time: 6000
- name: C
  ancestors: [B]
  start_time: 2000
- name: D
  ancestors: [C]
  start_time: 1000
migrations:
- demes: [A, D]
  rate: 1e-5
"""

graph = demes.loads(test_case)
fig = plot_figure(graph)
fig.savefig("/tmp/cr.png", dpi=200)

dbg = msprime.Demography.from_demes(graph).debug()
daiquiri.setup(level="DEBUG")
t = dbg.mean_coalescence_time({"A": 1, "C": 1}, max_iter=9)
grahamgower commented 3 years ago

How about we initialise the steps using a linspace for each epoch? This seems to work much better! (Converges after 6 iterations).

def get_steps(dbg):
    last_N = max(dbg.population_size_history[:, dbg.num_epochs - 1])
    last_epoch = dbg.epoch_start_time[-1]
    times = list(dbg.epoch_start_time) + [last_epoch + 12 * last_N]
    steps = set()
    for a, b in zip(times[:-1], times[1:]):
        steps.update(np.linspace(a, b, 101))
    steps = np.array(sorted(steps))
    return steps
jeromekelleher commented 3 years ago

Nicely done tracking this down @grahamgower - let's see what @petrelharp thinks. Maybe @apragsdale could take a peek also?

grahamgower commented 3 years ago

I've tracked this down a bit futher. It seems the _matrix_exponential() function is unstable. When replaced with scipy.linalg.expm, the behaviour is much improved.

petrelharp commented 3 years ago

Ah: _matrix_exponential( ) is based on an eigendecomposition, which isn't stable if there's very small eigenvalues. We didn't use scipy because we didn't want to introduce the dependency.

If we don't want to depend on scipy, here's an alternative algorithm: https://github.com/petrelharp/expm-experiment/blob/main/expm-simple.R#L48

grahamgower commented 3 years ago

Thanks @petrelharp. I agree we should avoid using the scipy implementation if possible. Aside from the extra dependency, it seems quite a lot slower than the eigendecomposition version. I did read that some matrix eponential methods allow one to determine when the algorthim will produce a poor approximation, thus allowing an error to be raised (e.g. checking if the matrix is almost singular). Do you know anything about this? Specifically, I was skimming through https://www.cs.cornell.edu/cvResearchPDF/19ways+.pdf (but I understood very little). In any event, I'll take a look at your R version.

petrelharp commented 3 years ago

I did read that some matrix eponential methods allow one to determine when the algorthim will produce a poor approximation, thus allowing an error to be raised (e.g. checking if the matrix is almost singular)

Right - we could check if one of the eigenvalues is too small, and raise an error - but, the algorithm I've got there in R is very robust. I think it's in the "19 dubious ways", but perhaps not prominently, because it only applies to stochastic matrices (which we have here), and that paper focuses on the general case.

I've got a lot on my plate at the moment, but I could swap out the expm algorithm if you like? Maybe the right order is to first write some tests that trigger this behavior and then see if swapping the algorithm fixes it?

petrelharp commented 3 years ago

I could swap out the expm algorithm if you like?

I did this!

jeromekelleher commented 3 years ago

Closed in #1788