bischtob / Opterax

An gradient-free optimization suite written in JAX. We conform to the optax interface and provide ensemble-based optimizers.
Other
1 stars 0 forks source link

Make EKI ensemble gradient approximator work with PyTrees #7

Open bischtob opened 3 years ago

bischtob commented 3 years ago

context In order to achieve full compatibility with JAX we need to make sure that all ensemble gradient estimators actually work with nested parameter containers.

ideas The main problem lies in building parameter covariance matrices for the model parameters. This should be equivalent to building covariance matrices with blocks structure, assembling the larger covariance matrix, disassembling it, etc. Perhaps it is possible to leverage JAX's ability to compute on pytrees directly?

next step deep dive into pytree datastructure, computing with pytree in jax, and recasting the problem of addition and matrix multiplication in terms of pytrees

bischtob commented 3 years ago

Ideas:

  1. Reuse JAX interal tree utils to flatten and unflatten
  2. See if update step can be recast as a map over leaf nodes -> parcel out the complex bits and optimize separately
  3. Implement an efficient covariance matrix calculation -> low-rank approximation?