mlysy / pfjax

Particle filtering in JAX
https://pfjax.readthedocs.io
MIT License
1 stars 0 forks source link

Not to use Boolean in mwg_adapt #3

Closed m9ko closed 2 years ago

m9ko commented 2 years ago

In mwg_adapt, we get an error that traced values cannot be used in Boolean functions (at line 207), as explained in: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-jit

https://github.com/mlysy/pfjax/blob/fff2a24e91f8de0350bf05bcaaaa2f4452f34452/src/pfjax/mcmc.py#L203-L207

What seems to work is

low_acc = jnp.sign(targ_acc - accept_rate) return jnp.exp(jnp.log(rw_sd) - delta * low_acc),

which essentially achieves the same but now does not involve a Boolean. Also, should line 205 be minimum instead of maximum?

mlysy commented 2 years ago

Thanks for catching all this. Can you please rebase the main branch and create a PR? It'll be easier this way than rebasing after the pending PR...