patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
310 stars 15 forks source link

Would an exhaustive grid search have a place in `optimistix`? #55

Closed mjo22 closed 6 months ago

mjo22 commented 6 months ago

Hello! In my applications it is very common to optimize a function with an exhaustive grid search method. This is because our loss functions are sharped peaked and poorly behaved in a manageable subset of parameter space, so it is often best to do exhaustive search (perhaps in a clever way). I am planning on implementing this in JAX, and I am wondering if this is within the scope of optimistix. It seems to me that there are tools in the library that would be useful for this task.

A rough outline of the implementation I am imagining is the following:

  1. The grid: The grid would be represented as an arbitrary pytree, whose leaves have leading "batch" dimensions. Namely, leaf $i$ is an array (or a pytree of arrays) with leading dimension(s) $N_i$. For $m$ leaves, the grid is then represented as a $N_1 \times N_2 \times ... \times N_m$ cartesian grid.
  2. The cost function: The cost function $f$ would be a function that takes in a pytree of the same structure of the grid (and additional arguments), evaluated at grid point $(i_1, i_2, ..., i_m)$. This function can return a single value of the cost function, or it can return a grid of cost function evaluations. This would allow support for more clever grid searches than a simple exhaustive search. For example, the simplest example in my field is to return a grid of cost-function evaluations through fourier convolution--e.g. a search over the space of translations. In general, this sub-grid returned by $f$ would explore a region of parameter space unrelated to the $N_1 \times N_2 \times ... \times N_m$ grid.
  3. The solution: There would need to be a flexible API for how the results of the grid search are stored. In a simple case, one could store the results of the best cost function evaluation at every grid point returned by the function $f$. In a more complicated case, one might want to use the grid search to marginalize away a portion of parameter space.

I'm not sure if this is within the scope of optimistix, and I would totally understand if it is not. If it were to be added to the library, I suppose it could be used as a method of ultra-last resort.

patrick-kidger commented 6 months ago

I think probably not, to be honest. The problem with optimisation is how many kinds of it they are... ! So for example Bayesian optimisation wouldn't be in scope here either.

Thank you for the offer nonetheless! (I am starting to think we should find a place to collect "extra" JAX scientific libraries, if for example you were to implement such a thing yourself.)

mjo22 commented 6 months ago

This completely makes sense! I didn’t think it would fit to be honest, but wanted to ask to be sure. Regardless, I will probably use optimistix-like ideas when writing an API!

And would be very happy to learn more what you’re thinking along these lines, this would be very helpful.