jbloomAus / DecisionTransformerInterpretability

Interpreting how transformers simulate agents performing RL tasks
https://jbloomaus-decisiontransformerinterpretability-app-4edcnc.streamlit.app/
MIT License
69 stars 16 forks source link

Verify Initialization of Transformer Model Components is good/appropriate. #65

Closed jbloomAus closed 1 year ago

jbloomAus commented 1 year ago

I think I haven't adequately look into this so will do so briefly.

If I have reason to think this will improve my model:

Edit: Adding to this card a task to optimize weight decay. We'll parameterize this.

jbloomAus commented 1 year ago

What Neel says in the T-lens docstring for HookedTransformer.init_weights()

        Initialize weights matrices with a normal of std=initializer_range (default=0.02). This roughly follows the GPT-2 paper's scheme (but with truncation, and not halving the std for W_pos).

        LayerNorm weights are already initialized to 1.0, and all biases are initialized to 0.0 (including LayerNorm), so this just initializes weight matrices.

        Weight matrices are set to empty by default (to save space + compute, since they're the bulk of the parameters), so it is important to call this if you are not loading in pretrained weights! Note that this function assumes that weight names being with W_

        Set seed here to ensure determinism.

        This does NOT follow the PyTorch scheme, which as far as I can tell is super out of date but no one has gotten round to updating it?
        https://github.com/pytorch/pytorch/issues/18182

        PyTorch Transformers are especially bad - TransformerEncoder initializes all layers to the exact same weights?! https://github.com/pytorch/pytorch/issues/72253

        The best paper I've found on transformer initialization is the muP paper, but haven't integrated those ideas yet: https://arxiv.org/abs/2203.03466
jbloomAus commented 1 year ago

Related idea: plot L2 norms of residual stream during training or check for saturation. Look at how these shift during training...

jbloomAus commented 1 year ago

I've decided I want to emulate MinGPT / Othello GPT as this seems most reasonable. There are some subtleties I've been getting wrong (theoretically?) which I want to rectify.

What is the initialization strategy we want? here.

What is the decay strategy we want? here.

(I'll make an arg to toggle this so we can compare to old strategy?)

Since my naive embedding strategy uses a linear layer to embed the observation, I will need to overrule this to be true to the strategy of MinGPT.

I'll quickly work out why he's doing things this way.

jbloomAus commented 1 year ago

Image

Old init looked very messy. This is better. (horizontal line in std was at -log(0.02) not -ln(0.02) which I've fixed.

jbloomAus commented 1 year ago

I don't think we need this plot in wandb, but I will add it to the github repo anyway.

jbloomAus commented 1 year ago

Adding weight decay groups. Might break reporting some of arg combinations but I'll deal with that as it comes.