jthlab / momi3

MIT License
4 stars 0 forks source link

TypeError after command momi.optimize() #3

Open quank opened 1 month ago

quank commented 1 month ago

Here is a toy example using momi3 to optimize the inference that I have just tried

the commands and the TypeError info are attached, any suggestions are appreciated

import demes
import demesdraw

b = demes.Builder()
b.add_deme("ABC", epochs = [dict(start_size = 1200, end_time = 800)])
b.add_deme("AB", ancestors = ["ABC"], epochs = [dict(start_size = 1000, end_time = 550)])
b.add_deme("C", ancestors = ["ABC"], epochs = [dict(start_size = 300)])
b.add_deme("A", ancestors = ["AB"], epochs = [dict(start_size = 800)])
b.add_deme("B", ancestors = ["AB"], epochs = [dict(start_size = 700)])
b.add_pulse(sources = ["C"], dest = "B", time = 240, proportions = [0.28])
g = b.resolve()

demesdraw.tubes(g)

from momi3.MOMI import Momi

sampled_demes = ["A", "B", "C"]
sample_sizes = [20, 30, 15]

momi = Momi(g, sampled_demes=sampled_demes, sample_sizes=sample_sizes, low_memory=True)
params = momi._default_params
bounds = momi.bound_sampler(params, 1000, seed=108)
momi = momi.bound(bounds)

params.set_train_all_etas(True)
params.set_train('eta_0', False)
jsfs = momi.simulate(200, seed=1124321)

momi.optimize(params=params, jsfs=jsfs, stepsize=0.5, maxiter=50)

TypeError info as below


TypeError Traceback (most recent call last) Cell In [7], line 1 ----> 1 momi.optimize(params=params, jsfs=jsfs, stepsize=0.5, maxiter=50)

File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/momi3-0.0.0-py3.10.egg/momi3/MOMI.py:284, in Momi.optimize(self, params, jsfs, stepsize, maxiter, theta_train_dict_0, htol, monitor_training) 281 negative_loglik_with_gradient = self.negative_loglik_with_gradient 282 sampled_demes = self.sampled_demes --> 284 return ProjectedGradient_optimizer( 285 negative_loglik_with_gradient=negative_loglik_with_gradient, 286 params=params, 287 jsfs=jsfs, 288 stepsize=stepsize, 289 maxiter=maxiter, 290 theta_train_dict_0=theta_train_dict_0, 291 sampled_demes=sampled_demes, 292 htol=htol, 293 monitor_training=monitor_training, 294 )

File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/momi3-0.0.0-py3.10.egg/momi3/optimizers.py:121, in ProjectedGradient_optimizer(negative_loglik_with_gradient, params, jsfs, stepsize, maxiter, sampled_demes, theta_train_dict_0, htol, monitor_training) 118 plt.xlabel("Iteration Number") 120 else: --> 121 opt_result = pg.run(theta_train_0, hyperparams_proj=(A, b, G, h)) 122 theta_train_hat = opt_result.params 123 pg_state = opt_result.state

File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/projected_gradient.py:137, in ProjectedGradient.run(self, init_params, hyperparams_proj, *args, kwargs) 132 def run(self, 133 init_params: Any, 134 hyperparams_proj: Optional[Any] = None, 135 *args, 136 *kwargs) -> base.OptStep: --> 137 return self._pg.run(init_params, hyperparams_proj, args, kwargs)

File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/base.py:359, in IterativeSolver.run(self, init_params, *args, *kwargs) 352 decorator = idf.custom_root( 353 self.optimality_fun, 354 has_aux=True, 355 solve=self.implicit_diff_solve, 356 reference_signature=reference_signature) 357 run = decorator(run) --> 359 return run(init_params, args, **kwargs)

File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py:251, in _custom_root..wrapped_solver_fun(*args, *kwargs) 249 args, kwargs = _signature_bind(solver_fun_signature, args, *kwargs) 250 keys, vals = list(kwargs.keys()), list(kwargs.values()) --> 251 return make_custom_vjp_solver_fun(solver_fun, keys)(args, *vals)

[... skipping hidden 5 frame]

File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py:207, in _custom_root..make_custom_vjp_solver_fun..solver_fun_flat(flat_args) 204 @jax.custom_vjp 205 def solver_fun_flat(flat_args): 206 args, kwargs = _extract_kwargs(kwarg_keys, flat_args) --> 207 return solver_fun(*args, **kwargs)

File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/base.py:321, in IterativeSolver._run(self, init_params, *args, *kwargs) 303 # We unroll the very first iteration. This allows init_val and body_fun 304 # below to have the same output type, which is a requirement of 305 # lax.while_loop and lax.scan. (...) 316 # of a lax.cond for now in order to avoid staging the initial 317 # update and the run loop. They might not be staging compatible. 319 zero_step = self._make_zero_step(init_params, state) --> 321 opt_step = self.update(init_params, state, args, **kwargs) 322 init_val = (opt_step, (args, kwargs)) 324 unroll = self._get_unroll_option()

File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/proximal_gradient.py:305, in ProximalGradient.update(self, params, state, hyperparams_prox, *args, **kwargs) 293 """Performs one iteration of proximal gradient. 294 295 Args: (...) 302 (params, state) 303 """ 304 f = self._update_accel if self.acceleration else self._update --> 305 return f(params, state, hyperparams_prox, args, kwargs)

File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/proximal_gradient.py:266, in ProximalGradient._update_accel(self, x, state, hyperparams_prox, args, kwargs) 263 stepsize = state.stepsize 264 (y_fun_val, aux), y_fun_grad = self._value_and_grad_with_aux(y, *args, 265 kwargs) --> 266 next_x, next_stepsize = self._iter(iter_num, y, y_fun_val, y_fun_grad, 267 stepsize, hyperparams_prox, args, kwargs) 268 next_t = 0.5 (1 + jnp.sqrt(1 + 4 t 2)) 269 diff_x = tree_sub(next_x, x)

File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/proximal_gradient.py:240, in ProximalGradient._iter(self, iter_num, x, x_fun_val, x_fun_grad, stepsize, hyperparams_prox, args, kwargs) 238 else: 239 next_stepsize = self.stepsize --> 240 next_x = self._prox_grad(x, x_fun_grad, next_stepsize, hyperparams_prox) 241 return next_x, next_stepsize

File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/proximal_gradient.py:209, in ProximalGradient._prox_grad(self, x, x_fun_grad, stepsize, hyperparams_prox) 208 def _prox_grad(self, x, x_fun_grad, stepsize, hyperparams_prox): --> 209 update = tree_add_scalar_mul(x, -stepsize, x_fun_grad) 210 return self.prox(update, hyperparams_prox, stepsize)

File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/tree_util.py:91, in tree_add_scalar_mul(tree_x, scalar, tree_y) 89 def tree_add_scalar_mul(tree_x, scalar, tree_y): 90 """Compute tree_x + scalar tree_y.""" ---> 91 return tree_map(lambda x, y: x + scalar y, tree_x, tree_y)

[... skipping hidden 2 frame]

File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/tree_util.py:91, in tree_add_scalar_mul..(x, y) 89 def tree_add_scalar_mul(tree_x, scalar, tree_y): 90 """Compute tree_x + scalar tree_y.""" ---> 91 return tree_map(lambda x, y: x + scalar y, tree_x, tree_y)

TypeError: unsupported operand type(s) for *: 'float' and 'dict' ​

error info above finished

enesdilber commented 3 weeks ago

Hi,

I’ve fixed the issue in this branch: grad_dict_fix. The error occurred because we were passing gradients as a dictionary type instead of an array. This fix should resolve the issue.

You can install this branch by running:

pip install git+https://github.com/jthlab/momi3.git@0.0.1

If you're experimenting stuff, you may want to monitor your training by passing monitor_training=True to the optimize method.