Lightning-AI / lit-llama

Implementation of the LLaMA language model based on nanoGPT. Supports flash attention, Int8 and GPTQ 4bit quantization, LoRA and LLaMA-Adapter fine-tuning, pre-training. Apache 2.0-licensed.
Apache License 2.0
5.9k stars 506 forks source link

How to train 13B version on 8bit with LoRA #413

Open raj-khare opened 1 year ago

raj-khare commented 1 year ago

I want to train the 13B Lllama but with 8bit quantization LoRA. Rn it takes 70GB of GPU RAM which is quite a lot. I'm using 8xA100-80GB.

lora.py

# Hyperparameters
learning_rate = 3e-4
batch_size = 64
micro_batch_size = 1
gradient_accumulation_iters = batch_size // micro_batch_size
assert gradient_accumulation_iters > 0
max_iters = 50000 * 3 // micro_batch_size
weight_decay = 0.0
max_seq_length = 4096  # see scripts/prepare_alpaca.py
lora_r = 8
lora_alpha = 16
lora_dropout = 0.05
warmup_iters = 100
image
def main(
    data_dir: str = "dataset", 
    pretrained_path: str = "/scratch/checkpoints/lit-llama/13B/lit-llama.pth",
    tokenizer_path: str = "/scratch/checkpoints/lit-llama/tokenizer.model",
    out_dir: str = "out/lora",
):

    fabric = L.Fabric(accelerator="cuda", devices=8, precision="bf16-true")
    fabric.launch()
    fabric.seed_everything(1337 + fabric.global_rank)
    ...
snake-4 commented 12 months ago

Training can't be done with quantized weights, as the steps would fall within the quantization error threshold.

AjibolaPy commented 2 months ago

Can I quantize and finetuned an llm of bf16 with qlora 4bit?.