Open tjhladish opened 8 months ago
Is it possible that under the hood there's a mutable default problem here? https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments
I am looking into a possible source of the issue with our passed in functions within get_parameters()
things like seasonality(t)
may not be functionally pure: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions
as noted in the jax docs: "Impure functions are dangerous because under JAX transformations they are likely not to behave as intended; they might fail silently, or produce surprising downstream errors like leaked Tracers. Moreover, JAX often can’t detect when side effects are present."
this may be the source of the issue
Currently, we are using deep copies of parameters in order to avoid a memory leak issue that resulted in MCMC chains overwriting an inappropriately shared value for a particular parameter. While this seems to fix the bug, @jasonasher and @ekr-cfa have said that this is not likely the right way to solve the problem, and probably some restructuring of the code can avoid the problem.
We need to dig into this a bit more to determine what other approaches may be viable.