Closed mgoin closed 3 months ago
On an H100 80B the calibration of a Llama 3 8B with a ~8192 sequence length input would cause OOM issues. With the small addition of with torch.inference_mode(): to the calibration loop, I see only a peak usage of ~15GB.
with torch.inference_mode():
Snippet used for testing:
from transformers import AutoTokenizer from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8" tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) seq_len = 8192 examples = ["hello " * seq_len] examples = tokenizer(examples, return_tensors="pt").to("cuda") quantize_config = BaseQuantizeConfig( quant_method="fp8", activation_scheme="static", ignore_patterns=["re:.*lm_head"], ) model = AutoFP8ForCausalLM.from_pretrained( pretrained_model_dir, quantize_config=quantize_config ) model.quantize(examples)
On an H100 80B the calibration of a Llama 3 8B with a ~8192 sequence length input would cause OOM issues. With the small addition of
with torch.inference_mode():
to the calibration loop, I see only a peak usage of ~15GB.Snippet used for testing: