tskit-dev / tsdate

Infer the age of ancestral nodes in a tree sequence.
MIT License
19 stars 10 forks source link

AssertionError for big file while no error for small file #415

Closed qianfeng2 closed 4 months ago

qianfeng2 commented 4 months ago

Dear whom it may concern,

I hope everything is well with you.

I am running tsinfer+tsdate for a vcf file generated from SLiM. It works well for a small sized vcf file but returns an assertion error when the size of input vcf file is a bit larger. These two vcf files were from the same simulation using SLiM, the difference is population size increases from 10 to 1000. Would you please have a look at this whenever it is convenient for you? Thanks in advance.

The python script I used is as below (when running the large vcf file, it would cost ~5min). These two vcf files are also attached. small_file.vcf.zip large_file.vcf.zip

import cyvcf2
import tsinfer
import tsdate

inputdir=sys.argv[1]
vcf_location=inputdir+'/large_file.vcf'
with tsinfer.SampleData(sequence_length=2e7) as samples:
    for name in cyvcf2.VCF(vcf_location).samples:
        samples.add_individual(ploidy=2, metadata={"name": name})
    for variant in cyvcf2.VCF(vcf_location): 
        genotypes = [g for row in variant.genotypes for g in row[0:2]]
        pos=variant.POS;alleles=["A","T"]
        samples.add_site(pos, genotypes, alleles, ancestral_allele=0)
# Do the inference
ts = tsinfer.infer(samples)#print("Inferred tree sequence: {} trees over {} Mb ({} edges)".format(ts.num_trees, ts.sequence_length / 1e6, ts.num_edges))
simplified_ts = tsdate.preprocess_ts(ts)
redated_ts = tsdate.date(simplified_ts, mutation_rate=1.44e-8)

The error I get is shown below.

WARNING:tsdate.util:Could not set 'unsplit_node_id' on node metadata
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[29], line 2
      1 simplified_ts = tsdate.preprocess_ts(ts)
----> 2 redated_ts = tsdate.date(simplified_ts, mutation_rate=1.44e-8)

File ~/anaconda3/lib/python3.11/site-packages/tsdate/core.py:1773, in date(tree_sequence, mutation_rate, recombination_rate, time_units, method, constr_iterations, return_posteriors, return_likelihood, progress, record_provenance, **kwargs)
   1770 if method not in estimation_methods:
   1771     raise ValueError(f"method must be one of {list(estimation_methods.keys())}")
-> 1773 return estimation_methods[method](
   1774     tree_sequence,
   1775     mutation_rate=mutation_rate,
   1776     recombination_rate=recombination_rate,
   1777     time_units=time_units,
   1778     progress=progress,
   1779     constr_iterations=constr_iterations,
   1780     return_posteriors=return_posteriors,
   1781     return_likelihood=return_likelihood,
   1782     record_provenance=record_provenance,
   1783     **kwargs,
   1784 )

File ~/anaconda3/lib/python3.11/site-packages/tsdate/core.py:1659, in variational_gamma(tree_sequence, mutation_rate, eps, max_iterations, rescaling_intervals, max_shape, match_central_moments, match_segregating_sites, regularise_roots, **kwargs)
   1653     raise ValueError(
   1654         "No mutations present: these are required for the variational_gamma method"
   1655     )
   1656 dating_method = VariationalGammaMethod(
   1657     tree_sequence, mutation_rate=mutation_rate, **kwargs
   1658 )
-> 1659 result = dating_method.run(
   1660     eps=eps,
   1661     max_iterations=max_iterations,
   1662     max_shape=max_shape,
   1663     match_central_moments=match_central_moments,
   1664     rescaling_intervals=rescaling_intervals,
   1665     match_segregating_sites=match_segregating_sites,
   1666     regularise_roots=regularise_roots,
   1667 )
   1668 return dating_method.parse_result(result, eps, {"parameter": ["shape", "rate"]})

File ~/anaconda3/lib/python3.11/site-packages/tsdate/core.py:1275, in VariationalGammaMethod.run(self, eps, max_iterations, max_shape, match_central_moments, rescaling_intervals, match_segregating_sites, regularise_roots)
   1273 min_kl = not match_central_moments
   1274 dynamic_prog = self.main_algorithm()
-> 1275 dynamic_prog.run(
   1276     ep_maxitt=max_iterations,
   1277     max_shape=max_shape,
   1278     min_kl=min_kl,
   1279     rescale_intervals=rescaling_intervals,
   1280     regularise=regularise_roots,
   1281     rescale_segsites=match_segregating_sites,
   1282     progress=self.pbar,
   1283 )
   1285 # TODO: use dynamic_prog.point_estimate
   1286 posterior_mean, posterior_vari = self.mean_var(
   1287     dynamic_prog.posterior, dynamic_prog.constraints
   1288 )

File ~/anaconda3/lib/python3.11/site-packages/tsdate/variational.py:701, in ExpectationPropagation.run(self, ep_maxitt, max_shape, min_step, min_kl, rescale_intervals, rescale_segsites, regularise, progress)
    699 if rescale_intervals > 0:
    700     rescale_timing = time.time()
--> 701     self.rescale(
    702         rescale_intervals=rescale_intervals, rescale_segsites=rescale_segsites
    703     )
    704     rescale_timing -= time.time()
    705     logging.info(f"Timescale rescaled in {abs(rescale_timing)} seconds")

File ~/anaconda3/lib/python3.11/site-packages/tsdate/variational.py:636, in ExpectationPropagation.rescale(self, rescale_intervals, rescale_segsites, use_median, quantile_width)
    626 nodes_time = self._point_estimate(self.posterior, self.constraints, use_median)
    627 original_breaks, rescaled_breaks = mutational_timescale(
    628     nodes_time,
    629     self.likelihoods,
   (...)
    634     rescale_intervals,
    635 )
--> 636 self.posterior[:] = piecewise_scale_posterior(
    637     self.posterior,
    638     original_breaks,
    639     rescaled_breaks,
    640     quantile_width,
    641     use_median,
    642 )
    643 self.mutations_posterior[:] = piecewise_scale_posterior(
    644     self.mutations_posterior,
    645     original_breaks,
   (...)
    648     use_median,
    649 )

File ~/anaconda3/lib/python3.11/site-packages/tsdate/rescaling.py:233, in piecewise_scale_posterior()
    230     midpt[i] /= beta
    232 # rescale quantiles
--> 233 assert np.all(np.diff(rescaled_breaks) > 0)
    234 assert np.all(np.diff(original_breaks) > 0)
    235 scalings = np.append(np.diff(rescaled_breaks) / np.diff(original_breaks), 0)

AssertionError: 

I have difficulty in understanding this error. Looking forward to hearing from you:)

Thanks, Qian

nspope commented 4 months ago

Hi, could you try setting the argument rescaling_intervals to a smaller number, like 100 or 10? What's happening is that there's not enough mutations for one of the steps of the algorithm, under the default setting. We'll add a more informative error message.

qianfeng2 commented 4 months ago

Hi nspope,

Thanks for your quick response. Importantly, your suggestion solves my problem! That's so helpful!!

Kind regards, Qian