rail-berkeley / crossformer

MIT License
185 stars 17 forks source link

Batch size and memory requirements of training/finetuning #2

Closed mhyatt000 closed 2 months ago

mhyatt000 commented 2 months ago

Thanks for open sourcing this project!

I've been trying to train a CrossFormer in a similar style to the one described in the paper. I've noticed that while the training script reccommends a batch size of 512, I am only able to use batch size = 16 * 32 gradient accumulation steps until the machine runs out of memory. I use a node with 2x or 4x A100 GPUs (40GB).

Interestingly, I only have this problem with the scripts/train.py and not scripts/finetune.py, where I can finetune with a batch of 512 no problem. Do you know why this is? I've looked through both scripts and did not notice any significant discrepancies except that the training script performs a forward pass with all heads instead of one head. I've turned this off to try diagnosing the problem and still run out of memory during pre-training.

Looking forward to your reply.

HomerW commented 2 months ago

Fine-tuning will typically require less memory because only the observation tokens and action readout tokens for a single embodiment are included in the context. During pre-training all the possible observation tokens and action readout tokens are included in the context for each batch element (see the discussion about dense packing in section 3.4 of the paper for more explanation). Even if you just compute the loss for a single head during pre-training, the long context length will still contribute to the high memory usage. Hope this answers your question!

mhyatt000 commented 2 months ago

This makes sense. @HomerW did you guys try any ablations or initial experiments where dense packing failed? Or does Sec.3.4 suggest it as a direction for future work? It follows that inferring the embodiment would be more difficult, but I did not see experiments to back this up.

However, by not fixing observation and readout token types to a set location in the context window, the model would need to infer the embodiment purely from the observations in order to predict actions of the correct type (rather than relying on the positional embeddings of the readout tokens). The observations for some embodiments can look similar (such as navigation and manipulation with only a wrist camera), so this design may require appending a prefix to the token sequence indicating the embodiment.

I may try this to save compute for my situation unless it caused extra trouble for you.

HomerW commented 2 months ago

Yeah we didn't try dense packing, but I'd encourage you to try it out! It should significantly reduce the memory consumption during pre-training, and I don't think it would be too hard for the model to infer the embodiment from the observations.