neuralmagic / AutoFP8

Apache License 2.0
150 stars 17 forks source link

Use `torch.inference_mode()` for lower memory usage during calibration #20

Closed mgoin closed 3 months ago

mgoin commented 3 months ago

On an H100 80B the calibration of a Llama 3 8B with a ~8192 sequence length input would cause OOM issues. With the small addition of with torch.inference_mode(): to the calibration loop, I see only a peak usage of ~15GB.

Snippet used for testing:

from transformers import AutoTokenizer

from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig

pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8"

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
seq_len = 8192
examples = ["hello " * seq_len]
examples = tokenizer(examples, return_tensors="pt").to("cuda")

quantize_config = BaseQuantizeConfig(
    quant_method="fp8",
    activation_scheme="static",
    ignore_patterns=["re:.*lm_head"],
)

model = AutoFP8ForCausalLM.from_pretrained(
    pretrained_model_dir, quantize_config=quantize_config
)
model.quantize(examples)