matsengrp / multidms

Joint modeling of multiple deep mutational scanning experiments
https://matsengrp.github.io/multidms
MIT License
1 stars 0 forks source link

Optimization improvements #156

Closed wsdewitt closed 2 months ago

wsdewitt commented 5 months ago

This issue proposes fixes for a few interconnected technical problems with optimization:

A notebook prototyping the ideas below: https://github.com/matsengrp/multidms/blob/149_custom_optimization_loop/notebooks/convergence_testing_2024-03-13/wsd.ipynb

Modeling and transformations for better-conditioned gradients

Inadequate convergence is likely caused by data/parameter scaling pathologies in gradient moves. Bundle mutations in non-reference conditions have most variants in the hot state, whereas non-bundle sites are mostly cold. This scaling discrepancy can be seen by plotting the row-sums of the encodings (left plot): image

This creates a ridge in parameter space defined by the subspace of non-bundle mutations: small changes in the bundle-associated parameter will result in huge changes in loss, whereas this isn't so for the non-bundle parameter.

Equivalent parameterization

Given any 1-hot encoding $X_d\in{0,1}^{n_d\times M}, d=1\dots,D$, we pose the regression problem as

$$ y{d} = g\theta\left(\beta_{0,d} + \summ\beta{m,d} \ x{m,d}\right) + \epsilon{d}. $$

The shifts of mutation effects between conditions $d$ and $d'$ are

$$ \Delta{m,d,d'} = \beta{m,d} - \beta_{m,d'}, $$

and the shifts between WTs are

$$ \alpha{d,d'} = \beta{0,d} - \beta_{0,d'}. $$

We will usually use a 1-hot encoding referenced to the WT sequence from one of the $D$ conditions, so that the mutation effects and shifts are interpreted wrt to that condition. This introduces parameter scaling problems in the non-reference conditions that confound gradient methods, arising from the set of "bundle" sites that define the divergence from the reference encoding (more on this below). For now, note that the choice of reference condition for interpreting shifts can be made a posteriori by fixing $d'$, recovering the parameterization in the main text.

Data scale transformation

To appropriately scale regression parameters, we need condition-specific 1-hot encodings, each wrt the WT in that condition. We get this by bit flipping the "bundle" mutations in all conditions. Now each condition has data referenced to its own WT. Now the scaling across conditions is comparable: image

Parameter transformation

Now we need to clarify how parameters for the rescaled data relate to those for the original data. Consider just one condition, so that we suppress the index $d$ below for simplicity. Let $\mathcal{B}$ be the set of indices of the columns of $X$ corresponding to the bundle. We want to work with rescaled data $1-x_m$ at sites $m\in\mathcal{B}$ in the bundle, and $x_m$ at sites $m\notin\mathcal{B}$ not in the bundle. Let $\tilde\beta_0$ and $\tilde\beta_m$ be the intercept and coefficients in the rescaled problem, and $\beta_0$ and $\beta_m$ be the intercept and coefficients in the original data. The equivalence of latent phenotype predictions between scales requires

$$ \tilde\beta0 + \sum{m\notin\mathcal{B}}\tilde\beta_m xm + \sum{m\in\mathcal{B}}\tilde\beta_m(1 - x_m) = \beta_0 + \sum_m\beta_m x_m. $$

Matching like terms, we find

$$ \beta_0 = \tilde\beta0 + \sum{m\in\mathcal{B}}\tilde\beta_m = \tilde\beta0 - \sum{m\in\mathcal{B}}\beta_m, $$

and

$$ \beta_m = \begin{cases} \tilde\beta_m, & m\notin\mathcal{B} \ -\tilde\beta_m, & m\in\mathcal{B} \end{cases}. $$

Note that $\tilde\beta_m^{\mathrm{ref}} = \beta_m^{\mathrm{ref}}$, since $\mathcal{B}^{\mathrm{ref}}=\emptyset$. Note also that the transform operation is its own inverse. ~With the above, we can transform into the original parameterization, perform proximal updates, and then transform back into the scaled parameterization.~ EDIT: no, we can't cavalierly treat the gradient and prox with different parameters, see later comment

In the prototype notebook, using this scaling transformation for data and parameters improves convergence behavior dramatically. image image

Here are resulting correlations with observed functional scores image image image

Reference-experiment equivariance during optimization

We currently choose a reference experiment a priori and penalize only the differences between each non-reference experiment and the reference experiment. For this section, consider one mutation, so we suppress the index $m$ below for simplicity. Our mutation effects across conditions are the vector $β\in\mathbb{R}^D$, and our shift penalty is

$$ \lambda\|\Delta\|_1 = \lambda\|\mathfrak{D} β\|_1, $$

where, assuming $d=1$ is the reference experiment and $D=3$, we use the difference matrix

$$ \mathfrak{D} = \begin{bmatrix} -1 & 1 & 0\ 0 & 0 & 0\ -1 & 0 & 1 \end{bmatrix}. $$

The above formulation is equivalent to the current model, but is expressed in the form of the generalized lasso problem (which considers an arbitrary $\mathfrak{D}$). In our $\mathfrak{D}$, the row of zeros captures the fact that we don't penalize shifts between non-reference experiments. This lack of symmetry limits the potential sparsity structure of the model. In particular, it's not possible for two non-reference experiments to have the same mutation effect that is different from the reference experiment. If an epistatic interaction has arisen only on the branch leading to the reference experiment, we can't discover that shift pattern. If we change the reference experiment, we get a different limitation on shift sparsity pattern discovery.

It would be more natural to infer a reference-equivariant model that fuses across all conditions symmetrically, after which an a posteriori choice of reference experiment can be made for mutation effect interpretation. We would use the difference matrix

$$ \mathfrak{D} = \begin{bmatrix} -1 & 1 & 0\ 0 & -1 & 1\ -1 & 0 & 1 \end{bmatrix}, $$

where all shifts are penalized symmetrically.

~Finally, we can combine this with the parameter transformation above by transforming parameters back to the original scale, applying proximal steps in $\beta$, then retransforming to $\tilde\beta$ (taking care to update the intercepts correctly).~ EDIT: no, we can't cavalierly treat the gradient and prox with different parameters, see later comment on what we actually need to do.

ADMM approach

We next derive the proximal operator for the symmetric fusion problem above, using standard augmented Lagrangian methods for the generalized lasso (see citations below). We have a special case of the generalized lasso. Our penalized problem is

$$ \min_{\beta} f(\beta) + \lambda\|\mathfrak{D}\beta\|_1, $$

where $f$ is the smooth piece of the objective. To use FISTA, we'll need the gradient $\nabla f$ and the proximal operator $\mathrm{prox}_{\lambda}$, defined as

$$ \mathrm{prox}_{\lambda}(x) = \arg\min_{y}\frac{1}{2}\|x - y\|_2^2 + \lambda\|\mathfrak{D}y\|_1, $$

which is equivalent to the constrained problem

$$ \mathrm{prox}_{\lambda}(x) = \arg\min_{y, z}\frac{1}{2}\|x - y\|_2^2 + \lambda\|z\|_1 \quad \text{s.t.} \quad \mathfrak{D}y - z = 0. $$

This can be evaluated by ADMM iterations with dual variable $u$:

$$ \begin{aligned} y^{(k+1)} &= (I + \rho \mathfrak{D}^\intercal \mathfrak{D})^{-1}(x + \rho \mathfrak{D}^\intercal(z^{(k)} - u^{(k)}))\ z^{(k+1)} &= S_{\lambda/\rho}(\mathfrak{D}y^{(k+1)} + u^{(k)})\ u^{(k+1)} &= u^{(k)} + \mathfrak{D}y^{(k+1)} - z^{(k+1)}, \end{aligned} $$

where $S_{\lambda/\rho}$ is the soft-thresholding operator and $\rho>0$ is the ADMM step-size. These iterations rapidly converge in terms of the primal and dual residuals:

$$ \begin{aligned} \text{primal residual} &= \|\mathfrak{D}y^{(k+1)} - z^{(k+1)}\|_2\ \text{dual residual} &= \rho\|D^\intercal(z^{(k+1)} - z^{(k)})\|_2. \end{aligned} $$

To parallelize across all mutations, we simply stack in the column dimensions of $y, z, u$. Thus, we have the proximal operator necessary for the symmetrically penalized model. A prototype ADMM prox function is implemented in the notebook linked at the top.

See:

jgallowa07 commented 5 months ago

Thanks, @WSDeWitt !

I'm slowly working my way through this, and will probably have more comments/questions but here's a few I'd like to start with:

Convergence results

I'm trying to test this now, but it seems the results you posted were using the reference-equivariant $\mathfrak{D}$ in the prox_scale_TEST. Have you confirmed the convergence behaves using only the bit flipping transformation, and a lasso penalty?

Typo's in $D$ ?

Above you note that the Difference Matrix that would replicate the current model is:

$$ \mathfrak{D} = \begin{bmatrix} -1 & 1 & 0\ 0 & 0 & 0\ -1 & 0 & 1 \end{bmatrix}. $$

However, If $\beta = [2, 1, 2]$ for example, then $\mathfrak{D}\beta=[-1, 0, 0]$. This seems incorrect. Alternatively, if

$$ \mathfrak{D} = \begin{bmatrix} 1 & -1 & 0\ 0 & 0 & 0\ 0 & -1 & 1 \end{bmatrix}. $$

Then we get $\mathfrak{D}\beta=[1, 0, 1]$. Is this what you were meaning to write? Or am I missing something?

The need for a reference-equivariant $\mathfrak{D}$:

If an epistatic interaction has arisen only on the branch leading to the reference experiment, we can't discover that shift pattern.

I'm not sure I follow, would the epistatic effect not be picked up in the shifts calculated above?

If we change the reference experiment, we get a different limitation on shift sparsity pattern discovery.

To be clear, if we were to choose $d = 0$ as the reference, then our difference matrix would presumably be defined as

(edited):

$$ \mathfrak{D} = \begin{bmatrix} 0 & 0 & 0\ -1 & 1 & 0\ -1 & 0 & 1 \end{bmatrix}. $$

The, (keeping the same $\beta$ as above), $\mathfrak{D}\beta=[0, -1, 0]$. So indeed, the limitation changes, but wouldn't the sweep across $\lambda L1$ correct for this to make the different choice of reference essentially the same?

wsdewitt commented 5 months ago

Thanks for the questions, @jgallowa07. If $\beta = [\beta_1, \beta_2, \beta_3]^\intercal$ for conditions 1, 2, 3, then with the first $\mathfrak{D}$ we get

$$ \mathfrak{D}\beta=\begin{bmatrix} \beta_2-\beta_1\ 0\ \beta_3-\beta_1 \end{bmatrix}, $$

and the penalty is $\|\mathfrak{D}\beta\|_1 = |\beta_2-\beta_1| + |\beta_3-\beta_1|$, which matches what we currently do, assuming index $1$ refers to the reference condition. Note that, in general, we would have $D(D-1)/2$ shifts, so the indices of the shifts can't generally align with the indices of $\beta$. However, I agree with you that it might be convenient to align them for the special case $D=3$ (for which $D(D-1)/2 = 3$).

For the second question, the problem is that it's impossible to find a sparsity pattern that has condition 1 shifted wrt conditions 2 and 3 while 2 and 3 are not shifted wrt eachother, unless we lasso conditions 2 and 3 to eachother the same way we lasso each to condition 1. We currently aren't doing that, so we can't find this pattern, and if we change the reference this limitation changes with it.

I didn't follow the last point about the suggested difference matrix. If you multiply your matrix by $\beta = [\beta_1, \beta_2, \beta_3]^\intercal$, you get

$$ \mathfrak{D}\beta=\begin{bmatrix} 0\ -\beta_1\ -\beta_1 - \beta_2 + \beta_3 \end{bmatrix}. $$

None of the elements of the result seem to correspond to any shift we'd want to penalize (each should be a difference $\beta_i-\beta_j$ for some $i\ne j$).

wsdewitt commented 5 months ago

I updated various things in the notebook. It now handles a generalized lasso, defined by a $3\times3$ difference matrix $\mathfrak{D}$ in cell 26. This is currently set to use the reference-based penalties like in the current code. In the same cell there is the equivariant version that can be uncommented to use instead.

An important point that I hadn't appreciated before is that the parameter transformation gives us a different penalization because we need to consider sign flips coming from the pattern of bundle/non-bundle identities across conditions for a mutation. I previously wrote that we could simply transform back to apply the prox in the standard parameters, then retransform. However, I don't think that's valid—FISTA isn't expected to converge if we use different parameterizations for the smooth term and the non-smooth term.

Instead, we need to specify $\tilde{\mathfrak{D}}$ such that $\tilde{\mathfrak{D}}\tilde\beta = \mathfrak{D}\beta$, which is achieved by setting $\tilde{\mathfrak{D}} = \mathfrak{D}(I - 2\mathrm{diag}(B))$, where $B$ is the 1-hot encoding of the bundleness across conditions at this site. The diagonal matrix $(I - 2\mathrm{diag}(B))$ has -1s for conditions for which the site is in the bundle, and 1s for conditions for which the site is not in the bundle. This also applies to the reference-based lasso penalty in the scaled parameterization: we end up with the generalized lasso problem with penalty $\|\tilde{\mathfrak{D}}\beta\|_1$. Thankfully the prox can be obtained with the linear ADMM using standard packages.

Here are some results using the reference-based $\mathfrak{D}$, transformed into $\tilde{\mathfrak{D}}$ using our bundle info, and applying FISTA with the ADMM-based prox (all in the transformed $\tilde\beta$ parameters). 178554d6-e54b-4e76-9e38-7a8298385882 ec327ef9-14ea-4da9-b93c-db00eea71a18 ced98510-36e8-4359-bcfb-a87216b3d686

jgallowa07 commented 5 months ago

Oops, my CS is showing. I was 0-indexing which confused a few things. I was also assuming we needed the resulting vector of $\mathfrak{D}\beta$ to be ordered - but I guess that's not the case. Thanks for the clarification!

I updated various things in the notebook.

Thanks for the updates! Taking a closer look now.

jgallowa07 commented 4 months ago

Lots of discussion and results from my investigation into this approach have been communicated via slack or zoom, but to summarize the results:

Here are the summary figures from the working #4 obj approach for documentation generalized_lasso_objective_comparison

@WSDeWitt The notebook that produced all the results discussed so far can be found here. This can be executed directly, or invoked across a grid or papermill parameters (in parallel) using the the nb_papermill.py script, paired with the desired params.json. Like

$ python nb_papermill.py --nb fit_generalized_lasso.ipynb --params params.json --output results/stronger_ridge_manu --nproc 24

This would output a directory with all the individual resulting notebooks, and their respective outputs to be collected and analyzed using code like that in papermill_results.ipynb.

It would be great if you could:

  1. Review the implementation of the post-latent sigmoid to make sure that's correct,
  2. Based on our conversation last week - implement and play with the smooth peice of the objective function.
  3. Prove that we can get speedy-ish convergence across all the manuscript lasso weights and replicates (this is where the papermill pipeline may come in handy)

In the meantime, I'm going to start implementing this into the multidms source code. I'll be branching off of the non_linear_generalized_lasso branch that the notebooks above exists in currently, so feel free to push there If you'd like.

Let me know if you have any questions - and thanks!