popgenmethods / momi2

Infer demographic history with the Moran model
GNU General Public License v3.0
47 stars 11 forks source link

simulate_data() doesn't work with lambda params #60

Open jackkamm opened 1 year ago

jackkamm commented 1 year ago

When setting parameter via a lambda function, like described here: https://momi2.readthedocs.io/en/latest/tutorial.html#Using-functions-of-model-parameters-as-demographic-parameters

Then simulate_data() doesn't work correctly. The problem is due to DemographicModel.parameters being an OrderedDict instead of a ParamsDict.

Here's an example:

import momi
import autograd.numpy as anp

mod = momi.DemographicModel(10000)
mod.add_leaf("pop1")
mod.add_leaf("pop2")

mod.add_time_param("t_div")
mod.move_lineages("pop1", "pop2", t="t_div")

mod.move_lineages(pop1, pop2, t=lambda params: params.t_div / 2, p=.5)

dat = mod.simulate_data(
    length=1000000,
    num_replicates=10,
    recoms_per_gen=1e-8,
    muts_per_gen=1e-8,
    sampled_n_dict={"pop1": 10, "pop2": 10}
)

Which gives the following error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<string> in <module>

~/dev/momi2/momi/demo_model.py in simulate_data(self, length, recoms_per_gen, num_replicates, muts_per_gen, sampled_n_dict, **kwargs)
    480         :rtype: :class:`SnpAlleleCounts`
    481         """
--> 482         demo = self._get_demo(sampled_n_dict)
    483         if muts_per_gen is None:
    484             if not self.muts_per_gen:

~/dev/momi2/momi/demo_model.py in _get_demo(self, sampled_n_dict)
    526     # TODO rename to get_multipop_moran?
    527     def _get_demo(self, sampled_n_dict):
--> 528         sampled_n_dict = self._get_sample_sizes(sampled_n_dict)
    529 
    530         params_dict = self.get_params()

~/dev/momi2/momi/demo_model.py in _get_sample_sizes(self, sampled_n_dict)
    800             leaf_set = set(self.leafs)
    801             if not sampled_pops_set <= leaf_set:
--> 802                 raise ValueError("{} not in leaf populations".format(
    803                     sampled_pops_set - leaf_set))
    804             # make sure it is sorted in correct order

ValueError: {'north', 'south'} not in leaf populations

In [136]: 
ob_comint_async_python_start_a101ca440fe0cf5a237f4bf06cf1d596
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<string> in <module>

~/dev/momi2/momi/demo_model.py in simulate_data(self, length, recoms_per_gen, num_replicates, muts_per_gen, sampled_n_dict, **kwargs)
    485                 raise ValueError("Need to provide mutation rate")
    486             muts_per_gen = self.muts_per_gen
--> 487         return demo.simulate_data(
    488             length=length,
    489             recombination_rate=4*self.N_e*recoms_per_gen,

~/dev/momi2/momi/demography.py in simulate_data(self, length, num_replicates, **kwargs)
    309 
    310     def simulate_data(self, length, num_replicates=1, **kwargs):
--> 311         treeseq = self.simulate_trees(length=length, num_replicates=num_replicates,
    312                                       **kwargs)
    313         try:

~/dev/momi2/momi/demography.py in simulate_trees(self, **kwargs)
    419         demographic_events = []
    420         for e in self._G.graph["events"]:
--> 421             e = e.get_msprime_event(self._G.graph["params"], pops)
    422             if e is not None:
    423                 demographic_events.append(e)

~/dev/momi2/momi/events.py in get_msprime_event(self, params_dict, pop_ids_dict)
    295 
    296     def get_msprime_event(self, params_dict, pop_ids_dict):
--> 297         t = self.t(params_dict)
    298         i = _get_pop_id(self.pop1, pop_ids_dict)
    299         j = _get_pop_id(self.pop2, pop_ids_dict)

~/dev/momi2/momi/events.py in __call__(self, params_dict, scaled)
    435             x = params_dict[self.x]
    436         elif callable(self.x):
--> 437             x = self.x(params_dict)
    438         else:
    439             x = self.x

<string> in <lambda>(params)

AttributeError: 'collections.OrderedDict' object has no attribute 't_div'