pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.31k stars 430 forks source link

CUDA out of memory after few epochs #1709

Closed Vattikondadheeraj closed 1 month ago

Vattikondadheeraj commented 1 month ago

Hey, I am trying to lora finetune llama-3.1-8b Instruct model on a single A-100 GPU with 40GB VRAM. The batch size is 2 and gradient accumulation parameter=32. I am getting "cuda out of memory" after 1 epoch which is strange. I have my own custom dataloader which is completely managing stuff in cpu. So I wanna know where exactly the problem is? Is there a data leak from my side or is there any dataleak in the codebase itself? Is there any way to easily sort this problem?

ebsmothers commented 1 month ago

Hi @Vattikondadheeraj thanks for creating the issue. Just to clarify, you make it through a full epoch (checkpoint save etc) before OOMing, right? If so I agree that a memory leak seems a likely culprit. The only other guess would be a batch with really long sequence length showing up, but if you already made it through an entire epoch I would think we can rule that out.

One suggestion to test the memory leak hypothesis is to set log_peak_memory_stats=True in the config and see if there is gradual increase in memory. You can log to WandB, Tensorboard, Comet, whatever you prefer: example for WandB setup here.

A couple other clarifying questions that might help with debugging: (1) what is custom about the dataloader? Is everything definitely happening on CPU there? (2) can you share the full config and recipe file you're using with any customizations (or however much you feel comfortable sharing, if e.g. the custom dataloader is proprietary or something)?

Vattikondadheeraj commented 1 month ago

Hey @ebsmothers , thanks for the quick reply. I commented out the save_checkpoint line in the train function for now because I am experimenting various hyperparameters. Is that an issue?

Regarding the custom dataloader, I think I am doing the tokenization in the dataloader get_item itself. Will that be an issue? If yes, then I will remove that. I attached the dataloader part below. Let me know if you find any bigs.

Thanks for your support. FYI, torchtune is one the versatile and best library I have ever used for LLM finetuning. Many people in my community is actually switching to torchtune.

class dataset(Dataset):
    def __init__(self, data, tokenizer, json_data,  max_length = 1024, batch_size=None, lambda_1=0.5):
        self.data = data
        self.code_gen = data.select_columns(['task_id', 'prompt', 'code'])
        self.aux_data = data.select_columns(['task_id', 'prompt', 'code', 'test_imports', 'test_list', 'variables'])
        self.json_data = json_data
        self.ls=data.select_columns(["test_list"])
        self.count = -1
        print(batch_size,"^^^^^^^^^^^^^^^^^^^^^^^^")
        self.batch_size = batch_size
        self.split = int(self.batch_size * lambda_1 )
        self.code_gen_idx = list(range(len(self.code_gen)))
        self.aux_idx = list(range(len(self.aux_data)))
        self.tokenizer = tokenizer
        self.max_length = max_length

        random.shuffle(self.code_gen_idx)
        random.shuffle(self.aux_idx)
        self.aux_ptr = 0
        self.code_ptr = 0
        self.len_data = len(self.data)
        self.prompt_len = []

    def _update_aux(self):
        new_id = self.aux_idx[self.aux_ptr]
        self.aux_ptr += 1
        self.aux_ptr = self.aux_ptr % self.len_data
        return new_id

    def process_code(self, id):
        task_id=self.code_gen[id]["task_id"]
        i=random.sample(range(len(self.ls[id]["test_list"])),1)
        print(task_id)
        cot=cots[task_id][i[0]].split("Test Case ")[1].strip()[2:].strip()
        input_pr=self.apply_chat_template(self.code_gen[id]['prompt'], self.ls[id]["test_list"][i[0]])
        #code_gen = "Question: " + self.code_gen[id]['prompt'] + "\nFor example, this is how the function name and number of arguments shouls look like: " + self.ls[id]["test_list"][0].split("assert ")[1].strip() +'\n' + "Answer:\n"
        # out=self.code_gen[id]['code'] + "\n" + cot
        out=self.code_gen[id]['code']
        # print(code_gen)
        return input_pr, out

    def apply_chat_template(self, user_prompt, assertion):
        #prompt_template = f"""<|start_header_id|>user<|end_header_id|>\n\n{user_prompt}\n\nYour code should satisfy the following assertion:\n{assertion}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\nHere is a solution to this programming problem:\n```python """
        prompt_template = f"""{user_prompt}\n\nYour code should satisfy the following assertion:\n{assertion}\nHere is a solution to this programming problem:\n```python """
        return prompt_template

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        self.count += 1
        self.count = self.count % self.batch_size
        # print(self.count)
        if self.count < self.split:
        #if 1:
            id = self.code_gen_idx[self.code_ptr]
            self.code_ptr += 1
            self.code_ptr = self.code_ptr % self.len_data
            inp, output = self.process_code(id)  
            token = self.tokenizer.encode(inp, add_eos=True)
            label = self.tokenizer.encode(inp + output, add_eos=True)

            return {
              "tokens" : token,
              "labels": label
            }

        else:

          # chose auxiliary task
            id = self.aux_idx[self.aux_ptr]
            task_id = self.code_gen[id]['task_id']

            while str(task_id) not in list(self.json_data.keys()):

                id=self._update_aux()
                task_id = self.code_gen[id]['task_id']

            self.aux_ptr += 1
            self.aux_ptr = self.aux_ptr % self.len_data
            if random.random() > 0.4:
            #if 0:
                print("aux1")
              # aux_1
                inp = self.json_data[str(task_id)]['a1']['tokens']
                output = self.json_data[str(task_id)]['a1']['labels']
            else:

              #aux_2
                n = len(self.json_data[str(task_id)]['a2'].keys())
                print("aux2")

                test_id = random.choice(range(n))
                inp = self.json_data[str(task_id)]['a2'][str(test_id)]['tokens']
                a=random.choice(range(len(inp)))
                output = self.json_data[str(task_id)]['a2'][str(test_id)]['labels']
                b=random.choice(range(len(output)))

            token = self.tokenizer.encode(inp, add_eos=True)
            label = self.tokenizer.encode(inp + output, add_eos=True)

            return {
              "tokens" : token,
              "labels": label
            }
felipemello1 commented 1 month ago

I am thinking that it can be the reserved memory. You could set expandable segments=True to see if it solves it

I am on my phone now, but for the exact command, search for “expandable” in the issues.

btw, you can probably lower the grad_accumulation number if you use our memory efficient configs

use the chunked cross entropy (default in the configs), compile=true, if it’s full fine tuning single device, use pagedadamw optimizer, if that’s not enough, use enable_activation_checkpoint=true. If activation check pointing is true, and you are finetuning Lora, also set enable_activation_offloading=True

You are creating your own dataset, so that makes things harder, but check our docs for packed dataset. It greatly increases tokens per second.

ebsmothers commented 1 month ago

Hi @Vattikondadheeraj I don't think commenting out save_checkpoint should cause any issues. Regarding your dataloader, it's hard to say for sure. There are a couple unused variables like a and b but these should be freed (and I don't think they're on CUDA anyways). Actually I think a more likely culprit is non-determinism. Are you sure that you are viewing the same samples each epoch? The usage of if random.random() > 0.4: makes me wonder if there is just a sample in your batch that will cause you to OOM but you're just not seeing it in the first epoch. So aside from logging peak memory stats I would also suggest logging the length of tokens and labels to see if the OOM occurs on a particularly long sequence.

Glad to hear you're enjoying the library! Do let us know if there are any pain points or feature requests you have, we are always striving to improve the experience of using torchtune.

Edit: also Felipe has added some good suggestions as I was typing this out. Assuming there's no memory leak and you are just OOMing due to a long sequence or wastefully reserving CUDA memory, these suggestions may be all you need (and in general they're all good things to try anyways).

Vattikondadheeraj commented 1 month ago

@felipemello1 , Thanks for the suggestions. I will try to test them. And regarding the custom dataloader, Our dataloader is a bit complicated and naive implementation can't be done. As you suggested, I will try to modify torchtune.datasets.SFTData and packed.py functions and integrate into our pipeline.

@ebsmothers , You are right. we are actually sampling different data points and the model doesn't see the same datapoints at each and every epoch. We are trying to send different ratios of different tasks, I mean in each batch, 2 samples belong to task-1 and 1 sample belongs to task-2 etc. and tasks are different from each other. But as we run multiple epochs, the model can see the all the samples multiple times. But you raised an important question in my mind, How does the above weighted sampling impact the training dynamics? Currently we are facing few problems like the model is not converging despite training for multiple epochs. We thought batch size was the culprit because my initial batch size is 3. Currently I am scaling up the batch size via distributed training to see if the loss values will go down or not. Let me know your opinion.

ebsmothers commented 1 month ago

@Vattikondadheeraj it's hard to say about the training dynamics in general as it depends on many things (data quality for each task, distribution of training data, etc). But one comment is that if you are changing the sampling over each epoch you may want to make sure you're using the right learning rate scheduler. I believe the default for our Llama 3.1 8B LoRA single-device config is to use a cosine schedule with warmup. Maybe this is what you want, but it does mean that later epochs will factor less heavily into training than earlier ones in general. And there may be some assumption baked into the use of this schedule that each epoch is seeing the same set of training data, which is not the case for your setup.

Vattikondadheeraj commented 1 month ago

Hey @felipemello1, I need one clarification on packed dataset. I think PackedDataset class is not yet used in your pipeline. But after going through PackedDataset class in detail, Its an extra layer in between the dataloader and the padded_collate_packed right? I mean I should first have dataloader which returns tokens and labels. Then this batch is passed through the PackedDataset and at last padded_collate_packed is used right?

felipemello1 commented 1 month ago

@Vattikondadheeraj , you can see how we use here in the dataset: https://github.com/pytorch/torchtune/blob/3fddc56942846220b39945559f4b5e695873bb43/torchtune/datasets/_samsum.py#L86

and here in the recipe: https://github.com/pytorch/torchtune/blob/3fddc56942846220b39945559f4b5e695873bb43/recipes/full_finetune_distributed.py#L518

felipemello1 commented 1 month ago

Currently we are facing few problems like the model is not converging despite training for multiple epochs.

I would set breakpoint() inside of your training loop, and take a look at the tokens and labels. Ideally, they should be the same, and our script makes them off by one, so that at token i, it predicts tokens i+1: https://github.com/pytorch/torchtune/blob/3fddc56942846220b39945559f4b5e695873bb43/recipes/full_finetune_distributed.py#L640

You may wanna see if you are using dataset.train_on_input=True/False (if you wanna train on the prompt or not) and if your ignore_index=-100, which is what we use in the CrossEntropy to ignore tokens

ideally, after multiple epochs, you should be at least overfitting to the training set

Also, make sure that you are training the layers you wanna train. You can do something like:

for name, param in model.named_paramenters():
    if param.requires_grad():
        print(name)
felipemello1 commented 1 month ago

hey @Vattikondadheeraj , just wondering if you were able to solve your issues, since you closed it. Thanks!!

Vattikondadheeraj commented 1 month ago

@felipemello1 Hey, yes. most of them are resolved. I scaled the gpus a bit to solve the OOM issues and the above tips given by you and @ebsmothers actually worked. Thanks!!!