patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.09k stars 140 forks source link

What is the recommended pattern for optimization loops in transformed parameter spaces? #676

Open mjo22 opened 7 months ago

mjo22 commented 7 months ago

Hello! Let’s say I have something like the following


def PowerLaw(eqx.Module):

    amplitude: Array
    power: Array

    def __call__(self, x: Array):
         return amplitude * x**power

Now, let’s say I want to optimize over the log of power and amplitude, but my loss function directly calls PowerLaw.__call__. What is the recommended pattern for using eqx.apply_updates here?

Of course I could handle this toy example with a single call to tree_map, but I am interested in the case where I optimize over an arbitrary set of model parameters in an arbitrary transformed space. To be more clear, I would have a pytree of functions I want to apply to my model before an optimization loop to go into the transformed parameter space, take gradients in this transformed parameter space, but apply the inverse pytree of parameters before ultimately evaluating my loss function.

Sorry if this isn’t completely clear, struggling a bit with the pseudocode on this one.

patrick-kidger commented 7 months ago

Are you aiming for something like this example? In the above example the operation we want to compute gradients with respect to the parameters living in an unconstrained Euclidean parameter space (because otherwise our gradients may cause us to leave the constrained space we want to live in), but as soon as we pass through the grad boundary we perform the transformation that takes us from an unconstrained space to the constrained space that we want our model to actually draw parameter values from.

So in your example, perhaps you want to record log_power (an unconstrained value taking values anywhere in R) in your model. Then in your tree-map you would exponentiate to get the power you actually want to apply to the model. (Or in this simple example, just do it directly: x**jnp.exp(self.log_power).)

Not sure if I've correctly grasped what you're trying to do, let me know if that helps!

mjo22 commented 7 months ago

I suppose the equivalent would be something like

class ExpTransform(eqx.Module):
  log_array: jax.Array

  def __init__(self, array: jax.Array):
    self.log_array = jnp.log(array)

  def get(self):
    return jnp.exp(self.log_array)

def is_exp(x):
    return isinstance(x, ExpTransform)

def maybe_exp(x):
    if is_exp(x):
        return x.get()
    else:
        return x  # leave everything else unchanged

def resolve_exp(model):
    return jax.tree_util.tree_map(maybe_exp, model, is_leaf=is_exp)

model = PowerLaw(…)
get_params = lambda m: (m.amplitude, m.power)
exp_model = eqx.tree_at(get_params, model, replace_fn=ExpTransform)

@eqx.filter_grad
def loss_fn(model, x, y):
  model = resolve_exp(model)
  pred_y = jax.vmap(model)(x)
  return jnp.sum((y - pred_y)**2)

grads = loss_fn(exp_model, ...)

What do you think? This seems like a reasonable solution to me. I will say that I don’t necessarily want to just enforce constraints (though this can be very related)—mainly I am looking for a flexible way to re-parameterize my parameter space to one with good posterior geometry, which is of course a very common task in the physical sciences. If you have any suggestions on this front that would be well suited for a PR, I could help with this. However, the above seems like a good pattern.

patrick-kidger commented 7 months ago

Yup, I think that looks good to me! What you've written above is essentially the generic version, for when you don't control the definition of PowerLaw. (If you do control its definition, then you could just inline everything into it directly.)

I think the above is my currently-recommended pattern for handling reparameterisations. I have wondered about adding some kind of support for this directly into Equinox (eqx.field(parameterisation=...)?) This is basically a generalised version of #621 (one could imagine setting parameterisation=jax.lax.stop_gradient), but it bumps into the same issues as over there -- it only works when you control the module definition, and then it's easy to inline whatever behaviour you want.

mjo22 commented 7 months ago

This makes sense. Yes it seems to me adding something in eqx.field would be somewhat equivalent to just controlling the definition of PowerLaw, whereas this is a bit more flexible.

I actually am curious to also use an approach like this to more generally enforce constraints. For example let’s say power is only allowed to be positive.

One approach to enforce that it is positive is to do this reparameterization. Another is to do some kind of run-time error checking with error_if. I suppose one could add this in __check_init__ (correct me if i’m wrong), but it seems to me this doesn’t cover the case of tree_at updates.

Now, of course this error check can always be accomplished by wrapping an array in a module. In my code this would get into territory of making some kind of abstract interface for every kind of parameter I can have—unconstrained, positive, angles, etc, which would be a little too cumbersome I think. I am wondering if the eqx.field approach could be useful for something like this case, i.e. field(…, constraint=lambda x: eqx.error_if(x, x < 0, “x must be positive!”))

patrick-kidger commented 7 months ago

Reparameterisation would definitely be the standard way of doing this.

Indeed __check_init__ wouldn't catch anything with tree_at (which is a good thing!)

In terms of having constraint=..., I think this use case is probably doable today with the existing converter=... API.

mjo22 commented 7 months ago

To make sure I understand, if I want to make sure a quantity stays positive (for example), would you recommend

  1. Check this in a __check_init__ with error_if to make sure it is initialized correctly.
  2. When doing optimization, reparameterize it in the above pattern.

And instead of re-parameterizing if I want to add checks beyond what __check_init__ will catch, would you do something like

class ConstraintChecker(eqx.Module):
  quantity: jax.Array

  def get(self):
    eqx.error_if(self.quantity, self.quantity < 0, "quantity must be positive!")
    return self.quantity

def is_constrained(x):
    return isinstance(x, ConstraintChecker)

def maybe_constrained(x):
    if is_constrained(x):
        return x.get()
    else:
        return x  # leave everything else unchanged

def resolve_constraints(model):
    return jax.tree_util.tree_map(maybe_constrained, model, is_leaf=is_constrained)

model = PowerLaw(…)
get_params = lambda m: (m.power,)
constrained_model = eqx.tree_at(get_params, model, replace_fn=ConstraintChecker)

@eqx.filter_grad
def loss_fn(model, x, y):
  model = resolve_constraints(model)
  pred_y = jax.vmap(model)(x)
  return jnp.sum((y - pred_y)**2)

grads = loss_fn(constrained_model, ...)
patrick-kidger commented 7 months ago
  1. I don't think __check_init__ can be used in conjuction with error_if, as the instance is already immutable in the former case but the latter requires re-assigning the variable. (Otherwise the check gets DCE'd.) FWIW initialise usually happens outside of JIT so you could probably just use a normal Python if statement.

  2. Yup.

For your final example: the one issue with this is that the error_if will get DCE'd, as you don't use its output. You want return eqx.error_if(self.quantity, ...). Other than that, yup!

mjo22 commented 7 months ago

Okay, I see. Thank you for the help!

I will definitely have instances where I init after a jit because I make heavy use of model ensembling, where I think the recommended pattern is the init after a vmap.

It seems to me that if I want the __check_init__ behavior I wanted, I should then do this with the converter? Or explicitly in the __init__?

patrick-kidger commented 7 months ago

Yup!

danielward27 commented 6 days ago

In case it's useful (or if anyone has any suggestions/criticisms), I made a little package paramax for constraints/parameterizations for JAX PyTrees. The jist is essentially a generalized version of the above, using AbstractUnwrappable objects which contain the unbounded parameters and logic for unwrapping, and an unwrap function the maps across the PyTree, unwrapping any (possibly nested) AbstractUnwrappable objects. Most parameterizations are simple to initialize by passing a function to apply when unwrapping, along with the (unbounded) arguments, e.g. scale=paramax.Parameterize(jnp.exp, jnp.log(jnp.ones(3))) will apply exp when the model is unwrapped.