huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.27k stars 1.16k forks source link

OOM Error using PPO Trainer to LoRa-tune 4-bit Llama-3-8B Model #1833

Closed AryamanJaggi closed 3 weeks ago

AryamanJaggi commented 1 month ago

As per the standard for PPO Training (which is to do supervised-fine tuning before running the PPO Algorithm) I did a QLoRa fine-tuning of the Llama-3-8B instruct model using my own custom data and the SFT Trainer. I then merged the LoRa adapers and pushed this model to the HF hub in 4-bit.

For the PPO Training step, I initialized my model like this (the Lora config and quantization config are defined elsewhere before this):

model_id = "path-to-my-model" model = AutoModelForCausalLMWithValueHead.from_pretrained( model_id, peft_config=lora_config, device_map={"": 0}, quantization_config=bnb_config, )

Then I run my PPO Training loop (using a custom Pytorch dataloader because the PPO one does not support dynamic padding when streaming a large dataset):

from tqdm import tqdm

Training parameters

epochs = 4 generation_kwargs = { "min_length": -1, "top_k": 0.0, "top_p": 1.0, "do_sample": True, "pad_token_id": tokenizer.pad_token_id, "max_new_tokens": 2560, #since max input is 2048 want to give some space for more }

Training loop

for epoch in tqdm(range(epochs), "epoch: "): batchnum = 0 numsave = 1 for batch in tqdm(dataloader): #swtiched from "ppo_trainer.dataloader" to dataloader which is defined in cell above batchnum += 1 query_tensors = batch["input_ids"] #'input_ids' is just tokenized query. List of integers

    # Get response from SFTModel
    response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs) 
    batch["response"] = [tokenizer.decode(r.squeeze(), skip_special_tokens=True) for r in response_tensors]

    # Decode and compute rewards (batch"query" was never encoded so just use that)
    rewards = [reward_function(query, response) for query, response in zip(batch["query"], batch["response"])] #ADDD BACK

    # Run PPO step

    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)
    if batchnum == 25000:
        ppo_trainer.save_model(f"/content/drive/MyDrive/Worksheets AI/PPO_Llama_epoch_{epoch}_{numsave}")
        numsave = numsave + 1
        batchnum = 0

# Save the trained model after each epoch
ppo_trainer.save_pretrained(f"/content/drive/MyDrive/Worksheets AI/PPO_Llama_epoch_{epoch}")

I made the reward model myself. It is pretty extensive, and relies on a BERT model to first assign an integer rating to the response, but the BERT model is on CPU and as far as I could tell the PPO algorithm does not compute gradients based on how the reward is calculated (unless I am wrong in which case I think I know where the problem is).

My question is why, with the 40 gb of RAM that Google Colab's A100 gives you, am I stil getting an OOM error? My GPU memory is at 7gb basically the whole time until the ppo_trainer.step() line, where it skyrockets to 40 and throws the error.

Also, here is the google colab I am using that has more details: https://colab.research.google.com/drive/17WHrsL6uK4EA94JywFh3fd6Q3A7OS1Jc?usp=sharing

github-actions[bot] commented 4 weeks ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.