lucidrains / enformer-pytorch

Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch
MIT License
434 stars 81 forks source link

Computing Contribution Scores #13

Open Prakash2403 opened 1 year ago

Prakash2403 commented 1 year ago

From the paper:

To better understand what sequence elements Enformer is utilizing when making predictions, we computed two different gene expression contribution scores — input gradients (gradient × input and attention weights

I was just wondering how to compute input gradients and fetch the attention matrix for the given input. I'm not well versed with PyTorch, so I'm sorry if this is a noob question.

SebastienLemaire commented 1 year ago

Hi, I have the same question. How to compute the contribution scores (input x gradient) and attention. If I missed a response on this issue, I apologize in advance.

jianghao-zhang commented 7 months ago

Similar confusion here.

jstjohn commented 7 months ago

Check out captum. Let’s you do all kinds of these kinds of things. For reference what you’re trying to do is called “salience analysis” if you want to look up other methods of doing it. inputs x gradients is one way to do it and they have that method available. https://captum.ai

Here’s a chat gpt session where I reminded myself how to do this yourself: https://chat.openai.com/share/d048a71d-effe-4114-ba4a-6f296c550f40

The key thing is you want to think about what you are interested in explaining. The example where I take the absolute value of the sum of a slice could be where you first find a predicted peak in some channel of the output of the model and you want to explain that peak. I forget if the output is in log space, if so you’ll want to exponential it instead and take the sum of the slice after that. Does that make sense? Then the input x gradient of the sum of the area under a peak will tell you what in the input explains that peak.

Good luck!

On Fri, Apr 26, 2024 at 6:56 AM Jianghao Zhang @.***> wrote:

Similar confusion here.

— Reply to this email directly, view it on GitHub https://github.com/lucidrains/enformer-pytorch/issues/13#issuecomment-2079448038, or unsubscribe https://github.com/notifications/unsubscribe-auth/AADQCBWFQBAAOVGQEJGNGUTY7JMITAVCNFSM6AAAAAASS7G2GCVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDANZZGQ2DQMBTHA . You are receiving this because you are subscribed to this thread.Message ID: @.***>

jstjohn commented 7 months ago

The other thing is you need to turn your question into a single score out of the model. So that’s why for example I’m talking about the sum of a slice in the output (after making sure it’s in the right space etc, like if it’s logs then a very large negative is a near 0, so the result won’t mean “this peak is big” like you want, so you’ll want to add a .exp() to the calculation, like score = model(input)[“human”][0,25,43:54].exp().sum() or something along those lines. That will give you a scalar score where you can look at the gradient to tell you what little changes to the input would make the score bigger or smaller, if you remember that from calculus.

On Fri, Apr 26, 2024 at 7:42 AM John St. John @.***> wrote:

Check out captum. Let’s you do all kinds of these kinds of things. For reference what you’re trying to do is called “salience analysis” if you want to look up other methods of doing it. inputs x gradients is one way to do it and they have that method available. https://captum.ai

Here’s a chat gpt session where I reminded myself how to do this yourself: https://chat.openai.com/share/d048a71d-effe-4114-ba4a-6f296c550f40

The key thing is you want to think about what you are interested in explaining. The example where I take the absolute value of the sum of a slice could be where you first find a predicted peak in some channel of the output of the model and you want to explain that peak. I forget if the output is in log space, if so you’ll want to exponential it instead and take the sum of the slice after that. Does that make sense? Then the input x gradient of the sum of the area under a peak will tell you what in the input explains that peak.

Good luck!

On Fri, Apr 26, 2024 at 6:56 AM Jianghao Zhang @.***> wrote:

Similar confusion here.

— Reply to this email directly, view it on GitHub https://github.com/lucidrains/enformer-pytorch/issues/13#issuecomment-2079448038, or unsubscribe https://github.com/notifications/unsubscribe-auth/AADQCBWFQBAAOVGQEJGNGUTY7JMITAVCNFSM6AAAAAASS7G2GCVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDANZZGQ2DQMBTHA . You are receiving this because you are subscribed to this thread.Message ID: @.***>

jstjohn commented 7 months ago

Last thing, please take time to understand channel definitions etc out of the model. I know the model makes predictions for some number of quantitative peak tracks and it does this for some number of positions, and does this for every element of your input batch. Take time to understand the meaning of the output so that your scalar “question” score you want to explain is properly crafted and makes sense.

On Fri, Apr 26, 2024 at 7:51 AM John St. John @.***> wrote:

The other thing is you need to turn your question into a single score out of the model. So that’s why for example I’m talking about the sum of a slice in the output (after making sure it’s in the right space etc, like if it’s logs then a very large negative is a near 0, so the result won’t mean “this peak is big” like you want, so you’ll want to add a .exp() to the calculation, like score = model(input)[“human”][0,25,43:54].exp().sum() or something along those lines. That will give you a scalar score where you can look at the gradient to tell you what little changes to the input would make the score bigger or smaller, if you remember that from calculus.

On Fri, Apr 26, 2024 at 7:42 AM John St. John @.***> wrote:

Check out captum. Let’s you do all kinds of these kinds of things. For reference what you’re trying to do is called “salience analysis” if you want to look up other methods of doing it. inputs x gradients is one way to do it and they have that method available. https://captum.ai

Here’s a chat gpt session where I reminded myself how to do this yourself: https://chat.openai.com/share/d048a71d-effe-4114-ba4a-6f296c550f40

The key thing is you want to think about what you are interested in explaining. The example where I take the absolute value of the sum of a slice could be where you first find a predicted peak in some channel of the output of the model and you want to explain that peak. I forget if the output is in log space, if so you’ll want to exponential it instead and take the sum of the slice after that. Does that make sense? Then the input x gradient of the sum of the area under a peak will tell you what in the input explains that peak.

Good luck!

On Fri, Apr 26, 2024 at 6:56 AM Jianghao Zhang @.***> wrote:

Similar confusion here.

— Reply to this email directly, view it on GitHub https://github.com/lucidrains/enformer-pytorch/issues/13#issuecomment-2079448038, or unsubscribe https://github.com/notifications/unsubscribe-auth/AADQCBWFQBAAOVGQEJGNGUTY7JMITAVCNFSM6AAAAAASS7G2GCVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDANZZGQ2DQMBTHA . You are receiving this because you are subscribed to this thread.Message ID: @.***>