Closed james77777778 closed 1 day ago
This PR should be ready for reviewing.
Both Gemma
and PaliGemma
now support quantization (int8 and float8).
Hi @fchollet @mattdangerw
I have added quantization support for Gemma2
(actually, adding tests is sufficient :) )
Please let me know if any updates are needed.
@james77777778 thanks so much! Sorry for the delay, I was out last week, but just got back in town. Will take a look tomorrow!
No hurry. Please take your time.
@james77777778 thanks for the changes!
As soon as testing is all green I will pull this in, especially since the US is about to go into holiday until next Monday.
I think the coverage is worth it, but let's keep seeing if we can think of ways to speed up these testing with decent coverage as a follow up.
We will need a new release of Keras for this. Currently, I have built the PR based on the master branch of Keras.The implementation is simple and clean after introducing
DTypePolicyMap
and some other fixes. Thanks to @fchollet and @mattdangerw for their help.It is worth noting that float8 training & inference are also supported in this PR. You can check
test_quantize
for this.Some numbers:
Script:
int8_gemma.py
```python import argparse import os import pathlib import time import typing import keras import psutil import tensorflow as tf import keras_nlp # Setup kaggle information os.environ["KAGGLE_USERNAME"] = "xxx" os.environ["KAGGLE_KEY"] = "xxx" def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--model", default="pali_gemma_3b_mix_224", choices=[ "gemma_1.1_instruct_2b_en", "pali_gemma_3b_mix_224", "gemma2_instruct_9b_en", ], help="Which model to demonstrate", ) parser.add_argument( "--path", default=".", help="Path to save and load the model", ) parser.add_argument( "--save", action="store_true", help="Quantize and save the model", ) args = parser.parse_args() return args def get_memory_usage(): # From CPU or GPU:0 try: memory_stats = tf.config.experimental.get_memory_info("GPU:0") peak_usage = memory_stats["peak"] / (2**30) except Exception: memory_usage = psutil.Process().memory_info().rss peak_usage = memory_usage / (2**30) return peak_usage def benchmark_pali_gemma( model: keras_nlp.models.PaliGemmaCausalLM, image, prompt: str ): # Warmup model.generate({"images": image, "prompts": prompt}, max_length=128) # Benchmark st = time.time() result = model.generate( {"images": image, "prompts": prompt}, max_length=128 ) ed = time.time() return result, ed - st def benchmark_gemma(model: keras_nlp.models.GemmaCausalLM, prompt: str): # Warmup model.generate(prompt, max_length=128) # Benchmark st = time.time() result = model.generate(prompt, max_length=128) ed = time.time() return result, ed - st def save_int8_model( preset: str, model: typing.Union[ keras_nlp.models.GemmaCausalLM, keras_nlp.models.PaliGemmaCausalLM, ], ): model.quantize("int8") model.summary() model.save(f"{preset}_int8.keras") def load(model_path: pathlib.Path): model = keras.saving.load_model(model_path) return model if __name__ == "__main__": keras.config.set_dtype_policy("bfloat16") x = keras.ops.ones([1]) * keras.ops.ones([1]) # Trigger TF dummy logs args = get_args() path = pathlib.Path(args.path) is_pali_gemma = "pali_gemma" in str(args.model) print(f"Peak memory usage (init): {get_memory_usage():.3f} GB") # Save if args.save: if is_pali_gemma: model = keras_nlp.models.PaliGemmaCausalLM.from_preset(args.model) else: model = keras_nlp.models.GemmaCausalLM.from_preset(args.model) model.summary() print( "Peak memory usage (loaded float model): " f"{get_memory_usage():.3f} GB" ) save_int8_model(args.model, model) # Load else: model_path = path / f"{args.model}_int8.keras" model = load(model_path) print( "Peak memory usage (loaded int8 model): " f"{get_memory_usage():.3f} GB" ) if is_pali_gemma: image_path = keras.utils.get_file( "cow_beach_1.png", "https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png", ) image = keras.utils.load_img(image_path) image = keras.utils.img_to_array(image, "channels_last") prompt = "describe en\n" result, elapsed_time = benchmark_pali_gemma(model, image, prompt) else: prompt = "What is Keras3?" result, elapsed_time = benchmark_gemma(model, prompt) print(result) print( f"The elapsed time for model inference: {elapsed_time:.3f} seconds" ) ```Usage:
Outputs: