jbloomAus / DecisionTransformerInterpretability

Interpreting how transformers simulate agents performing RL tasks
https://jbloomaus-decisiontransformerinterpretability-app-4edcnc.streamlit.app/
MIT License
61 stars 15 forks source link

Folding Layer Norm in Model Loading #71

Closed jbloomAus closed 1 year ago

jbloomAus commented 1 year ago

It's occured to me that my failure to fold layernorm is probably making analysis harder which isn't ideal. A quick attempt at solving this has presented the following problems:

Possible solutions:

I'll try one very quickly. Chat GPT can do it.

jbloomAus commented 1 year ago

Monkey patch working.

Seems like final_ln does in fact get used in the forward pass meaning that I should fold it into the models weights. The issue is that the prediction weights (as a linear layer) projecting onto logits, isn't part of the transformer which this patching code gets to see. The result is that I need to:

Very ugly. Very ugly.

jbloomAus commented 1 year ago

Ok, got it to work. and wrote a test for it.

moodlep commented 1 year ago

@jbloomAus I was just about to ask this question on the TransformerLens repo and then I found your issue! I am having the same problem - I created a from_pretrained() to load a saved decision transformer model, adapted the mingpt converter for my model and it works if I don't fold layernorm but when I do, I receive an error message about missing keys, all of the ln.w and ln.b keys.

Looking at the code I can see these keys are deleted after folding, so I don't know why I am getting the error if this is part of the procedure... I thought maybe I was doing something wrong.

Anyway I am grateful to find you have resolved this; really appreciate your efforts. I have been hitting lots of the same issues and then I realise you have solved them. Makes me feel a bit more sane, especially as I am new to this area!

I hope to be able to contribute at some stage but I am writing up and have to finish by the end of the year

jbloomAus commented 1 year ago

Hey Perusha,

I'm so glad you found this useful! I started writing up these cards in a very low effort way so wasn't sure they'd be useful to anyone but me! I'm continuing to work on this project, though my focus has been much more on analysis / interpretability lately. I expect an engineering sprint will happen after I finish my current analytical project.

I'd be happy to discuss contributions, especially since the issues on this repo are likely a bit stale!

On Tue, Aug 22, 2023 at 8:14 PM Perusha @.***> wrote:

@jbloomAus https://github.com/jbloomAus I was just about to ask this question on the TransformerLens repo and then I found your issue! I am having the same problem - I created a from_pretrained() to load a saved decision transformer model, adapted the mingpt converter for my model and it works if I don't fold layernorm but when I do, I receive an error message about missing keys, all of the ln.w and ln.b keys.

Looking at the code I can see these keys are deleted after folding, so I don't know why I am getting the error if this is part of the procedure... I thought maybe I was doing something wrong.

Anyway I am grateful to find you have resolved this; really appreciate your efforts. I have been hitting lots of the same issues and then I realise you have solved them. Makes me feel a bit more sane, especially as I am new to this area!

I hope to be able to contribute at some stage but I am writing up and have to finish by the end of the year

— Reply to this email directly, view it on GitHub https://github.com/jbloomAus/DecisionTransformerInterpretability/issues/71#issuecomment-1688782395, or unsubscribe https://github.com/notifications/unsubscribe-auth/AQPMYZZHHMJ6JD3C5YI4ETDXWUAKJANCNFSM6AAAAAAYAFMUYE . You are receiving this because you were mentioned.Message ID: @.*** .com>

moodlep commented 1 year ago

That's great! I look forward to it! Happy to have a chat, also about the analysis side of things. I feel very much a beginner there but I am working through the material - definitely a few months behind!

Just some background of what I am trying to do - I was hoping to use some of the MI techniques to analyse my model. I am running Vizdoom (HGS and Deadly Corridor for now) on the decision transformer. I have a 3-layer model trained for MI purposes and would like to perform logit attribution, look at the attention heatmaps, activation patching, etc. comparing two models with different designs. I am not sure I will get this done before I have to submit but I will try!

jbloomAus commented 1 year ago

Hey Perusha,

That sounds really interesting and very similar to the stuff I've been working on meaning there could be a lot to learn from each others work. Unfortunately, I don't have any capacity currently to meet, possibly could meet in a month or two, but happy to correspond here or on email ( @.***). Biggest questions I have are:

  1. How are you generating your training data?
  2. How do the tasks work? How well does the DT solve them? How effectively does RTG modulate the behaviour?
  3. How do you encode the observations? What are the actions?
  4. Do you use Layernorm? Which activation functions?
  5. What other techniques are you using to analyse your model?
  6. Generally what are your goals? What are the other investigations you're using for these models?

On Wed, Aug 23, 2023 at 9:34 AM Perusha @.***> wrote:

That's great! I look forward to it! Happy to have a chat, also about the analysis side of things. I feel very much a beginner there but I am working through the material - definitely a few months behind!

Just some background of what I am trying to do - I was hoping to use some of the MI techniques to analyse my model. I am running Vizdoom (HGS and Deadly Corridor for now) on the decision transformer. I have a 3-layer model trained for MI purposes and would like to perform logit attribution, look at the attention heatmaps, activation patching, etc. comparing two models with different designs. I am not sure I will get this done before I have to submit but I will try!

— Reply to this email directly, view it on GitHub https://github.com/jbloomAus/DecisionTransformerInterpretability/issues/71#issuecomment-1689525664, or unsubscribe https://github.com/notifications/unsubscribe-auth/AQPMYZ2LVAAOC743P4XAQ33XWW6A7ANCNFSM6AAAAAAYAFMUYE . You are receiving this because you were mentioned.Message ID: @.*** .com>

moodlep commented 1 year ago

Hi Joseph, I was thinking the same and I am very happy to share details and discuss offline via email! PS I can't see your address above (just @.***) but I found it on LessWrong! I will pop you an email this evening!

moodlep commented 1 year ago

PS let me know if you did not receive anything - I may have the wrong email.

jbloomAus commented 1 year ago

I got it! Will respond soon as I can :)

On Thu, Aug 24, 2023 at 9:22 AM Perusha @.***> wrote:

PS let me know if you did not receive anything - I may have the wrong email.

— Reply to this email directly, view it on GitHub https://github.com/jbloomAus/DecisionTransformerInterpretability/issues/71#issuecomment-1691232662, or unsubscribe https://github.com/notifications/unsubscribe-auth/AQPMYZ22IZMO4PENK5CH52DXW4FNNANCNFSM6AAAAAAYAFMUYE . You are receiving this because you were mentioned.Message ID: @.*** .com>