keras-team / keras-nlp

Modular Natural Language Processing workflows with Keras
Apache License 2.0
730 stars 215 forks source link

Add quantization support for `Gemma`, `Gemma2` and `PaliGemma` #1670

Closed james77777778 closed 1 day ago

james77777778 commented 1 week ago

We will need a new release of Keras for this. Currently, I have built the PR based on the master branch of Keras.

The implementation is simple and clean after introducing DTypePolicyMap and some other fixes. Thanks to @fchollet and @mattdangerw for their help.

It is worth noting that float8 training & inference are also supported in this PR. You can check test_quantize for this.

Some numbers:

Model Memory Usage (bfloat16) Memory Usage (int8) Weights (kagglehub) Weights (int8) Note
"gemma_1.1_instruct_2b_en" 5.69GB 2.82GB 4.7GB 2.4GB
"gemma2_instruct_9b_en" 20.93GB 10.14GB 18GB 8.7GB Measured on CPU
"pali_gemma_3b_mix_224" 6.52GB 3.22GB 5.5GB 2.8GB

Script:

int8_gemma.py ```python import argparse import os import pathlib import time import typing import keras import psutil import tensorflow as tf import keras_nlp # Setup kaggle information os.environ["KAGGLE_USERNAME"] = "xxx" os.environ["KAGGLE_KEY"] = "xxx" def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--model", default="pali_gemma_3b_mix_224", choices=[ "gemma_1.1_instruct_2b_en", "pali_gemma_3b_mix_224", "gemma2_instruct_9b_en", ], help="Which model to demonstrate", ) parser.add_argument( "--path", default=".", help="Path to save and load the model", ) parser.add_argument( "--save", action="store_true", help="Quantize and save the model", ) args = parser.parse_args() return args def get_memory_usage(): # From CPU or GPU:0 try: memory_stats = tf.config.experimental.get_memory_info("GPU:0") peak_usage = memory_stats["peak"] / (2**30) except Exception: memory_usage = psutil.Process().memory_info().rss peak_usage = memory_usage / (2**30) return peak_usage def benchmark_pali_gemma( model: keras_nlp.models.PaliGemmaCausalLM, image, prompt: str ): # Warmup model.generate({"images": image, "prompts": prompt}, max_length=128) # Benchmark st = time.time() result = model.generate( {"images": image, "prompts": prompt}, max_length=128 ) ed = time.time() return result, ed - st def benchmark_gemma(model: keras_nlp.models.GemmaCausalLM, prompt: str): # Warmup model.generate(prompt, max_length=128) # Benchmark st = time.time() result = model.generate(prompt, max_length=128) ed = time.time() return result, ed - st def save_int8_model( preset: str, model: typing.Union[ keras_nlp.models.GemmaCausalLM, keras_nlp.models.PaliGemmaCausalLM, ], ): model.quantize("int8") model.summary() model.save(f"{preset}_int8.keras") def load(model_path: pathlib.Path): model = keras.saving.load_model(model_path) return model if __name__ == "__main__": keras.config.set_dtype_policy("bfloat16") x = keras.ops.ones([1]) * keras.ops.ones([1]) # Trigger TF dummy logs args = get_args() path = pathlib.Path(args.path) is_pali_gemma = "pali_gemma" in str(args.model) print(f"Peak memory usage (init): {get_memory_usage():.3f} GB") # Save if args.save: if is_pali_gemma: model = keras_nlp.models.PaliGemmaCausalLM.from_preset(args.model) else: model = keras_nlp.models.GemmaCausalLM.from_preset(args.model) model.summary() print( "Peak memory usage (loaded float model): " f"{get_memory_usage():.3f} GB" ) save_int8_model(args.model, model) # Load else: model_path = path / f"{args.model}_int8.keras" model = load(model_path) print( "Peak memory usage (loaded int8 model): " f"{get_memory_usage():.3f} GB" ) if is_pali_gemma: image_path = keras.utils.get_file( "cow_beach_1.png", "https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png", ) image = keras.utils.load_img(image_path) image = keras.utils.img_to_array(image, "channels_last") prompt = "describe en\n" result, elapsed_time = benchmark_pali_gemma(model, image, prompt) else: prompt = "What is Keras3?" result, elapsed_time = benchmark_gemma(model, prompt) print(result) print( f"The elapsed time for model inference: {elapsed_time:.3f} seconds" ) ```

Usage:

# Get quantized model
python int8_gemma.py --model "gemma_1.1_instruct_2b_en" --save
python int8_gemma.py --model "gemma2_instruct_9b_en" --save
python int8_gemma.py --model "pali_gemma_3b_mix_224" --save
# Run
python int8_gemma.py --model "gemma_1.1_instruct_2b_en"
python int8_gemma.py --model "gemma2_instruct_9b_en"
python int8_gemma.py --model "pali_gemma_3b_mix_224"

Outputs:

# Gemma
What is Keras3?

Keras3 is a high-level neural network library built on top of Keras 2. It provides a simplified and more efficient way to build and train deep learning models.

**Key features of Keras3:**

- Simplified API with Keras 2 compatibility
- High-level abstractions for common tasks
- Improved performance and efficiency
- Support for modern neural network architectures

**Benefits of using Keras3:**

- Easier to learn and use
- Faster and more accurate models
- Reduced development time
- Improved portability across different hardware platforms

**How to use Keras3:**

- Import

# PaliGemma
describe en
In this image I can see a cow which affor is in brown color and white color. I can see the sand. In the background I can see the water and the sky.
james77777778 commented 1 week ago

This PR should be ready for reviewing. Both Gemma and PaliGemma now support quantization (int8 and float8).

james77777778 commented 1 week ago

Hi @fchollet @mattdangerw I have added quantization support for Gemma2 (actually, adding tests is sufficient :) ) Please let me know if any updates are needed.

mattdangerw commented 5 days ago

@james77777778 thanks so much! Sorry for the delay, I was out last week, but just got back in town. Will take a look tomorrow!

james77777778 commented 4 days ago

No hurry. Please take your time.

mattdangerw commented 1 day ago

@james77777778 thanks for the changes!

As soon as testing is all green I will pull this in, especially since the US is about to go into holiday until next Monday.

I think the coverage is worth it, but let's keep seeing if we can think of ways to speed up these testing with decent coverage as a follow up.