understanding-search / maze-transformer

This repo is built to facilitate the training and analysis of autoregressive transformers on maze-solving tasks.
24 stars 6 forks source link

Backpropagate on the predicted shortest path rather than the entire input sequence #121

Open luciaquirke opened 1 year ago

luciaquirke commented 1 year ago


We want a model that has learned crisp and easily interpretable search algorithms. Such a model will solve mazes with high accuracy. However our ability to train such models is impacted by noise in our backpropagation calculation.

If we have a training sequence of tokens a b c d, then the model receives as separate samples the sequences:

a -> b
a b -> c
a b c -> d

Our current sequence includes both the adjacency list representing the maze and each step in the shortest path. This slows down training because in our first n sequence samples the model is making a prediction for and backpropagating on the adjacency list, which is partially randomly generated.

I think it's inherent to the transformer architecture that the model predicts the next token for each item in the sequence, but we can improve our backpropagation and loss accuracy by only using the shortest path predictions to calculate our loss and gradient updates.

A parameter in the HookedTransformer forward method can be set to return a tensor of per-token losses instead of the overall average loss. We can use this to determine optimal path prediction loss.

Definition of Done

Future Work

valedan commented 1 year ago

Should we maintain support for the old backprop strategy after we introduce this one? I seem to recall some suggestion that there might be some benefit to the model learning the general maze structure by training on the adj_list too.

valedan commented 1 year ago

I think @afspies said he tried this in one of his experiments and it didn't make much difference. I think it would still be good for us to add the option of doing loss in this way to the training script

afspies commented 1 year ago

Given @canrager 's experiments with smaller-maze-generalization, I am now inclined to believe that a good implementation of this would be valuable (I.e. one that uses padding and masked gradients, such as to not incur the 10% slowdown of my "schnell and dirty"™️ implementation).

The reason for this is that all adjacency lists for a given maze size are the same length, and so (at least some) of the trained models learned to recognise the beginning of the path based on the position in the sequence, rather than the delimiting tokens. What this meant was that any model trained with adjacency lists, when given a smaller maze (and thus shorter adjacency list) would first hallucinate a bunch of adjacency list tokens, before starting a path.

Not surprising in hindsight - if you want a model to generalize w.r.t something, you should probably show it that thing varying during training.