google-research / text-to-text-transfer-transformer

Code for the paper "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer"
https://arxiv.org/abs/1910.10683
Apache License 2.0
6.17k stars 756 forks source link

Batch Size vs Sequence Length #592

Closed trisongz closed 3 years ago

trisongz commented 3 years ago

Hi,

I've been working with the mT5 models recently, and wanted to get a better understanding of the dataset input function during training.

I've looked closely at mesh_tf's data_fn but wasn't able to get a clear answer on the mechanics.

For sequence_length, I've used different permutations, including:

'sequence_length': {'inputs': 512, 'targets': 2048} 'sequence_length': {'inputs': 512, 'targets': 1024} 'sequence_length': {'inputs': 256, 'targets': 1536}

For batch_size, I've tested both:

'batch_size': ('tokens_per_batch', 1048576), 'batch_size': 8-128,

I have several different tasks, some of which use a LM approach {inputs: None, targets: text}, whereas some others are set up for input/target. I've observed that using 'tokens_per_batch' generally results in slower time per iteration compared to just a set batch_size. (3-4x longer over the same iteration steps)

If I understand correctly, even when 'tokens_per_batch' is used, the tokens are truncated at the specified inputs/targets and concatenated together to create the specified number of tokens?

In experimenting with xlarge vs xxl models, using the static batch_size, and 'sequence_length': {'inputs': 256, 'targets': 2048}, xlarge was able to run (v3-8) with batch_size 32, whereas xxl was not able to run at all (output feed error after first checkpoint), even with batch_size 8, despite having available memory (profiling TPU during run), but reducing it down to {'inputs': 256, 'targets': 1536} allowed it to run on batch_size 8. Would the cause of this be due to how the data is being partitioned across the TPU cores?

When comparing static batch_size vs tokens_per_batch vs sequences_per_batch (have not experimented with), which have you found to have the highest throughput in I/o? In my tests, TPU MXU remained consistently above 45% regardless of static batch_size vs tokens_per_batch.

Is there a theoretical limit on the sequence lengths? For example, if a dataset had {'inputs': 64, 'targets': 4096}, would it affect the model performance if d_model was scaled equivalently vs if it was set lower than the sequence lengths (d_model = 2048)? And in this example, using tokens_per_batch would be more effective than using a static batch_size right?

Thanks in advance for the insights!

craffel commented 3 years ago

Hi, there is no difference at all between the behavior of tokens_per_batch or sequences_per_batch. In short, max(input_length, target_length)*sequences_per_batch = tokens_per_batch. For example, if your input and target sequence lengths are both 512, then tokens_per_batch=8192 is equivalent to sequences_per_batch=16. If there was a different in throughput, it's because you were ultimately specifying different batch sizes (e.g. if you specified sequences_per_batch=8 in the example above, it would be faster than tokens_per_batch=8192 because the latter would have a 2x larger batch).

Whether you are able to fit the model on a given TPU topology will depend on the model size, batch size, sequence length, and data/model parallelism. You can try tweaking/shrinking each of those factors to try to get a large model to fit in a small TPU. I don't think we've tried running xxl on a v3-8 so I can't provide an example configuration off the top of my head.

The only limit to sequence lengths is memory - you need to be able to fit the sequences (and quadratic-size attention matrices) in memory. There is no relationship between d_model and sequence lengths. d_model controls the dimension of the activations inside the model and has nothing to do with sequence length.