tskit-dev / tskit

Population-scale genomics
MIT License
147 stars 69 forks source link

Windowed Genealogical Nearest Neighbours (GNN) #665

Open awohns opened 4 years ago

awohns commented 4 years ago

@hyanwong, Gil and I have recently discussed details of a "windowed" GNN, which would use either span- or time-based windows to calculate the nearest neighbours of a set of focal nodes. Here's a proposal of what we hope to implement.

Background: The currently-implemented GNN function definition is the following: genealogical_nearest_neighbours(focal, sample_sets, num_threads=0)

The function returns "An 𝑛 by 𝑚 array of focal nodes by GNN proportions. Every focal node corresponds to a row. The numbers in each row corresponding to the GNN proportion for each of the passed-in reference sets. Rows therefore sum to one."

As detailed in the mathematical definition of a GNN here, the GNN is calculated on a per-tree basis, and then normalised by the span of each tree.

Proposal We aim to generalise the GNN function to calculate the GNN using two types of windows: span-based windows (for example, each tree in the tree sequence) and time-based windows.

This change will not modify the per-tree calculation of the GNN.

New Function Definition genealogical_nearest_neighbours(focal, sample_sets, windows, time_windows, num_threads=0)

If windows and time_windows are None, the function returns an 𝑛 by 𝑚 array exactly as in the current implementation. If windows is not None, an r by m by n array of focal nodes by GNN proportions by windows is outputted. Every focal node corresponds to a 2d array of r by m. For each focal node at each window, a 1d array correspond to the GNN proportion for each of the passed-in reference sets at that positional window, and should all sum to one. Similarly, if time_windows is not None, an s by m array for each focal node is returned. If both windows and time_windows is used, a 4d, r by s by m by n array is outputted.

Unresolved points and questions: -We were aiming for the syntax to be consistent with the stats framework. @petrelharp, do you have any thoughts on this?

jeromekelleher commented 4 years ago

Some quick answers here @awohns

I think the side-by-side windows thing is pretty well understood and we have existing machinery for making it work which we can hook into. Can you give an example of what you're thinking in terms of the time windowing please?

hyanwong commented 4 years ago

Also, how do we ask for tree-wise positions? Presumably we could simply pass ts.breakpoints() as the windows, but should we stick to how the stats API does it? Also, should we allow ts.breakpoints() (i.e. an iterator) or require an actual array to be passed in?

jeromekelleher commented 4 years ago

Everything to do with (span) windows should be done exactly as it's done for the stats API @hyanwong - there's mechanisms in there for all sorts of stuff.

petrelharp commented 4 years ago

I agree with Jerome; I imagine you've looked back at the documentation, btw? (I ask since this is a good chance to make sure that what's written there is clear!)

This will be a good test case for seeing how to do time windows more generally in the future (I hope). Happy to help interpret the code!

hyanwong commented 4 years ago

I've just been chatting to Wilder. What is the preferred thing to return when the window has no information - e.g. a time window from 10-20 when the oldest root in the ts is at time 1? This is rather like the issue of what (normalised) stat to return for a window of the tree sequence that is entirely within a region where all the edges have been deleted (e.g. with delete_intervals). I think we should return NaNs for the GNN proportions in these regions, at it doesn't make sense to return 0. Other suggestions @petrelharp ?

jeromekelleher commented 4 years ago

I'm having trouble understanding what the output is - can we see some simple example code here?

jeromekelleher commented 4 years ago

Just the simplest, most naive version of time-aware GNN in Python would be really helpful now.

hyanwong commented 4 years ago

I think @awohns has something mocked up. He can post it as a draft PR for discussion.

jeromekelleher commented 4 years ago

Could just plop the Python code in here also - it should only be 20 lines long or so.

awohns commented 4 years ago

Just cleaning it up a bit and will post as draft PR momentarily!

awohns commented 4 years ago

Ok I'll just put the rough version here then

petrelharp commented 4 years ago

What is the preferred thing to return when the window has no information - e.g. a time window from 10-20 when the oldest root in the ts is at time 1?

Either 0 or nan, as you say! In general, if it's something you add up across individuals or windows or something then 0, while if it has a denominator that's zero, then nan - I'd have to double-check the GNN definition to make sure, but it has a denominator, right? So nan, as you say?

awohns commented 4 years ago

Here's a version of the time-aware windowing only (i.e. no positional, span-based windows) This version returns nan when a time window has no information.

def genealogical_nearest_neighbours_time(ts, focal, reference_sets, time_windows=None):
    reference_set_map = np.zeros(ts.num_nodes, dtype=int) - 1
    for k, reference_set in enumerate(reference_sets):
        for u in reference_set:
            if reference_set_map[u] != -1:
                raise ValueError("Duplicate value in reference sets")
            reference_set_map[u] = k

    # Always make an extra time window for nodes older than greatest breakpoint
    W = len(time_windows)
    K = len(reference_sets)
    A = np.zeros((W, len(focal), K))
    parent = np.zeros(ts.num_nodes, dtype=int) - 1
    sample_count = np.zeros((ts.num_nodes, K), dtype=int)
    node_ages = ts.tables.nodes.time[:]
    time_normalisations = np.zeros((W, len(focal)))

    # Set the initial conditions.
    for j in range(K):
        sample_count[reference_sets[j], j] = 1

    for (left, right), edges_out, edges_in in ts.edge_diffs():
        for edge in edges_out:
            parent[edge.child] = -1
            v = edge.parent
            while v != -1:
                sample_count[v] -= sample_count[edge.child]
                v = parent[v]
        for edge in edges_in:
            parent[edge.child] = edge.parent
            v = edge.parent
            while v != -1:
                sample_count[v] += sample_count[edge.child]
                v = parent[v]

        # Process this tree.
        for j, u in enumerate(focal):
            focal_reference_set = reference_set_map[u]
            delta = int(focal_reference_set != -1)
            p = u
            while p != tskit.NULL:
                total = np.sum(sample_count[p])
                if total > delta:
                    break
                p = parent[p]
            if p != tskit.NULL:
                span = right - left
                scale = span / (total - delta)
                time_index = np.searchsorted(time_windows, node_ages[p]) - 1
                time_normalisations[time_index, j] += span
                for k, _reference_set in enumerate(reference_sets):
                    n = sample_count[p, k] - int(focal_reference_set == k)
                    A[time_index, j, k] += n * scale         
    A /= time_normalisations.reshape((W, len(focal), 1))
    return A
jeromekelleher commented 4 years ago

Thanks @awohns, this is very helpful. This does look nice and simple. I've tweaked the code a bit to take the time_index lookup out of the inner loop. I'm not entirely certain that we're doing the normalisation right, but let's see what @petrelharp has to say about that.

awohns commented 4 years ago

Great, thanks @jeromekelleher. I think the way to check the normalisation is that for each focal node/time window value, the GNN values for the m reference sets should sum to 1, but interested to hear @petrelharp's thoughts!

hyanwong commented 4 years ago

Also we should decide what to do when the uppermost time window is younger than ts. max_root_time. In this case, I think the searchsorted() call could return an index off the top off the array. You could argue that we should create an extra bin for everything above the oldest time window in this case, but then the number of time windows could be the number passed in plus one, which is a bit confusing. Or we could drop these bits of the GNN, but if so perhaps we should issue a warning that the GNN doesn't cover all possible times. Thoughts @awohns ?

hyanwong commented 4 years ago

I agree with Jerome; I imagine you've looked back at the documentation, btw? (I ask since this is a good chance to make sure that what's written there is clear!)

@petrelharp: just to clarify a (very minor) point in the stats docs, off the back of what I wrote above about ts.breakpoints(). The docs say "windows should be a list of n+1 increasing numbers beginning with 0 and ending with the sequence_length", but you also say "windows = 'trees' is equivalent to passing windows=ts.breakpoints()". That implies that windows need not be a list but could also be an iterator, e.g. as returned by ts.breakpoints(). If iterators are allowed, perhaps you should say something like:

"The canonical way to specify windows is by providing a list of n+1 increasing numbers beginning with 0 and ending with the sequence_length"

?

awohns commented 4 years ago

You could argue that we should create an extra bin for everything above the oldest time window in this case, but then the number of time windows could be the number passed in plus one, which is a bit confusing. Or we could drop these bits of the GNN, but if so perhaps we should issue a warning that the GNN doesn't cover all possible times.

I personally think that not returning an extra time bin and just leaving it up to the user to give complete windows is the best way to go. The windowed stats return a _tskit.LibraryError if windows don't cover the whole tree sequence, we could do the same? However, bombing out like this is not as straightforward for the user because we the maximum node time is not as easily obtainable as the sequence_length. On the other hand I'm not sure if it makes sense to return a different number of time windows than the user specified. On balance I would go with the approach that's analagous to the windowed stat though. Interested to hear other opinions!

jeromekelleher commented 4 years ago

I personally think that not returning an extra time bin and just leaving it up to the user to give complete windows is the best way to go. The windowed stats return a _tskit.LibraryError if windows don't cover the whole tree sequence, we could do the same?

We should treat the two windowing dimensions as directly analogous as possible. For now, let's assume that the user provides the correct time windows, and we throw an error if we end up with a time value outside of this. If it seems like a good idea, we can add an option to say "ignore any values outside the windows" or something later, but let's keep it simple and assume the user knows everything.

petrelharp commented 4 years ago

Well, whether it's right or not depends on your goal, but as currently written, here's what I think you are computing: suppose for simplicity that a single focal node is not in any reference set, and that there are no samples other than these. Then, it is: the proportion of modern genome descended from MRCAs of the focal node in each of the reference sets, split out by time period of the MRCA. (Here MRCA is the MRCA of the focal sample with any other sample.) As Wilder says, this will produce proportions that sum to one for each (time period, window, focal node) combination. (I'm sure you know this is what you're computing, but I'm writing it out for confirmation?)

The other possible normalization, I suppose, is if you wanted proportions that summed to 1 across all time windows. But the end user could obtain this if they could just skip the last division by time_normalizations, and then normalize by (window, focal node).

For other statistics, what I had in mind was to produce numbers that if you added them up across time windows you'd get the value for the time window [0, Inf), which would I suppose be this second normalization. This is like a time_normalise argument, I guess, and you've written the time_normalize=True version, while I'm imagining doing the False version by default? This statistic is different enough I don't think they necessarily have to agree. But, I could imagine wanting to do it the other way - it depends if you want to look at the breakdown of nearest neighbors separately by time period, or if you want to also see how the MRCA is spread out over time. (Oh, that's a good use of time_normalise=False: you could then normalise separately by reference set, and obtain the distribution of TMRCAs separately by identity of the reference set. (I think?)

So, in summary, I guess I vote for implementing a time_normalise option, that skips that normalisation, since that'll let you retain all the relevant information. I don't have strong opinions about what the defaults should be.

hyanwong commented 3 years ago

I personally think that not returning an extra time bin and just leaving it up to the user to give complete windows is the best way to go. The windowed stats return a _tskit.LibraryError if windows don't cover the whole tree sequence, we could do the same?

Linking to https://github.com/tskit-dev/tskit/issues/202 as we might not keep this behaviour for partial windows.

percyfal commented 2 years ago

@hyanwong As requested I'm following up on the discussion in #2026, where I was asking about the function local_gnn. I did go through the code, and now that I've reread the methods on the GNN definition, I think I got it - most of it anyway.

With regards to chipping in, it seems that this issue is partly resolved by #683 and that it mainly is the C API that needs updating (#1237) - is that correct?

hyanwong commented 2 years ago

I think so: @awohns you know more about this than I do. Would it be helpful if @percyfal were to finish it off somehow (assume he's willing!)

hyanwong commented 2 months ago

@sinanshi has implemented something very similar to time-windowed GNNs, on the ARG structures output by ARG-needle / THREADs, and called it "local ancestry". I'm hoping that we will be able to make the implementation here compatible with his definitions. I'm syncing up with him about this.