understanding-search / maze-transformer

This repo is built to facilitate the training and analysis of autoregressive transformers on maze-solving tasks.
24 stars 6 forks source link

Training batches are missing a bunch of tokens #180

Closed valedan closed 1 year ago

valedan commented 1 year ago

In our training loop, each maze in the batch is missing tokens from the end. They often do not include start or end tokens, which would make training impossible. Alex's sweeps have been using a different training loop that does not have this problem.

Example of a 3x3 5-maze dataset:


# ------original dataset--------
['<ADJLIST_START>', '(0,1)', '<-->', '(0,2)', ';', '(0,2)', '<-->', '(1,2)', ';', '(1,1)', '<-->', '(2,1)', ';', '(2,2)', '<-->', '(2,1)', ';', '(1,2)', '<-->', '(2,2)', ';', '(1,0)', '<-->', '(0,0)', ';', '(1,1)', '<-->', '(1,0)', ';', '(2,0)', '<-->', '(1,0)', ';', '<ADJLIST_END>', '<ORIGIN_START>', '(1,0)', '<ORIGIN_END>', '<TARGET_START>', '(0,1)', '<TARGET_END>']
['<ADJLIST_START>', '(2,0)', '<-->', '(2,1)', ';', '(1,0)', '<-->', '(0,0)', ';', '(2,1)', '<-->', '(2,2)', ';', '(1,2)', '<-->', '(2,2)', ';', '(1,0)', '<-->', '(1,1)', ';', '(0,1)', '<-->', '(0,2)', ';', '(0,2)', '<-->', '(1,2)', ';', '(0,1)', '<-->', '(0,0)', ';', '<ADJLIST_END>', '<ORIGIN_START>', '(0,1)', '<ORIGIN_END>', '<TARGET_START>', '(2,1)', '<TARGET_END>']
['<ADJLIST_START>', '(0,2)', '<-->', '(0,1)', ';', '(1,0)', '<-->', '(2,0)', ';', '(2,2)', '<-->', '(2,1)', ';', '(2,1)', '<-->', '(2,0)', ';', '(0,0)', '<-->', '(0,1)', ';', '(0,1)', '<-->', '(1,1)', ';', '(2,2)', '<-->', '(1,2)', ';', '(1,0)', '<-->', '(0,0)', ';', '<ADJLIST_END>', '<ORIGIN_START>', '(1,2)', '<ORIGIN_END>', '<TARGET_START>', '(0,2)', '<TARGET_END>']
['<ADJLIST_START>', '(1,1)', '<-->', '(1,2)', ';', '(1,2)', '<-->', '(2,2)', ';', '(2,0)', '<-->', '(2,1)', ';', '(0,0)', '<-->', '(0,1)', ';', '(0,0)', '<-->', '(1,0)', ';', '(1,2)', '<-->', '(0,2)', ';', '(0,1)', '<-->', '(0,2)', ';', '(2,0)', '<-->', '(1,0)', ';', '<ADJLIST_END>', '<ORIGIN_START>', '(0,2)', '<ORIGIN_END>', '<TARGET_START>', '(1,2)', '<TARGET_END>']
['<ADJLIST_START>', '(1,2)', '<-->', '(0,2)', ';', '(1,1)', '<-->', '(2,1)', ';', '(0,1)', '<-->', '(0,2)', ';', '(2,0)', '<-->', '(2,1)', ';', '(1,0)', '<-->', '(2,0)', ';', '(1,2)', '<-->', '(2,2)', ';', '(1,1)', '<-->', '(0,1)', ';', '(0,1)', '<-->', '(0,0)', ';', '<ADJLIST_END>', '<ORIGIN_START>', '(0,1)', '<ORIGIN_END>', '<TARGET_START>', '(0,0)', '<TARGET_END>']

#-------batch in training loop----------

['<ADJLIST_START>', '(2,2)', '<-->', '(2,1)', ';', '(1,0)', '<-->', '(0,0)', ';', '(1,0)', '<-->', '(1,1)', ';', '(2,0)', '<-->', '(2,1)', ';', '(0,1)', '<-->', '(0,0)', ';', '(0,1)']
['<ADJLIST_START>', '(0,1)', '<-->', '(0,0)', ';', '(0,2)', '<-->', '(0,1)', ';', '(1,1)', '<-->', '(0,1)', ';', '(1,0)', '<-->', '(2,0)', ';', '(0,2)', '<-->', '(1,2)', ';', '(1,1)', '<-->', '(2,1)', ';', '(2,1)', '<-->', '(2,0)', ';', '(2,2)', '<-->', '(1,2)', ';', '<ADJLIST_END>', '<ORIGIN_START>', '(0,1)', '<ORIGIN_END>']
['<ADJLIST_START>', '(0,1)', '<-->', '(0,0)', ';', '(0,2)', '<-->', '(0,1)', ';', '(1,1)', '<-->', '(0,1)', ';', '(1,0)', '<-->']
['<ADJLIST_START>', '(1,0)', '<-->', '(2,0)', ';', '(1,1)', '<-->', '(1,2)', ';', '(0,0)', '<-->', '(1,0)', ';', '(1,2)', '<-->']
['<ADJLIST_START>', '(2,2)', '<-->', '(2,1)', ';', '(1,0)', '<-->', '(0,0)', ';', '(2,1)', '<-->', '(1,1)', ';', '(1,2)', '<-->', '(2,2)', ';', '(2,0)', '<-->', '(1,0)', ';', '(1,1)', '<-->', '(1,0)', ';', '(0,2)', '<-->', '(1,2)', ';', '(0,1)', '<-->', '(0,2)', ';', '<ADJLIST_END>', '<ORIGIN_START>']```
valedan commented 1 year ago

More context here: https://searchingforsearch.slack.com/archives/C04QRDD81LZ/p1680593584634419

mivanit commented 1 year ago

fixed in #177, but still need feedback on what to return for each sequence: