tskit-dev / tsviz

Visualisation tools for tree sequences
MIT License
5 stars 3 forks source link

Spatial equivalent of GNN #7

Open hyanwong opened 2 years ago

hyanwong commented 2 years ago

I discussed with @percyfal some ideas of how to plot a GNN equivalent when we don't have a set of defined subpopulations. One possibility is to look at the nearest neighbours and to summarize their x/y locations, perhaps using an average location and some sort of variance.

Here's a SLiM script we could use for testing (try running it in the SLiM GUI: needs the world_map_540x217.png file from the SLiM docs):

// Keywords: continuous space, continuous spatial landscape, spatial map, reprising boundaries

initialize() {
    initializeSLiMOptions(dimensionality="xy");
    initializeTreeSeq();
    initializeMutationRate(1e-7);
    initializeMutationType("m1", 0.5, "f", 0.0);
    initializeGenomicElementType("g1", m1, 1.0);
    initializeGenomicElement(g1, 0, 99999);
    initializeRecombinationRate(1e-8);

    // spatial competition
    initializeInteractionType(1, "xy", reciprocal=T, maxDistance=30.0);
    i1.setInteractionFunction("n", 5.0, 10.0);

    // spatial mate choice
    initializeInteractionType(2, "xy", reciprocal=T, maxDistance=30.0);
    i2.setInteractionFunction("n", 1.0, 10.0);
}
1 late() {
    sim.addSubpop("p1", 1000);

    p1.setSpatialBounds(c(0.0, 0.0, 539.0, 216.0));

    // this file is in the recipe archive at http://benhaller.com/slim/SLiM_Recipes.zip
    mapImage = Image("~/Downloads/SLiM_Recipes/world_map_540x217.png");
    p1.defineSpatialMap("world", "xy", 1.0 - mapImage.floatK,
        valueRange=c(0.0, 1.0), colors=c("#0000CC", "#55FF22"));

    // start near a specific map location
    for (ind in p1.individuals) {
        ind.x = rnorm(1, 300.0, 1.0);
        ind.y = rnorm(1, 100.0, 1.0);
    }
}
1: late() {
    i1.evaluate();
    inds = sim.subpopulations.individuals;
    competition = i1.totalOfNeighborStrengths(inds) / size(inds);
    competition = pmin(competition, 0.99);
    inds.fitnessScaling = 1.0 - competition;
}
first() {
    i2.evaluate();
}
mateChoice() {
    return i2.strength(individual);
}
modifyChild() {
    do pos = parent1.spatialPosition + rnorm(2, 0, 0.5);
    while (!p1.pointInBounds(pos));

    // prevent dispersal into water
    if (p1.spatialMapValue("world", pos) == 0.0)
        return F;

    child.setSpatialPosition(pos);
    return T;
}
3000 late() {
    sim.treeSeqOutput("~/Downloads/SLiM_Recipes/spatial.trees");
}

And here's a rough, non-efficient, way of plotting out the spatial GNN (lots could be improved, and I think there's a bug, but putting it down here for reference):

import tskit
import numpy as np
import matplotlib.pyplot as plt

ts = tskit.load("/Users/Yan/Downloads/SLiM_Recipes/spatial.trees")

GNN_x = np.zeros(ts.num_samples)
GNN_y = np.zeros(ts.num_samples)
for u in ts.samples():
    loc = ts.individual(ts.node(u).individual).location
    GNN_x[u] = loc[0]
    GNN_y[u] = loc[1]

GNN_u = np.zeros(ts.num_samples)
GNN_v = np.zeros(ts.num_samples)

for tree in ts.trees():
    for u in ts.samples():
        x = 0
        y = 0
        n = 0
        focal_node = tree.parent(u)
        for s in tree.samples(focal_node):
            if s != u:
                loc = ts.individual(ts.node(s).individual).location
                n += 1
                x += GNN_x[u] - loc[0]
                y += GNN_y[u] - loc[1]
        GNN_u[u] += x * tree.span
        GNN_v[u] += y * tree.span
GNN_u /= ts.sequence_length
GNN_v /= ts.sequence_length

fig, ax = plt.subplots(figsize=(20, 15))
q = ax.quiver(GNN_x, GNN_y, GNN_u, GNN_v, units='xy')
percyfal commented 2 years ago

Hi @hyanwong, finally I found the time to give your code a go! In preparing for an analysis of real data, I decided to apply your function to a data set simulated in slendr, and plotted in geoviews, just to get a hang of the entire workflow. I'm posting the code I wrote for reference.

To begin with, I used the slendr tutorial example, a toy model of the history of modern humans in West Eurasia after the Out of Africa migration. I used the following parts of the code (NB: requires tibble >= 3.1.6):

library(slendr)
library(dplyr)

# Define world and regions
map <- world(
  xrange = c(-15, 60), # min-max longitude
  yrange = c(20, 65),  # min-max latitude
  crs = "EPSG:3035"    # coordinate reference system (CRS) for West Eurasia
)
africa <- region(
  "Africa", map,
  polygon = list(c(-18, 20), c(40, 20), c(30, 33),
                 c(20, 32), c(10, 35), c(-8, 35))
)
europe <- region(
  "Europe", map,
  polygon = list(
    c(-8, 35), c(-5, 36), c(10, 38), c(20, 35), c(25, 35),
    c(33, 45), c(20, 58), c(-5, 60), c(-15, 50)
  )
)
anatolia <- region(
  "Anatolia", map,
  polygon = list(c(28, 35), c(40, 35), c(42, 40),
                 c(30, 43), c(27, 40), c(25, 38))
)

# Define populations
afr <- population( # African ancestral population
  "AFR", parent = "ancestor", time = 52000, N = 3000,
  map = map, polygon = africa
)
ooa <- population( # population of the first migrants out of Africa
  "OOA", parent = afr, time = 51000, N = 500, remove = 25000,
  center = c(33, 30), radius = 400e3
) %>%
  move(
    trajectory = list(c(40, 30), c(50, 30), c(60, 40)),
    start = 50000, end = 40000, snapshots = 20
  )
ehg <- population( # Eastern hunter-gatherers
  "EHG", parent = ooa, time = 28000, N = 1000, remove = 6000,
  polygon = list(
    c(26, 55), c(38, 53), c(48, 53), c(60, 53),
    c(60, 60), c(48, 63), c(38, 63), c(26, 60))
)
eur <- population( # European population
  name = "EUR", parent = ehg, time = 25000, N = 2000,
  polygon = europe
)
ana <- population( # Anatolian farmers
  name = "ANA", time = 28000, N = 3000, parent = ooa, remove = 4000,
  center = c(34, 38), radius = 500e3, polygon = anatolia
) %>%
  expand_range( # expand the range by 2.500 km
    by = 2500e3, start = 10000, end = 7000,
    polygon = join(europe, anatolia), snapshots = 20
  )
yam <- population( # Yamnaya steppe population
  name = "YAM", time = 7000, N = 500, parent = ehg, remove = 2500,
  polygon = list(c(26, 50), c(38, 49), c(48, 50),
                 c(48, 56), c(38, 59), c(26, 56))
) %>% move(trajectory = list(c(15, 50)), start = 5000, end = 3000, snapshots = 10)

# Define gene flow events
gf <- list(
  gene_flow(from = ana, to = yam, rate = 0.5, start = 6500, end = 6400, overlap = FALSE),
  gene_flow(from = ana, to = eur, rate = 0.5, start = 8000, end = 6000),
  gene_flow(from = yam, to = eur, rate = 0.75, start = 4000, end = 3000)
)

# Compile model
model <- compile_model(
  populations = list(afr, ooa, ehg, eur, ana, yam), # populations defined above
  gene_flow = gf,
  generation_time = 30,
  resolution = 10e3, # resolution in meters per pixel
  competition = 130e3, mating = 100e3, # spatial interaction parameters
  dispersal = 70e3, # how far can offspring end up from their parents
)

# Run model in SLiM
slim(model, sequence_length = 10e6, recombination_rate = 1e-8,
     save_locations = TRUE, method = "batch", random_seed = 314159)

After slim has been run, we need to save the model coordinates corresonding to each sample. This information will later be used to update tree sequence metadata.

ts <- ts_load(model, file=file.path(model$path, "output_slim.trees"))
cnames = c("pedigree_id", "location_x", "location_y")
individuals <- ts_data(ts) %>% dplyr::distinct(ind_id, .keep_all = TRUE) %>% dplyr::mutate(location_x = as.vector(sf::st_coordinates(location)[, 1]), location_y = as.vector(sf::st_coordinates(location)[, 2])) %>% dplyr::select(pedigree_id, location_x, location_y) %>% as.data.frame
write.table(individuals[, cnames], sep="\t", row.names=FALSE, file="individuals.metadata.tsv")

Standing in the simulation output directory, we load tree sequences and metadata to modify the tree sequence metadata.

import tskit
import pandas as pd
ts = tskit.load("output_slim.trees")

# Dump tables
tables = ts.dump_tables()
metadata = pd.read_table("individuals.metadata.tsv")
metadata.set_index(['pedigree_id'], inplace=True)

schema_dict = tables.individuals.metadata_schema.asdict()
tables.individuals.clear()
from collections import OrderedDict
schema_dict["properties"]["location_x"] = OrderedDict({'binaryFormat': 'f', 'description': 'x coordinate (longitude) in EPSG:3035', 'index': 8, 'type': 'number', 'default': 0.0})
schema_dict["properties"]["location_y"] = OrderedDict({'binaryFormat': 'f', 'description': 'y coordinate (latitude) in EPSG:3035', 'index': 9, 'type': 'number', 'default': 0.0})
tables.individuals.metadata_schema = tskit.MetadataSchema(schema_dict)

for row in list(ts.tables.individuals):
    md = row.metadata
    d = metadata.loc[md['pedigree_id']].to_dict()
    md.update(**d)
    row = row.replace(metadata=md)
    tables.individuals.append(row)

# Make new tree sequence
ts = tables.tree_sequence()

Next we calculate gnn coordinates and plot using geoviews

import geoviews as gv
gv.extension('bokeh')
import geoviews.tile_sources as gts
from geoviews import opts
# Needed to project EPSG:3035 coordinates correctly
from cartopy import crs
import numpy as np
from tqdm import tqdm
np.random.seed(79)

# Select a subset of the 10000 samples; otherwise takes 1-2h
# to calculate gnn
samplesidx = sorted(np.random.choice(np.arange(ts.num_samples), 50, replace=False))

d = pd.DataFrame(0, index=np.arange(len(samplesidx)), columns=['id', 'pedigree_id', 'x', 'y', 'u', 'v'])
d['id'] = samplesidx
d.set_index('id', inplace=True)
for u in samplesidx:
    md = ts.individual(ts.node(u).individual).metadata
    d.loc[u]['x'] = md["location_x"]
    d.loc[u]['y'] = md["location_y"]
    d.loc[u]['pedigree_id'] = md["pedigree_id"]
d['color'] = 'blue'

# Calculate GNN coordinates
for tree in tqdm(ts.trees()):
    for u in samplesidx:
        x = 0
        y = 0
        n = 0
        focal_node = tree.parent(u)
        for s in tree.samples(focal_node):
            if s != u:
                loc_x = ts.individual(ts.node(s).individual).metadata["location_x"]
                loc_y = ts.individual(ts.node(s).individual).metadata["location_y"]
                n += 1
                x += d.loc[u]['x'] - loc_x
                y += d.loc[u]['y'] - loc_y
        d.loc[u]['u'] += x * tree.span / n
        d.loc[u]['v'] += y * tree.span / n

# Average coordinate of GNN
d['u'] /= ts.sequence_length
d['v'] /= ts.sequence_length
d.reset_index(inplace=True)

# Make gnn data frame that holds the gnn coordinates
gnn = pd.DataFrame(-1, index=np.arange(len(samplesidx)), columns=['id', 'pedigree_id', 'x', 'y', 'u', 'v'])
gnn['id'] = d['id'] + ts.num_samples
gnn['x'] = d['x'].values + d['u'].values
gnn['y'] = d['y'].values + d['v'].values
gnn['color'] = 'green'

# Make geoviews points object of samples and gnns
p = pd.concat([d, gnn])
neighbours = pd.DataFrame(dict(sample_id=d['id'], gnn_id=gnn['id']))
points = gv.Points(p, ['x', 'y'], crs=crs.epsg(3035))
tiles = gv.tile_sources.Wikipedia
nodes = gv.Nodes(points, ['x', 'y', 'id'], ['pedigree_id', 'color'])
graph = gv.Graph((neighbours, nodes), ['sample_id', 'gnn_id'], [])
(tiles * graph).opts(opts.Graph(node_size=8, width=800, height=800, directed=True, arrowhead_length=0.01, node_color='color', edge_line_alpha=0.5, node_alpha=.5, edge_line_width=2))

There must be some glaring error still as one of the gnns is in Madagascar when all should be either Northern Africa or Europe (see attached image)...

spatialplot

petrelharp commented 2 years ago

Very nice! There is I think a minor bug - I think that this bit

        d.loc[u]['u'] += x * tree.span
        d.loc[u]['v'] += y * tree.span

should be

        d.loc[u]['u'] += x * tree.span / n
        d.loc[u]['v'] += y * tree.span / n
percyfal commented 2 years ago

Thanks @petrelharp that seems to do the trick; at least the Madagascan gnns have migrated northward :)

percyfal commented 2 years ago

gnnmap

petrelharp commented 2 years ago

Very nice!

And, hm, the fact that that one sample had GNN in Madagascar tells us that something was different about that sample - the error was worse for samples with lots of nearest neighbors, i.e., samples that don't coalesce into the tree very recently. So - that sample seems to be an outlier? Perhaps that could be communicated somehow? Any ideas?

bodkan commented 2 years ago

Hey @percyfal, thanks for sharing the link over in the slendr repo to your results here

In case someone decides to follow up on your slendr code, I would suggest to run the SLiM simulation simply as:

slim(model, sequence_length = 10e6, recombination_rate = 1e-8, method = "batch", random_seed = 314159)`

without the save_locations = TRUE. This option is mostly useful for debugging and for making cool GIFs for slides but not much else (it records a table of locations of everyone who ever lived, regardless of whether they are among the ancestors of the final set of samples, so it makes things extremely inefficient). The locations are saved in the tree sequence metadata even without this option.

Also, for a simple simulation like this, the tree sequence can be loaded simply by ts <- ts_load(model) (no path specification is necessary, because the tree sequence is saved with the model data by default, so the path is known in this case).

I've been thinking about spatial GNN things lately, so I'll be following this issue closely! Cool stuff!