tskit-dev / tsinfer

Infer a tree sequence from genetic variation data.
GNU General Public License v3.0
56 stars 13 forks source link

Add `ancestral_allele` and `variant_time` parameters to SgkitSampleData creator, with ability to provide numpy arrays #923

Closed hyanwong closed 1 month ago

hyanwong commented 5 months ago

@jeromekelleher said on slack

I would like to see SampleData (maybe renamed VariantData) being a thin wrapper around a VCF Zarr data set, like

sd = SampleData(path_or_url, *,
    variant_mask: Union[npt.ArrayLike, str],  # if str, this refers to a named array in the Zarr
    sample_mask: Union[npt.ArrayLike, str],  # ditto
    ancestral_allele: Union[npt.ArrayLike, str],  # ditto
    variant_time: Union[npt.ArrayLike, str],  # ditto - although should this be "time", "date", or "age"?
)

I guess the simplest thing would be to insist that variant_age and ancestral_allele arrays are with respect to the original Zarr (unmasked) coordinates?

This would make it easy to pass in new ages for reinference without messing with the datafile on-disk:

data = VariantData("myfile.vcz")
inferred_ts = tsinfer.infer(data)
dated_ts = tsdate.variational_gamma(tsdate.preprocess_ts(inferred_ts), mutation_rate=1e-8)
new_data = VariantData("myfile.vcz", variant_time=tsdate.util.sites_time_from_ts(dated_ts))
reinferred_ts = tsinfer.infer(new_data)
related_ts = tsdate.variational_gamma(tsdate.preprocess_ts(reinferred_ts), mutation_rate=1e-8)

And as Jerome says, it adds flexibility in the case that you don't have write access to original zarr

jeromekelleher commented 5 months ago

Wouldn't be quite that simple though, because of singletons etc. How would you get the array you pass in from the tree sequence timed sites to have the same length as the original Zarr? Easy enough to add utility functions to do so I guess.

hyanwong commented 5 months ago

I'm assuming (for didactic/teaching purposes) that we are not masking out any sites, and that singletons are phased.

I presume that pipelines with real data would, indeed, need to have some sort of wrapper functions. A VariantData method that padded out sites with a "missing data" value would indeed be convenient. E.g.

site_times_from_ts = tsdate.util.sites_time_from_ts(dated_ts)
all_site_times = data.fill_variant_mask(site_times_from_ts, fill_value=np.nan)  # new function
new_data = VariantData("myfile.vcz", variant_time=all_site_times)
jeromekelleher commented 5 months ago

I think it also needs the positions, though, or else it can't merge. Something like

site_times = data.pad_variants_array(dated_ts.sites_position.astype(int), site_times_from_ts, fill=np.nan)
hyanwong commented 5 months ago

I think it also needs the positions, though, or else it can't merge. Something like

site_times = data.pad_variants_array(dated_ts.sites_position.astype(int), site_times_from_ts, fill=np.nan)

I was thinking that the data object should know which sites were masked out. But I totally agree that it's much less prone to error if we use the positions. So picking up on your idea, I think the neatest thing would be to have a alternative tsdate function that also returns the positions: e.g. "mut_node_time_from_ts". Then we could simply do:

pos, times = tsdate.util.mut_node_time_from_ts(dated_ts)
data.pad_variants_array(pos, times, fill=np.nan)  # maybe try the (int) conversion for positions within the pad_ method?

I think we want the time of the node below the mutation (as we are really trying to use the best estimate of the time for a node, rather than the time of the mutation above that node), hence the suggestion for the method name. It's worth retaining this as a method of tsdate, as we ideally want to use the unconstrained times in the tsdate-encoded metadata, rather than the times from ts.nodes_time

hyanwong commented 5 months ago

Another thing: we relatively often want to provide a mask that calculates something off the Zarr file, e.g. "mask if variant_quality < 20". Should we recommend that this be done in SGkit, or (as below) by creating a numpy array in a preprocessing step, or by allowing the mask parameters to be a lambda function?

mask = zarr.load("demo.vcz")["variant_quality"] < 20
sd = SampleData("demo.vcz", variant_mask=mask)

# or allow a function that takes the zarr object as the only param
# probably not worth the complexity, unless it is greatly more efficient
sd = SampleData("demo.vcz", variant_mask=lambda z: z.variant_quality < 20)
jeromekelleher commented 5 months ago

I don't think we want to recommend doing any real compute in the constructor - in practise these masks will probably involve QC on call-level fields that'll need to be done in advance and saved somewhere.

So, in practise I think allowing the mask to be a function would just lead to complexity and confusion.