A good place to start would be to substantially beat the current CPU JAX INLA implementation.
Things to do here:
[ ] Get a CUDA GPU cloud instance or local machine and set up a smooth dev environment on the machine. I think we should be willing to invest substantial time here. Good profiling and remote dev tools will be valuable.
[ ] Try running the existing JAX implementation on the cuda backend. What's the performance like? What improvements could be made? Are there operations that are poorly suited to the cuda backend?
I think if we find the hessian to be poorly conditioned, we could probably bail out to a first-order optimization method? Could plug in any of the standard NN optimizers or see here for fancier ones (I'd start with Adam). Jax also has a BFGS solver, which I haven't tried.
I also wonder if there's an analytic solution as sigma -> 0.
If you want to do something really fun, it's probably possible to regularize the newton steps. See e.g. this for a new newton algorithm that claims global convergence.
A good place to start would be to substantially beat the current CPU JAX INLA implementation.
Things to do here:
Explore other JIT tools that might fit our use case. Does Numba work well? CuPy? ArrayFire?