context
In order to achieve full compatibility with JAX we need to make sure that all ensemble gradient estimators actually work with nested parameter containers.
ideas
The main problem lies in building parameter covariance matrices for the model parameters. This should be equivalent to building covariance matrices with blocks structure, assembling the larger covariance matrix, disassembling it, etc. Perhaps it is possible to leverage JAX's ability to compute on pytrees directly?
next step
deep dive into pytree datastructure, computing with pytree in jax, and recasting the problem of addition and matrix multiplication in terms of pytrees
context In order to achieve full compatibility with JAX we need to make sure that all ensemble gradient estimators actually work with nested parameter containers.
ideas The main problem lies in building parameter covariance matrices for the model parameters. This should be equivalent to building covariance matrices with blocks structure, assembling the larger covariance matrix, disassembling it, etc. Perhaps it is possible to leverage JAX's ability to compute on pytrees directly?
next step deep dive into pytree datastructure, computing with pytree in jax, and recasting the problem of addition and matrix multiplication in terms of pytrees