Extracted the log_joint and grad_hess functions from the numpy_inference pathway through FastINLA. This allows replacing those functions with other models. These models still need to have the hierarchical sigma2-InvGamma + theta-Normal structure, but can have different outcome distributions. In particular, in the other repo, I have examples using this for survival analysis.
Modified FastINLA to take a data parameter instead of y and n. This helps abstract the concept of data across different model types.
Currently, the tests are failing because of the JAX-float32 stuff. I don't like the precedent of merging PRs with failing tests, so I'd prefer to either:1. Wait until @constinit PR is ready to merge until merging this.2. Switch JAX to use float64 for this PR. Then @constinit can revert this change in his PR.
log_joint
andgrad_hess
functions from the numpy_inference pathway through FastINLA. This allows replacing those functions with other models. These models still need to have the hierarchical sigma2-InvGamma + theta-Normal structure, but can have different outcome distributions. In particular, in the other repo, I have examples using this for survival analysis.data
parameter instead ofy
andn
. This helps abstract the concept of data across different model types.Currently, the tests are failing because of the JAX-float32 stuff. I don't like the precedent of merging PRs with failing tests, so I'd prefer to either:1. Wait until @constinit PR is ready to merge until merging this.2. Switch JAX to use float64 for this PR. Then @constinit can revert this change in his PR.Fine with me either way.