Closed Vattikondadheeraj closed 1 month ago
Hi @Vattikondadheeraj thanks for creating the issue. Just to clarify, you make it through a full epoch (checkpoint save etc) before OOMing, right? If so I agree that a memory leak seems a likely culprit. The only other guess would be a batch with really long sequence length showing up, but if you already made it through an entire epoch I would think we can rule that out.
One suggestion to test the memory leak hypothesis is to set log_peak_memory_stats=True
in the config and see if there is gradual increase in memory. You can log to WandB, Tensorboard, Comet, whatever you prefer: example for WandB setup here.
A couple other clarifying questions that might help with debugging: (1) what is custom about the dataloader? Is everything definitely happening on CPU there? (2) can you share the full config and recipe file you're using with any customizations (or however much you feel comfortable sharing, if e.g. the custom dataloader is proprietary or something)?
Hey @ebsmothers , thanks for the quick reply. I commented out the save_checkpoint line in the train function for now because I am experimenting various hyperparameters. Is that an issue?
Regarding the custom dataloader, I think I am doing the tokenization in the dataloader get_item itself. Will that be an issue? If yes, then I will remove that. I attached the dataloader part below. Let me know if you find any bigs.
Thanks for your support. FYI, torchtune is one the versatile and best library I have ever used for LLM finetuning. Many people in my community is actually switching to torchtune.
class dataset(Dataset):
def __init__(self, data, tokenizer, json_data, max_length = 1024, batch_size=None, lambda_1=0.5):
self.data = data
self.code_gen = data.select_columns(['task_id', 'prompt', 'code'])
self.aux_data = data.select_columns(['task_id', 'prompt', 'code', 'test_imports', 'test_list', 'variables'])
self.json_data = json_data
self.ls=data.select_columns(["test_list"])
self.count = -1
print(batch_size,"^^^^^^^^^^^^^^^^^^^^^^^^")
self.batch_size = batch_size
self.split = int(self.batch_size * lambda_1 )
self.code_gen_idx = list(range(len(self.code_gen)))
self.aux_idx = list(range(len(self.aux_data)))
self.tokenizer = tokenizer
self.max_length = max_length
random.shuffle(self.code_gen_idx)
random.shuffle(self.aux_idx)
self.aux_ptr = 0
self.code_ptr = 0
self.len_data = len(self.data)
self.prompt_len = []
def _update_aux(self):
new_id = self.aux_idx[self.aux_ptr]
self.aux_ptr += 1
self.aux_ptr = self.aux_ptr % self.len_data
return new_id
def process_code(self, id):
task_id=self.code_gen[id]["task_id"]
i=random.sample(range(len(self.ls[id]["test_list"])),1)
print(task_id)
cot=cots[task_id][i[0]].split("Test Case ")[1].strip()[2:].strip()
input_pr=self.apply_chat_template(self.code_gen[id]['prompt'], self.ls[id]["test_list"][i[0]])
#code_gen = "Question: " + self.code_gen[id]['prompt'] + "\nFor example, this is how the function name and number of arguments shouls look like: " + self.ls[id]["test_list"][0].split("assert ")[1].strip() +'\n' + "Answer:\n"
# out=self.code_gen[id]['code'] + "\n" + cot
out=self.code_gen[id]['code']
# print(code_gen)
return input_pr, out
def apply_chat_template(self, user_prompt, assertion):
#prompt_template = f"""<|start_header_id|>user<|end_header_id|>\n\n{user_prompt}\n\nYour code should satisfy the following assertion:\n{assertion}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\nHere is a solution to this programming problem:\n```python """
prompt_template = f"""{user_prompt}\n\nYour code should satisfy the following assertion:\n{assertion}\nHere is a solution to this programming problem:\n```python """
return prompt_template
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
self.count += 1
self.count = self.count % self.batch_size
# print(self.count)
if self.count < self.split:
#if 1:
id = self.code_gen_idx[self.code_ptr]
self.code_ptr += 1
self.code_ptr = self.code_ptr % self.len_data
inp, output = self.process_code(id)
token = self.tokenizer.encode(inp, add_eos=True)
label = self.tokenizer.encode(inp + output, add_eos=True)
return {
"tokens" : token,
"labels": label
}
else:
# chose auxiliary task
id = self.aux_idx[self.aux_ptr]
task_id = self.code_gen[id]['task_id']
while str(task_id) not in list(self.json_data.keys()):
id=self._update_aux()
task_id = self.code_gen[id]['task_id']
self.aux_ptr += 1
self.aux_ptr = self.aux_ptr % self.len_data
if random.random() > 0.4:
#if 0:
print("aux1")
# aux_1
inp = self.json_data[str(task_id)]['a1']['tokens']
output = self.json_data[str(task_id)]['a1']['labels']
else:
#aux_2
n = len(self.json_data[str(task_id)]['a2'].keys())
print("aux2")
test_id = random.choice(range(n))
inp = self.json_data[str(task_id)]['a2'][str(test_id)]['tokens']
a=random.choice(range(len(inp)))
output = self.json_data[str(task_id)]['a2'][str(test_id)]['labels']
b=random.choice(range(len(output)))
token = self.tokenizer.encode(inp, add_eos=True)
label = self.tokenizer.encode(inp + output, add_eos=True)
return {
"tokens" : token,
"labels": label
}
I am thinking that it can be the reserved memory. You could set expandable segments=True to see if it solves it
I am on my phone now, but for the exact command, search for “expandable” in the issues.
btw, you can probably lower the grad_accumulation number if you use our memory efficient configs
use the chunked cross entropy (default in the configs), compile=true, if it’s full fine tuning single device, use pagedadamw optimizer, if that’s not enough, use enable_activation_checkpoint=true. If activation check pointing is true, and you are finetuning Lora, also set enable_activation_offloading=True
You are creating your own dataset, so that makes things harder, but check our docs for packed dataset. It greatly increases tokens per second.
Hi @Vattikondadheeraj I don't think commenting out save_checkpoint should cause any issues. Regarding your dataloader, it's hard to say for sure. There are a couple unused variables like a
and b
but these should be freed (and I don't think they're on CUDA anyways). Actually I think a more likely culprit is non-determinism. Are you sure that you are viewing the same samples each epoch? The usage of if random.random() > 0.4:
makes me wonder if there is just a sample in your batch that will cause you to OOM but you're just not seeing it in the first epoch. So aside from logging peak memory stats I would also suggest logging the length of tokens and labels to see if the OOM occurs on a particularly long sequence.
Glad to hear you're enjoying the library! Do let us know if there are any pain points or feature requests you have, we are always striving to improve the experience of using torchtune.
Edit: also Felipe has added some good suggestions as I was typing this out. Assuming there's no memory leak and you are just OOMing due to a long sequence or wastefully reserving CUDA memory, these suggestions may be all you need (and in general they're all good things to try anyways).
@felipemello1 , Thanks for the suggestions. I will try to test them. And regarding the custom dataloader, Our dataloader is a bit complicated and naive implementation can't be done. As you suggested, I will try to modify torchtune.datasets.SFTData and packed.py functions and integrate into our pipeline.
@ebsmothers , You are right. we are actually sampling different data points and the model doesn't see the same datapoints at each and every epoch. We are trying to send different ratios of different tasks, I mean in each batch, 2 samples belong to task-1 and 1 sample belongs to task-2 etc. and tasks are different from each other. But as we run multiple epochs, the model can see the all the samples multiple times. But you raised an important question in my mind, How does the above weighted sampling impact the training dynamics? Currently we are facing few problems like the model is not converging despite training for multiple epochs. We thought batch size was the culprit because my initial batch size is 3. Currently I am scaling up the batch size via distributed training to see if the loss values will go down or not. Let me know your opinion.
@Vattikondadheeraj it's hard to say about the training dynamics in general as it depends on many things (data quality for each task, distribution of training data, etc). But one comment is that if you are changing the sampling over each epoch you may want to make sure you're using the right learning rate scheduler. I believe the default for our Llama 3.1 8B LoRA single-device config is to use a cosine schedule with warmup. Maybe this is what you want, but it does mean that later epochs will factor less heavily into training than earlier ones in general. And there may be some assumption baked into the use of this schedule that each epoch is seeing the same set of training data, which is not the case for your setup.
Hey @felipemello1, I need one clarification on packed dataset. I think PackedDataset class is not yet used in your pipeline. But after going through PackedDataset class in detail, Its an extra layer in between the dataloader and the padded_collate_packed right? I mean I should first have dataloader which returns tokens and labels. Then this batch is passed through the PackedDataset and at last padded_collate_packed is used right?
@Vattikondadheeraj , you can see how we use here in the dataset: https://github.com/pytorch/torchtune/blob/3fddc56942846220b39945559f4b5e695873bb43/torchtune/datasets/_samsum.py#L86
and here in the recipe: https://github.com/pytorch/torchtune/blob/3fddc56942846220b39945559f4b5e695873bb43/recipes/full_finetune_distributed.py#L518
Currently we are facing few problems like the model is not converging despite training for multiple epochs.
I would set breakpoint() inside of your training loop, and take a look at the tokens and labels. Ideally, they should be the same, and our script makes them off by one, so that at token i, it predicts tokens i+1: https://github.com/pytorch/torchtune/blob/3fddc56942846220b39945559f4b5e695873bb43/recipes/full_finetune_distributed.py#L640
You may wanna see if you are using dataset.train_on_input=True/False (if you wanna train on the prompt or not) and if your ignore_index=-100, which is what we use in the CrossEntropy to ignore tokens
ideally, after multiple epochs, you should be at least overfitting to the training set
Also, make sure that you are training the layers you wanna train. You can do something like:
for name, param in model.named_paramenters():
if param.requires_grad():
print(name)
hey @Vattikondadheeraj , just wondering if you were able to solve your issues, since you closed it. Thanks!!
@felipemello1 Hey, yes. most of them are resolved. I scaled the gpus a bit to solve the OOM issues and the above tips given by you and @ebsmothers actually worked. Thanks!!!
Hey, I am trying to lora finetune llama-3.1-8b Instruct model on a single A-100 GPU with 40GB VRAM. The batch size is 2 and gradient accumulation parameter=32. I am getting "cuda out of memory" after 1 epoch which is strange. I have my own custom dataloader which is completely managing stuff in cpu. So I wanna know where exactly the problem is? Is there a data leak from my side or is there any dataleak in the codebase itself? Is there any way to easily sort this problem?