Closed srikanthsrnvs closed 1 year ago
Hi @srikanthsrnvs
Thanks for your interest in the library and your questions.
The chunking in the cross entropy loss is quite simple: Instead of computing the cross entropy over the entire batch and sequence, the computation is done on smaller chunks (splitting along the batch dimension) to avoid memory peaks at the cost of a slightly slower speed. The returned loss should be the same as the non-chunked version, there is no scaling needed or anything by the user.
The behavior you describe is correct and is expected. In multi-gpu training with N gpus, there will be N processes running in parallel, each one assigned to one GPU. There is no single shared "manager" device. Perhaps you are mixing it up with some other form of training, but this is not how it is in Lightning / PyTorch.
I'll run this and let you know what I find, but has to wait until monday unfortunately. Will get back to you then. Based alone on your description, I suspect that the sequence length here is the big factor that pushes you to OOM. Our finetuning scripts and the numbers we documented in the readmes were for much smaller sequence lengths (because of the finetuning data which had shorter promts).
Gotcha, thanks.
if fabric.global_rank==0: wandb.log()
? Since I dont want to reinitialize wandb on every process.@srikanthsrnvs
1) Could you show for which inputs the function returns wildly different values? I extended the test case in #343 to show that the chunked implementation is equivalent to the regular cross entripy loss.
2) Yes :)
I have two questions around the current codebase as it stands.
finetune/lora.py
code to train on 8 GPUs, I see that thevalidate()
function is called 8 times on 8 separate processes. Why is that? Why does fabric run model validation & training on each process separately? As far as I understand, when using DDP, the model is cloned across devices, using a single shared device as the manager for the gradient updates, or am I incorrect? If so, why do all 8 processes spawn training output & validation outputs? Shouldn't there only be 1?lora.py
results in an OOM on a batch size of 1 for a 13B parameter llama model, with a context length of 4096. The training works fine, until it hits a validation interval, at which the OOM gets triggered. Here is my currentlora.py