jbloomlab / polyclonal

Model mutational escape from polyclonal antibodies.
Other
3 stars 0 forks source link

A regularization term that penalizes similarity between epitope escape profiles #124

Closed timcyu closed 1 year ago

timcyu commented 1 year ago

I'm proposing to introduce a new regularization term that penalizes the similarity between epitope escape profiles. The reason why this may be useful is @caelanradford and others are getting identical looking epitope escape profiles with sera. These profiles sometimes show multiple peaks that look like they should separate into their own epitopes.

Here's an idea that I think can be applied when fitting the site-wise model. It penalizes the dot product of site-wise escape values for each pair of epitopes. So, it should encourage any escape mutations to have 0 effect on escape at all but one of the epitopes. I remember this is similar to something that @matsen suggested a long time ago for torchdms too.

Screen Shot 2022-09-26 at 8 29 59 PM

I think we should also consider trying to reproduce this phenomena with the simulated dataset. My guess is we might get it if we increase the noise. Or we could just try and work directly with @caelanradford's data or Frances' cocktail selection data if it's available.

WSDeWitt commented 1 year ago

This is a nice idea. Can you clarify what are $s_i$ and $s_j$ (I thought $s$ was a site index)? I’m assuming you mean something like: for each epitope you compute an $S$-vector of norms of the $\beta$ coefficients over mutations at each site $1,\dots,S$ (measuring the overall escapiness of each site for that epitope), then you take the dot products of these vectors from pairs of epitopes.

timcyu commented 1 year ago

Sorry could've made this more clear @WSDeWitt! The $s$ is a site index. I believe the site-level model fits $\beta{s,e}$ coefficients for each site $s$ at each epitope $e$. I was thinking of penalizing the dot products between $\beta{s,i}$ and $\beta{s,j}$ for all unique pairs of epitopes ($i$, $j$), not the site-wise norms of the $\beta{m,e}$ coefficients. However, the latter should work too.

matsen commented 1 year ago

I like the dot product idea. If I understand correctly, it's the sum of $\beta{s, i} \ \beta{s,j}$ across all sites $s$ and pairs $i,j$ of epitopes.

It seems like ideally we want the dot product for each $i,j$ to be zero-- each pair of vectors is orthogonal. However, in theory it seems like with your original idea one could have some $i,j$ pairs be positive, and others negative, which could cancel each other out. This would be fixed by taking the absolute value for each $i,j$ pair, but then we'd be in more difficult gradient-descent land. I wouldn't suggest going there unless we see that this becomes a problem.

zorian15 commented 1 year ago

I really like this idea as well @timcyu !

Re: potential optimization challenges highlighted by @matsen -- could we use a differentiable approximation of the absolute value $\sqrt{x{ij}^2 + c}$, where $x{ij}$ is the dot product between each $i,j$ pair, and $c$ could be a super small positive number?

timcyu commented 1 year ago

Thanks @matsen @zorian15! The differentiable approximation sounds good to me. I believe the updated term would look like this.

Screen Shot 2022-09-27 at 9 47 26 PM
WSDeWitt commented 1 year ago

I see, thanks for clarifying—this is a site-level model (i.e. without an AA-wise dimension in the coefficients).

In addition to Erick's concern about the possibility of additive cancelation over the different inner products, I also wanted to point out another possible issue: for each inner product, if the vector elements can be positive or negative you can get orthogonality without disjoint site-wise partitioning of effects into epitopes. If the parameter values are very strongly driven (by the data fitting term) to all be positive (or all negative), then maybe this is less of a problem, and regularization would find orthogonal vectors that are also approximately site-wise disjoint, but I'm not sure about that.

Some thoughts on yet another approach, using a matrix norm

With this $S\times E$ parameter matrix $\beta$, you're saying the columns should be orthogonal. Another way to say this is that the so-called Gram matrix (a matrix of inner products of a set of vectors) formed from the columns of $\beta$ is diagonal. E.g. if the columns of $\beta$ were also unit norm (orthonormal) you'd have

$$ \beta^\intercal\beta = I, $$

where $I$ is the identity matrix. An L2-smooth penalty approach would be:

$$ \lambda \Vert \beta^\intercal\beta - I \Vert_F^2 $$

The squared Frobenius norm amounts to squaring the summand in your original suggested penalty, which resolves Erick's concern about additive cancelation of the inner products.

Generalizing slightly to account for the diagonal not being all 1s (orthogonal but not orthonormal), this becomes

$$ \lambda \Vert (\beta^\intercal\beta) \odot (\mathbf{1} - I) \Vert_F^2, $$

where $\mathbf{1}$ denotes the matrix containing all 1s, so $\mathbf{1} - I$ is the "complement" of the identity matrix, and pulls out the off-diagonal elements when used in an element-wise product (indicated with $\odot$).

The advantage to a matrix norm penalty like this is that it's easy to evaluate the matrix derivative wrt $\beta$ to get a gradient that's also expressed as efficient matrix operations.

This paper explores several differentiable approaches for orthogonality regularization that purport to improve on the "soft" Frobenius penalty above.

matsen commented 1 year ago

That's awesome, @WSDeWitt -- thanks for thinking about this. That seems like the right approach.

@timcyu -- I'd suggest trying out your original idea, even though it might have the cancellation problem and the other problem that Will described, because it'll be easy (recalling that we need to hand-code gradients here). Maybe it's enough!

timcyu commented 1 year ago

Okay sounds good @matsen @WSDeWitt!

timcyu commented 1 year ago

An update on this: after chatting with @jbloom we thought it would be best if the penalty is with respect to the $\beta{m,e}$, not the site-wise $\beta{s,e}$. The revised penalty I'm trying first looks like this:

$$ R{\rm{similarity}} = \lambda{\rm{similarity}} \sum{(i,j) \in P} \sum{k} \left\Vert\beta_{k,i}\right\Vert2^2 \left\Vert\beta{k,j}\right\Vert_2^2 $$

where $P$ is the set of all pairs $i$, $j$ of epitopes, $k$ is the site number, and $\beta_{k,i}$ is a vector of all escape values for mutations at site $k$ and epitope $i$. I'm also using the squared Euclidean norm.

I think I'm close (?) but having some trouble with the gradient, which I don't think I have right. This was my attempt at deriving the two-epitope case (with just one pair $i$, $j$ of epitopes) using the product rule:

$$\frac{\partial{R{\rm{similarity}}}}{\partial{\beta{m,e}}} = 2 \lambda{similarity} \sum{k} \left( \left\Vert\beta_{k,j}\right\Vert2^2 \beta{k,i} + \left\Vert\beta_{k,i}\right\Vert2^2 \beta{k,j} \right) $$

In the summation, I get two vectors of shape (m_k, 1), where $m_k$ is the number of mutations at site $k$. I'm not sure if this is right, but I combine these to get a matrix of shape (m_k, 2) and then stack on the vectors that result from all the remaining mutations at other sites to end up with a matrix that is the same shape as the parameters $\beta_{m,e}$.

Where am I going wrong here? I figured if I could get this to work it would then be easy to generalize to when there are more than two epitopes. Hope this description makes sense and sorry I haven't tried deriving a function in so long...

matsen commented 1 year ago

Here is Tim's equation that wasn't rendering for me:

image


I think things are just getting a little muddled with your intermediate use of vectors. The partial derivative of the penalty with respect to a single parameter should be a scalar.

If I understand correctly, you would like this penalty, where $S_{k,i}$ is the set of mutations for site $k$ with epitope $i$:

$$ R{\rm{similarity}} = \lambda{\rm{similarity}} \sum{(i,j) \in P} \sum{k} \left( \sum{m \in S{k,i}} \beta{m,i}^2 \right) \left( \sum{m \in S{k,j}} \beta{m,j}^2 \right) $$

Is that correct? If so, I think continuing with the product rule as you were before should get you to the solution.

jbloom commented 1 year ago

Also, to add to above, important that equation denotes that you are summing only over $i \ne j$. Right now it indicates you are summing over all combinations of i and j, but it should be only the ones with $i \ne j$.

In fact, I think it would be good to write out two separate summands, because the quantitative value of the penalty will differ by two depending on if you are doing $\sumi \sum{j \ne i}$ versus $\sumi \sum{j \gt i}$. I'd recommend the latter as there isn't any double counting, although either is technically fine. But the equation should be written in a way that makes it clear.

timcyu commented 1 year ago

Yes, $P$ was supposed to denote only unique pairs of $i$ and $j$, so it would get rid of double counting. For more clarity, I will use the double summand from now on!

jbloom commented 1 year ago

@timcyu (also @matsen @WSDeWitt), just thinking more about the regularization that Tim is implementing here.

Originally we had discussed just having the penalty be the dot product of the $\beta{m,e}$ values. But we decided (correctly I think) that we wanted it to operate at the site level, so instead we aggregated the values at each site, currently as the sum of the squares of the $\beta{m,e}$ values are then are penalizing the product of that across epitopes.

But effectively, we are now regularizing on what would be the square of the dot products if we went back to the mutation-level model. I worry that this is too strong of a penalty: it means that the penalty will grow very rapidly if there is any epitope overlap at a site. In general we want to avoid such overlap, but we shouldn't be overly dramatic in the penalty if there is some.

I wonder if we could think of instead regularizing on something like the product of the L2-norm at each site, which would make it more like the original idea of normalizing on the dot product. Would that still be differentiable?

Additionally, I wonder if the L2-norm we are using should be normalized by the number of mutations observed at each site? This isn't as immediately obvious to me, but right now the penalty will apply more to sites with more mutations observed at them in the libraries (as there is no normalization but the number of mutation values per site that are observed), and I'm not sure that is desirable?

matsen commented 1 year ago

Good point, @jbloom .

I just want to make sure I understand your new idea.

Is it

$$ R{\rm{similarity}} = \lambda{\rm{similarity}} \sum{(i,j) \in P} \sum{k} \sqrt{ \sum{m \in S{k,i}} \beta{m,i}^2 } \sqrt{ \sum{m \in S{k,j}} \beta{m,j}^2 } $$

with P being the unique pairs?

jbloom commented 1 year ago

Yes, something like that is what I was thinking, @matsen. Is that still differentiable? If not, something with similar behavior that is.

(Also, side note that @timcyu and I figured out: $Sk$ does not depend on $i$ or $j$, so the $S{k,i}$ values can just be written as $S_k$.)

matsen commented 1 year ago

Unfortunately not differentiable, but we could use @zorian15 's suggestion https://github.com/jbloomlab/polyclonal/issues/124#issuecomment-1259642760

timcyu commented 1 year ago

Okay, I can try using @zorian15's absolute value approximation and I agree this would ease the penalty. I also like the idea of normalizing the penalty applied to each site by the number of mutations so it doesn't bias towards sites with more mutations, so I will try dividing the products of the norms by $M_k$, the number of mutations at site $k$.

matsen commented 1 year ago

This is where using JAX would be really nice, to be able to try out various things without having to write out the gradient of each thing.

I do wonder if there is something you could do to test out the idea without having to write out the gradient in the current context.

For example, imagine $\theta_1$ is the complete parameter set without any regularization (lots of overlap) and $\theta_2$ is the complete parameter set with the current "squared" regularization. If we take $\theta = w \theta_1 + (1-w) \theta_2$, where $w$ is a weight between 0 and 1, do you get a minimum for the L2 version of the loss for an intermediate value of $w$?

There are lots of reasons why this wouldn't work (we can't expect such combinations to be good) but if we got a lower value for the L2 version of the loss with an intermediate $w$ I'd feel more confident investing time into the gradient descent.

jbloom commented 1 year ago

I sort of feel like the gradient should not be that complicated (I feel like differentiating the function with the square roots looks about like the difficulty of a high-school calculus problem), and it may be easier to just to do that than try to implement the thing above?

jbloom commented 1 year ago

Anyway, @timcyu, it's up to you if you want to work on this first or do one pull request with your current method, another with the new method, and we can compare and decide which to merge.

WSDeWitt commented 1 year ago

But effectively, we are now regularizing on what would be the square of the dot products if we went back to the mutation-level model

Just noting that precisely this issue is discussed in the paper I linked to previously (the soft Frobenius penalty). This is commonly done, so is probably worth trying (although the paper suggests alternatives). As Erick notes, using norms (instead of squared norms) makes the penalty non-smooth, and requires fancier optimization methods.

This is where using JAX would be really nice, to be able to try out various things without having to write out the gradient of each thing.

^Retweet. Automatic differentiation is more numerically accurate than hand-derived formulas, much more numerically efficient, requires no pen & paper work, and has a numpy-like API. I think hand deriving/coding the gradients only has drawbacks, if you want to iterate quickly on penalization approaches.

timcyu commented 1 year ago

I think best way forward for now would be for me to make a separate PR implementing the absolute value approximation, which I don't think will be too difficult to adapt from the current PR. I think getting one of these merged soon would be a good idea, as the current method seems to already be helpful for understanding @caelanradford's data.

After merging that, I would definitely support switching to JAX. I just think it would take a lot more time as I'm personally not familiar with it but for the reasons @matsen @WSDeWitt outlined above I think it would worth it.

jbloom commented 1 year ago

I agree with above. I think considering implementing JAX is a good idea, but it's clearly more of a long term direction that may require some significant re-write.

Although I don't feel super strongly, I also feel that implementing JAX may be something that is more efficiently led by someone who is already quite familiar with it than me or Tim.

WSDeWitt commented 1 year ago

Here's a colab notebook with a demo implementation of the soft orthogonality regularizer in JAX.

matsen commented 1 year ago

Thanks, Will! Very nice.

I agree with everything said so far. My thoughts are

I think this is the right idea scientifically, though I don't want to complicate things re publications, etc. Jesse, I'd be happy to hear your thoughts.

timcyu commented 1 year ago

Thanks Will! I will leave it up to @jbloom, but I'd be interested in working on this, either on helping bring the model into multidms or a "from scratch" JAX implementation of polyclonal that would keep the model separate. The latter won't be the most efficient but I think it'd be a good challenge for myself.

WSDeWitt commented 1 year ago

I think I shared this a few months ago, but here's a notebook that implements a minimal JAX global epistasis model, and may be fun/informative to play with. It uses Tyler's Ab-CGGnaive_DMS data. This was the inspiration for how things are being implemented in multidms.

jbloom commented 1 year ago

I think we should split the issue of re-implementing in JAX off from this epitope similarity regularization issue.

This regularization is pretty straightforward to implement into the existing codebase.

I agree a JAX implementation could be advantageous for the reasons to discuss above, but it seems like it would require some extensive re-writing or just de novo package creation to implement. I am totally supportive if @matsen or @WSDeWitt or others want to do this re-implementation, but I think it should be separated from the minor extension of the current code in question. I am not sure that it is sensible / realistic for @timcyu to do this, as neither he nor I have any experience with JAX and given his interests, I think it makes more sense right now for him to work on applying the existing package rather than re-implementing. If the new JAX implementation can retain the current interface, it should be easy to swap it in at some future date. Note that the following issues and pull requests are already open with respect to JAX: #31, #32, #41,