jbloomAus / DecisionTransformerInterpretability

Interpreting how transformers simulate agents performing RL tasks
https://jbloomaus-decisiontransformerinterpretability-app-4edcnc.streamlit.app/
MIT License
61 stars 15 forks source link

Better encode/embed MiniGrid State to speed up training in DT's. #61

Closed jbloomAus closed 1 year ago

jbloomAus commented 1 year ago

I have two main ideas for this:

Implement a variation of the BOW encoding that is used by BabyAI but add position to avoid building a convnet.

Not discussed in the BabyAI paper itself, it seems like they actually used a much better tokenization scheme than I am using and this could plausibly be causing many of my problems.

class ImageBOWEmbedding(nn.Module):
   def __init__(self, max_value, embedding_dim):
       super().__init__()
       self.max_value = max_value
       self.embedding_dim = embedding_dim
       self.embedding = nn.Embedding(3 * max_value, embedding_dim)
       self.apply(initialize_parameters)

   def forward(self, inputs):
       offsets = torch.Tensor([0, self.max_value, 2 * self.max_value]).to(inputs.device)
       inputs = (inputs + offsets[None, :, None, None]).long()
       return self.embedding(inputs).sum(1).permute(0, 3, 1, 2)

This takes each position and it represents it as a unique embedding that is independent of position. For example, Key + yellow + close (object, state,color) -> 13 or something -> maps to a specific vector. They then pass this into a convolutional network (which I was hoping to avoid).

I can do better than this though in the context of my model. I can have, bear with me, 5 separate embedding matrices. Consider the current model, it's actually, 3 embedding matrices:

Then any given object,color,state at any given position starts of with a unique representation. We can then look at how weight regularization acts as a feature selector over these and how the embeddings of each evolve in response to the circuits which use them.

One complication here is that for the positional embeddings, it seems like I should use something like sinusoidal embeddings but I want to ensure they are orthogonal to the previous sinusoidal embeddings. I have some ideas about how to do this but I will google it anyway.

Eg: this from the gato paper. Image

jbloomAus commented 1 year ago

Implicit in this card is a hypothesis that we just majorly screwed ourselves over the last 5 months using shit encoding of the state. If I'm right, this will be so good.

jbloomAus commented 1 year ago

Alrighty, I've written up a pretty cool state encoder which I'm hoping will work. I've essentially done the above, and am now producing a tensor of show (batch, view_size, view_size, d_model) which is a sum of object, state and color embeddings with an overlayed 2d sinusoidal positional embedding. I'm then going to flatten that and run it through a linear layer but I suspect it will be way more interpretable by default that the one hot encode stuff which I think might have been making it hard to do positional reasoning in a fairly obvious way.

I've done some work to visualize the embeddings and check that they don't have too much interference which seems true (see below). I'll add a linear projection of size 77n_state_emb d_model so that the state then gets projected into the residual stream. One challenge here is that this is actually a fair number of parameters. Even with d_emd = 32, if view_size = 7 then we get 77*32 =1568 neurons in that intermediate later before we project down to 128 in d_model. The linear layer that sits between this encoding and the rest of the model has to do a lot of work, but can't reason about RTG which I guess gives some assurances as to what it can't do re-task solving. I can try a 0 layer BC clone to check how smart it is.

Image Image Image

jbloomAus commented 1 year ago

Once I get a basic version of this working, I think I should double check:

Right now I care about having a working model more than I care about other details but that may change in the future.

jbloomAus commented 1 year ago

I've done some thinking and come up with the following conclusions:

As such, the tasks I'm setting myself for now are:

Then, I'm going to evaluate each on the dynamic obstacles task and then the memory env task. If I'm still doing this beyond today because of unforeseen complexities, then I'll split both out into their own cards.

jbloomAus commented 1 year ago

Results on convnet in dynamic obstacles:

In order to use the CNN in practice, I will need to either move to using GPU's or optimize the code or hyperparameters. It will not be practical (I imagine) to accept an 8x reduction in speed.

jbloomAus commented 1 year ago

Onto BOW with ViT.

Managed to get a class running. 3x speed increase (not as bad as ConvNet). Also solves dynamic obstacles well. I still need to hook it up with args those and test both of these on the memory env.

jbloomAus commented 1 year ago

Parameter counts are very high. Will look into this/visualizing parameter counts in model to ensure I understand it.