Closed m9ko closed 2 years ago
In mwg_adapt, we get an error that traced values cannot be used in Boolean functions (at line 207), as explained in: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-jit
mwg_adapt
https://github.com/mlysy/pfjax/blob/fff2a24e91f8de0350bf05bcaaaa2f4452f34452/src/pfjax/mcmc.py#L203-L207
What seems to work is
low_acc = jnp.sign(targ_acc - accept_rate) return jnp.exp(jnp.log(rw_sd) - delta * low_acc),
low_acc = jnp.sign(targ_acc - accept_rate)
return jnp.exp(jnp.log(rw_sd) - delta * low_acc)
which essentially achieves the same but now does not involve a Boolean. Also, should line 205 be minimum instead of maximum?
Thanks for catching all this. Can you please rebase the main branch and create a PR? It'll be easier this way than rebasing after the pending PR...
main
In
mwg_adapt
, we get an error that traced values cannot be used in Boolean functions (at line 207), as explained in: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-jithttps://github.com/mlysy/pfjax/blob/fff2a24e91f8de0350bf05bcaaaa2f4452f34452/src/pfjax/mcmc.py#L203-L207
What seems to work is
low_acc = jnp.sign(targ_acc - accept_rate)
return jnp.exp(jnp.log(rw_sd) - delta * low_acc)
,which essentially achieves the same but now does not involve a Boolean. Also, should line 205 be minimum instead of maximum?