Open mjo22 opened 8 months ago
Are you aiming for something like this example?
In the above example the operation we want to compute gradients with respect to the parameters living in an unconstrained Euclidean parameter space (because otherwise our gradients may cause us to leave the constrained space we want to live in), but as soon as we pass through the grad
boundary we perform the transformation that takes us from an unconstrained space to the constrained space that we want our model to actually draw parameter values from.
So in your example, perhaps you want to record log_power
(an unconstrained value taking values anywhere in R) in your model. Then in your tree-map you would exponentiate to get the power you actually want to apply to the model. (Or in this simple example, just do it directly: x**jnp.exp(self.log_power).
)
Not sure if I've correctly grasped what you're trying to do, let me know if that helps!
I suppose the equivalent would be something like
class ExpTransform(eqx.Module):
log_array: jax.Array
def __init__(self, array: jax.Array):
self.log_array = jnp.log(array)
def get(self):
return jnp.exp(self.log_array)
def is_exp(x):
return isinstance(x, ExpTransform)
def maybe_exp(x):
if is_exp(x):
return x.get()
else:
return x # leave everything else unchanged
def resolve_exp(model):
return jax.tree_util.tree_map(maybe_exp, model, is_leaf=is_exp)
model = PowerLaw(…)
get_params = lambda m: (m.amplitude, m.power)
exp_model = eqx.tree_at(get_params, model, replace_fn=ExpTransform)
@eqx.filter_grad
def loss_fn(model, x, y):
model = resolve_exp(model)
pred_y = jax.vmap(model)(x)
return jnp.sum((y - pred_y)**2)
grads = loss_fn(exp_model, ...)
What do you think? This seems like a reasonable solution to me. I will say that I don’t necessarily want to just enforce constraints (though this can be very related)—mainly I am looking for a flexible way to re-parameterize my parameter space to one with good posterior geometry, which is of course a very common task in the physical sciences. If you have any suggestions on this front that would be well suited for a PR, I could help with this. However, the above seems like a good pattern.
Yup, I think that looks good to me! What you've written above is essentially the generic version, for when you don't control the definition of PowerLaw
. (If you do control its definition, then you could just inline everything into it directly.)
I think the above is my currently-recommended pattern for handling reparameterisations. I have wondered about adding some kind of support for this directly into Equinox (eqx.field(parameterisation=...)
?) This is basically a generalised version of #621 (one could imagine setting parameterisation=jax.lax.stop_gradient
), but it bumps into the same issues as over there -- it only works when you control the module definition, and then it's easy to inline whatever behaviour you want.
This makes sense. Yes it seems to me adding something in eqx.field
would be somewhat equivalent to just controlling the definition of PowerLaw
, whereas this is a bit more flexible.
I actually am curious to also use an approach like this to more generally enforce constraints. For example let’s say power
is only allowed to be positive.
One approach to enforce that it is positive is to do this reparameterization. Another is to do some kind of run-time error checking with error_if. I suppose one could add this in __check_init__
(correct me if i’m wrong), but it seems to me this doesn’t cover the case of tree_at
updates.
Now, of course this error check can always be accomplished by wrapping an array in a module. In my code this would get into territory of making some kind of abstract interface for every kind of parameter I can have—unconstrained, positive, angles, etc, which would be a little too cumbersome I think. I am wondering if the eqx.field
approach could be useful for something like this case, i.e. field(…, constraint=lambda x: eqx.error_if(x, x < 0, “x must be positive!”))
Reparameterisation would definitely be the standard way of doing this.
Indeed __check_init__
wouldn't catch anything with tree_at
(which is a good thing!)
In terms of having constraint=...
, I think this use case is probably doable today with the existing converter=...
API.
To make sure I understand, if I want to make sure a quantity stays positive (for example), would you recommend
__check_init__
with error_if
to make sure it is initialized correctly.And instead of re-parameterizing if I want to add checks beyond what __check_init__
will catch, would you do something like
class ConstraintChecker(eqx.Module):
quantity: jax.Array
def get(self):
eqx.error_if(self.quantity, self.quantity < 0, "quantity must be positive!")
return self.quantity
def is_constrained(x):
return isinstance(x, ConstraintChecker)
def maybe_constrained(x):
if is_constrained(x):
return x.get()
else:
return x # leave everything else unchanged
def resolve_constraints(model):
return jax.tree_util.tree_map(maybe_constrained, model, is_leaf=is_constrained)
model = PowerLaw(…)
get_params = lambda m: (m.power,)
constrained_model = eqx.tree_at(get_params, model, replace_fn=ConstraintChecker)
@eqx.filter_grad
def loss_fn(model, x, y):
model = resolve_constraints(model)
pred_y = jax.vmap(model)(x)
return jnp.sum((y - pred_y)**2)
grads = loss_fn(constrained_model, ...)
I don't think __check_init__
can be used in conjuction with error_if
, as the instance is already immutable in the former case but the latter requires re-assigning the variable. (Otherwise the check gets DCE'd.) FWIW initialise usually happens outside of JIT so you could probably just use a normal Python if
statement.
Yup.
For your final example: the one issue with this is that the error_if
will get DCE'd, as you don't use its output. You want return eqx.error_if(self.quantity, ...)
. Other than that, yup!
Okay, I see. Thank you for the help!
I will definitely have instances where I init after a jit because I make heavy use of model ensembling, where I think the recommended pattern is the init after a vmap.
It seems to me that if I want the __check_init__
behavior I wanted, I should then do this with the converter
? Or explicitly in the __init__
?
Yup!
In case it's useful (or if anyone has any suggestions/criticisms), I made a little package paramax for constraints/parameterizations for JAX PyTrees. The jist is essentially a generalized version of the above, using AbstractUnwrappable
objects which contain the unbounded parameters and logic for unwrapping, and an unwrap
function the maps across the PyTree, unwrapping any (possibly nested) AbstractUnwrappable
objects. Most parameterizations are simple to initialize by passing a function to apply when unwrapping, along with the (unbounded) arguments, e.g. scale=paramax.Parameterize(jnp.exp, jnp.log(jnp.ones(3)))
will apply exp when the model is unwrapped.
Hello! Let’s say I have something like the following
Now, let’s say I want to optimize over the log of
power
andamplitude
, but my loss function directly callsPowerLaw.__call__
. What is the recommended pattern for usingeqx.apply_updates
here?Of course I could handle this toy example with a single call to
tree_map
, but I am interested in the case where I optimize over an arbitrary set of model parameters in an arbitrary transformed space. To be more clear, I would have a pytree of functions I want to apply to my model before an optimization loop to go into the transformed parameter space, take gradients in this transformed parameter space, but apply the inverse pytree of parameters before ultimately evaluating my loss function.Sorry if this isn’t completely clear, struggling a bit with the pseudocode on this one.