Closed pmelchior closed 7 months ago
Aye, perhaps we can start writing some docs? Happy to get a start on that to fill up some busy work time.
[Draft (based on my current understanding of the issue, so this summary will likely be subject to edits)]
Yesterday, Peter and I further discussed the nuances of using both constants and priors. The story is a little different from the initial description from Peter above. In short, the gradients are evaluated in terms of the untransformed variables $x$ but are then applied to the transformed variables $y$.
For a little summary:
When a constraint is applied to the parameter x, it preforms the transformation: $x \rightarrow y = f(x)$ (for example, $f=\log$ if a positive constraint is applied).
The transformation is initially performed when scene.fit
is called:
# transform to unconstrained parameters
constraint_fn = {name: biject_to(info["constraint"]) for name, (value, info) in parameters.items() if info["constraint"] is not None}
scene = _constraint_replace(self, constraint_fn, inv=True)
at this point $x \rightarrow y$ for all parameters with a constraint.
The gradients for optimization are evaluating inside _make_step
:
loss, grads = eqx.filter_value_and_grad(loss_fn)(model)
at this point $x \rightarrow y$ for all parameters with a constraint inside model
.
Inside _make_step.loss_fn
the first step is to transform $y \rightarrow x$:
model = _constraint_replace(model, constraint_fn)
the likelihood and prior are then evaluated on $x$.
This creates a mismatch, because the gradients returned by loss_fn
are in terms of $x$ but these gradients are then applied to model
, which is in terms of $y$:
updates, opt_state = optim.update(grads, opt_state, model)
@jaredcsiegel this is almost correct. Let me simplify it a little. In case of an unconstrained parameter $x$, the optimization (by optax.optimizer.update
) simply carries out this sequence:
$$x^{t+1} = x^t - \lambda^t \nabla_x l(x)$$
with some loss function. In our case with a prior, this is $l(x) = \log\mathcal{L}(x\mid\mathcal{D}) + \log p(x)$. The important part is this is all consistently in the space of $\mathcal{X}$.
Now, when we transform a variable, $x\rightarrow f^{-1}(x)\equiv y$, we reparameterize the model and its loss function, and this happens inside of fit
:
$$ \begin{align} &y^0 = f^{-1}(x_0)\ &y^{t+1} = y^t - \lambda^t \nabla_y l(f(y))\ &\dots\ &x^T = f(y^T) \end{align} $$
The key part is that the gradients are now evaluated as a function of $y$, not $x$. For the likelihood, this doesn't matter because we're using autodiff, and the transformation $f(y)$ is in the forward path. But the prior gradients are direct evaluations of the score network, which has been trained to produce $s(x) = \nabla_x \log p(x)$. So, the problem isn't the argument to the score function, it's the return value being in the space $\mathcal{X}$, while the optimizer operates on the space $\mathcal{Y}$.
There are two solutions to this problem:
fit
: $\nabla_y \log p(x(y)) = \nablax \log p(x)\cdot \frac{\partial x}{\partial y} = s(x) \cdot J{f^{-1}}$, where $J$ refers to the Jacobian matrix of the inverse function.Because of the problems with item 1, I much prefer option 2. This can be done with another call to jax.jacfwd
on the inverse method of the transformation.
Note: This looks like it the same problem @SampsonML solved when training the morphology prior. He applied the transformation $\log(x+1)$ to compress the dynamic range. But he's done that during training, so that both the domain and the co-domain got transformed, and at test time he's applying the transformation consistently to the argument and the return value without the user even knowing that the score model internally operates on something different from $x$.
Yep, is is relatively straightforward to solve using method 2 and is in fact I have done exactly as Peter described. If you look here https://github.com/pmelchior/scarlet2/blob/90061b8e1e7653f83df80d98ac519c9f7c92f2e2/scarlet2/nn.py#L114 you can see from this, and line 132, you can easily parse in a transformation function and all the work is taken care of.
So the user just needs to know that any transformation performed, either via constraints or other methods can, and should be parsed into the construction of ScorePrior()
. If using more than one transformation, ie the training space one I have, and another constraint based transformation one should be able to just define a new composite transformation function m = f(g) and parse in m to ScorePrior(transform=m)
.
So in short, I think just more clear docs are needed shouldn't be a need for any code adaption as far as I see.
@SampsonML I saw that you've already implemented it, and I think this is precisely the mechanism we can use to make a prior evaluate constrained parameters.
However, can you send here, how you specify the prior when you run scarlet. I think that you have to tell it to use something like transform=lambda x: jnp.log(1+x)
when you instantiate the ScorePrior
class?
An example is here below, you could indeed use a lambda function as you have shown, I just explicitly defined my transform here
# load in the model you wish to use
from galaxygrad import HSC_ScoreNet64
# choose model
prior_model = HSC_ScoreNet64
# specify model size and temperature
model_size = 64
temp = 0.01
prior_model = nn.TempScore(prior_model, temp=temp)
# define transform
def transform(x):
sigma_y = 0.10
return jnp.log(x + 1) / sigma_y
# construct prior
prior = nn.ScorePrior(
model=prior_model, transform=transform, model_size=model_size
)
Note, you can also not use the temperature model and hence have
# load in the model you wish to use
from galaxygrad import HSC_ScoreNet64
# construct prior
prior = nn.ScorePrior(
model=HSC_ScoreNet64, transform=lambda x: jnp.log(x + 1) / 0.1, model_size=64
)
That's what I suspected. I think having to provide the transform
argument at instantiation is dangerous. What happens if an unsuspecting user forgets it? Or changes it to something else?
To me this should be treated as part of the HSC_ScoreNet64
because it's a transformation you applied while training the network. A user should not need to know. So, I suggest, you create a wrapper, in which HSC_ScoreNet64
knows that it needs to apply this transformation on the inputs and the associated Jacobian on the gradients.
This would also allow us to leverage your existing transform
mechanism from scarlet2.nn
behind the scenes for the constraint transformation.
Ok no problem can sort that
Alright, all set @pmelchior galaxygrad==0.1.1 now has all the data transforms and corresponding gradient transforms handled inside. I agree this is nicer, this should be handled by the model trainer/maker not the model user.
The transform function is now free to be used for the constraints.
My description above was also incorrect. It turns out that autodiff follows the constraint transformation not just into the likelihood evaluation (and then applies the Jacobians correctly), it also does the same for custom gradients like we get from ScorePrior
. This means that we don't need to handle priors differently. In particular, as long as autodiff can see the transformation before its given to the prior (and it does see that), it will automatically apply the chain rule for the gradients. It's called automatic differentiation for a reason...
closed with #35
It's easy to get into trouble when a parameter specifies both constraint and prior. The problem stems from a transformation that is done to the variable to make it appear (for the sampler or the optimizer) to be unconstrained. E.g a positive constraint amounts to adding $f(x) = \exp(x)$ into the path, which means that the original variable is reparameterized as $y=\log(x)$. For the likelihood, this transformation is transparent because autograd will just apply one more chain rule, but for a prior (e.g. a score model), we compute the gradient $\nabla \log p(x)$, not $\nabla \log p(y)$. This will get people into trouble without even knowing it.
So, we need a warning for those cases that the same transformation $f$ needs to be apply when training and testing the prior network.