pmelchior / scarlet2

Scarlet, all new and shiny
MIT License
13 stars 3 forks source link

correct behavior when constraint and prior are set #34

Closed pmelchior closed 7 months ago

pmelchior commented 7 months ago

It's easy to get into trouble when a parameter specifies both constraint and prior. The problem stems from a transformation that is done to the variable to make it appear (for the sampler or the optimizer) to be unconstrained. E.g a positive constraint amounts to adding $f(x) = \exp(x)$ into the path, which means that the original variable is reparameterized as $y=\log(x)$. For the likelihood, this transformation is transparent because autograd will just apply one more chain rule, but for a prior (e.g. a score model), we compute the gradient $\nabla \log p(x)$, not $\nabla \log p(y)$. This will get people into trouble without even knowing it.

So, we need a warning for those cases that the same transformation $f$ needs to be apply when training and testing the prior network.

SampsonML commented 7 months ago

Aye, perhaps we can start writing some docs? Happy to get a start on that to fill up some busy work time.

jaredcsiegel commented 7 months ago

[Draft (based on my current understanding of the issue, so this summary will likely be subject to edits)]

Yesterday, Peter and I further discussed the nuances of using both constants and priors. The story is a little different from the initial description from Peter above. In short, the gradients are evaluated in terms of the untransformed variables $x$ but are then applied to the transformed variables $y$.

For a little summary:

When a constraint is applied to the parameter x, it preforms the transformation: $x \rightarrow y = f(x)$ (for example, $f=\log$ if a positive constraint is applied).

The transformation is initially performed when scene.fit is called:

# transform to unconstrained parameters
constraint_fn = {name: biject_to(info["constraint"]) for name, (value, info) in parameters.items() if info["constraint"] is not None}
scene = _constraint_replace(self, constraint_fn, inv=True)

at this point $x \rightarrow y$ for all parameters with a constraint.

The gradients for optimization are evaluating inside _make_step:

loss, grads = eqx.filter_value_and_grad(loss_fn)(model)

at this point $x \rightarrow y$ for all parameters with a constraint inside model.

Inside _make_step.loss_fn the first step is to transform $y \rightarrow x$:

model = _constraint_replace(model, constraint_fn)

the likelihood and prior are then evaluated on $x$.

This creates a mismatch, because the gradients returned by loss_fn are in terms of $x$ but these gradients are then applied to model, which is in terms of $y$:

updates, opt_state = optim.update(grads, opt_state, model)
pmelchior commented 7 months ago

@jaredcsiegel this is almost correct. Let me simplify it a little. In case of an unconstrained parameter $x$, the optimization (by optax.optimizer.update) simply carries out this sequence:

$$x^{t+1} = x^t - \lambda^t \nabla_x l(x)$$

with some loss function. In our case with a prior, this is $l(x) = \log\mathcal{L}(x\mid\mathcal{D}) + \log p(x)$. The important part is this is all consistently in the space of $\mathcal{X}$.

Now, when we transform a variable, $x\rightarrow f^{-1}(x)\equiv y$, we reparameterize the model and its loss function, and this happens inside of fit:

$$ \begin{align} &y^0 = f^{-1}(x_0)\ &y^{t+1} = y^t - \lambda^t \nabla_y l(f(y))\ &\dots\ &x^T = f(y^T) \end{align} $$

The key part is that the gradients are now evaluated as a function of $y$, not $x$. For the likelihood, this doesn't matter because we're using autodiff, and the transformation $f(y)$ is in the forward path. But the prior gradients are direct evaluations of the score network, which has been trained to produce $s(x) = \nabla_x \log p(x)$. So, the problem isn't the argument to the score function, it's the return value being in the space $\mathcal{X}$, while the optimizer operates on the space $\mathcal{Y}$.

There are two solutions to this problem:

  1. Train the score network to produce $\nabla_y \log p(x)$. Note that this is not $\nabla_y \log p(y)$: domain and co-domain of this function are different, it goes from $\mathcal{X}$ to $\mathcal{Y}$. This is ugly because it requires that the person training the prior knows what constraints a person using it may also specify.
  2. Apply the chain rule in fit: $\nabla_y \log p(x(y)) = \nablax \log p(x)\cdot \frac{\partial x}{\partial y} = s(x) \cdot J{f^{-1}}$, where $J$ refers to the Jacobian matrix of the inverse function.

Because of the problems with item 1, I much prefer option 2. This can be done with another call to jax.jacfwd on the inverse method of the transformation.

Note: This looks like it the same problem @SampsonML solved when training the morphology prior. He applied the transformation $\log(x+1)$ to compress the dynamic range. But he's done that during training, so that both the domain and the co-domain got transformed, and at test time he's applying the transformation consistently to the argument and the return value without the user even knowing that the score model internally operates on something different from $x$.

SampsonML commented 7 months ago

Yep, is is relatively straightforward to solve using method 2 and is in fact I have done exactly as Peter described. If you look here https://github.com/pmelchior/scarlet2/blob/90061b8e1e7653f83df80d98ac519c9f7c92f2e2/scarlet2/nn.py#L114 you can see from this, and line 132, you can easily parse in a transformation function and all the work is taken care of.

So the user just needs to know that any transformation performed, either via constraints or other methods can, and should be parsed into the construction of ScorePrior(). If using more than one transformation, ie the training space one I have, and another constraint based transformation one should be able to just define a new composite transformation function m = f(g) and parse in m to ScorePrior(transform=m).

So in short, I think just more clear docs are needed shouldn't be a need for any code adaption as far as I see.

pmelchior commented 7 months ago

@SampsonML I saw that you've already implemented it, and I think this is precisely the mechanism we can use to make a prior evaluate constrained parameters.

However, can you send here, how you specify the prior when you run scarlet. I think that you have to tell it to use something like transform=lambda x: jnp.log(1+x) when you instantiate the ScorePrior class?

SampsonML commented 7 months ago

An example is here below, you could indeed use a lambda function as you have shown, I just explicitly defined my transform here

# load in the model you wish to use
from galaxygrad import HSC_ScoreNet64

# choose model
prior_model = HSC_ScoreNet64

# specify model size and temperature
model_size = 64
temp = 0.01 
prior_model = nn.TempScore(prior_model, temp=temp)

# define transform
def transform(x):
    sigma_y = 0.10
    return jnp.log(x + 1) / sigma_y

# construct prior
prior = nn.ScorePrior(
            model=prior_model, transform=transform, model_size=model_size
        )

Note, you can also not use the temperature model and hence have

# load in the model you wish to use
from galaxygrad import HSC_ScoreNet64

# construct prior
prior = nn.ScorePrior(
            model=HSC_ScoreNet64, transform=lambda x: jnp.log(x + 1) / 0.1, model_size=64
        )
pmelchior commented 7 months ago

That's what I suspected. I think having to provide the transform argument at instantiation is dangerous. What happens if an unsuspecting user forgets it? Or changes it to something else?

To me this should be treated as part of the HSC_ScoreNet64 because it's a transformation you applied while training the network. A user should not need to know. So, I suggest, you create a wrapper, in which HSC_ScoreNet64 knows that it needs to apply this transformation on the inputs and the associated Jacobian on the gradients.

This would also allow us to leverage your existing transform mechanism from scarlet2.nn behind the scenes for the constraint transformation.

SampsonML commented 7 months ago

Ok no problem can sort that

SampsonML commented 7 months ago

Alright, all set @pmelchior galaxygrad==0.1.1 now has all the data transforms and corresponding gradient transforms handled inside. I agree this is nicer, this should be handled by the model trainer/maker not the model user.

The transform function is now free to be used for the constraints.

pmelchior commented 7 months ago

My description above was also incorrect. It turns out that autodiff follows the constraint transformation not just into the likelihood evaluation (and then applies the Jacobians correctly), it also does the same for custom gradients like we get from ScorePrior. This means that we don't need to handle priors differently. In particular, as long as autodiff can see the transformation before its given to the prior (and it does see that), it will automatically apply the chain rule for the gradients. It's called automatic differentiation for a reason...

pmelchior commented 7 months ago

closed with #35