philschmid / deep-learning-pytorch-huggingface

MIT License
580 stars 138 forks source link

Quantization question: #56

Open aptum11 opened 4 weeks ago

aptum11 commented 4 weeks ago

Using this code in order to fine-tune llama3 70b on AWS GPUs. Here we use BitsAndBytesConfig to quantize the model weights and load them as float 4.

    torch_dtype = torch.bfloat16
    quant_storage_dtype = torch.bfloat16

    quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch_dtype,
            bnb_4bit_quant_storage=quant_storage_dtype,
        )

    model = AutoModelForCausalLM.from_pretrained(
        script_args.model_id,
        quantization_config=quantization_config,
        attn_implementation="sdpa",
        torch_dtype=quant_storage_dtype,
        use_cache=False if training_args.gradient_checkpointing else True,
    )

But the output seems to be quant_storage_dtype = torch.bfloat16.

When i then try to merge LoRA weights with the model later:

from peft import AutoPeftModelForCausalLM
import torch

# # Load PEFT model on CPU
model = AutoPeftModelForCausalLM.from_pretrained(
     '/my-checkpoint-40',
     torch_dtype=torch.float16,
     low_cpu_mem_usage=True,
)
# Merge LoRA and base model and save
merged_model = model.merge_and_unload()
# Double-check if quantization is still effective
for name, param in merged_model.named_parameters():
    print(name, param.dtype, param.shape)  # This will show the dtype and shape of each parameter

Each layer is stored as quant_storage_dtype = torch.bfloat16.

The size of the safetensors added together for the fine-tuned model is 118 GB, and llama3-70b size is 127 GB. Meaning a ~7% reduction in the fine-tuned model.

Maybe a weird question but -> Is this model quantized? Is it semi-quantized? Should I quantize it further to reduce the size even more? (Need a smaller model because of the GPUs I have)

The quant_storage_dtype = torch.bfloat16 confuses me a bit.