keras-team / keras-nlp

Modular Natural Language Processing workflows with Keras
Apache License 2.0
734 stars 216 forks source link

Support dynamic int8 quantization for Gemma #1591

Closed james77777778 closed 2 months ago

james77777778 commented 2 months ago

I understand these patches might seem inelegant, but I've struggled to find a better way to pass dtype_policy for each layer in subclasses (GemmaBackbone, CachedGemmaAttention and GemmaDecoderBlock)

It is worth noting that I've overridden most of the int8 quantization code in ReversibleEmbedding inherited from Embedding. After carefully reading the code from gemma_pytorch, I found that the shape of the scalar for the embedding is set to (input_dim,). As input_dim is larger that output_dim, this will significantly reduce quantization error.

Here are some numbers:

Model Mem. (bfloat16) Mem. (int8) Weights (kagglehub) Weights (int8) Notes
"gemma_1.1_instruct_2b_en" 5.69GB 2.82GB 4.7GB 2.4GB
"gemma_1.1_instruct_7b_en" 18.97GB 8.69GB 16.0GB 8.0GB Run on CPU

Model outputs:

# "gemma_1.1_instruct_2b_en" int8 version
What is Keras?

Keras is an open-source machine learning library and framework that provides a high-level interface for building and training deep learning models. It is built on top of TensorFlow, allowing users to leverage the vast resources and capabilities of the TensorFlow ecosystem.

**Key features of Keras:**

- High-level API for building and training models
- Support for a wide range of deep learning algorithms
- Optimized for performance and scalability
- Integration with TensorFlow ecosystem for seamless data loading and processing

**Benefits of using Keras:**

- **Simplified model building:** Keras provides a user-friendly interface for constructing deep learning

# "gemma_1.1_instruct_7b_en" int8 version
What is Keras?

**Keras** is a high-level API for TensorFlow and other machine learning libraries. It provides a user-friendly and modular interface for building, training, and evaluating deep learning models. Keras is designed to be accessible to beginners and experienced ML engineers alike.

**Key features of Keras:**

- **Modular design:** Allows for easy composition of different layers and models.
- **TensorFlow compatibility:** Leverages the power of TensorFlow for backend computation.
- **Python API:** Written in Python, making it easy to use and integrate with other Python libraries.
- **Wide range of layers:**

Standalone script:

python3 gemma_int8.py --model gemma_1.1_instruct_2b_en --save
python3 gemma_int8.py --model gemma_1.1_instruct_2b_en

# Use CPU for the following commands
python3 gemma_int8.py --model gemma_1.1_instruct_7b_en --save
python3 gemma_int8.py --model gemma_1.1_instruct_7b_en
gemma_int8.py ```python import argparse import json import os import pathlib import keras import psutil import tensorflow as tf import keras_nlp # Setup kaggle information os.environ["KAGGLE_USERNAME"] = "xxxxx" os.environ["KAGGLE_KEY"] = "xxxxx" def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--model", default="gemma_1.1_instruct_2b_en", choices=["gemma_1.1_instruct_2b_en", "gemma_1.1_instruct_7b_en"], help="Which model to demonstrate", ) parser.add_argument( "--path", default=".", help="Path to save and load the model", ) parser.add_argument( "--save", action="store_true", help="Quantize and save the model", ) args = parser.parse_args() return args def get_memory_usage(): # From CPU or GPU:0 try: memory_stats = tf.config.experimental.get_memory_info("GPU:0") peak_usage = memory_stats["peak"] / (2**30) except Exception: memory_usage = psutil.Process().memory_info().rss peak_usage = memory_usage / (2**30) return peak_usage def save_int8_model( preset: str, model: keras_nlp.models.GemmaCausalLM, path: pathlib.Path ): model.quantize("int8") model.summary() # Save config config = keras.saving.serialize_keras_object(model) with open(path / f"{preset}_int8.json", "w") as f: f.write(json.dumps(config)) # Save weights model.save_weights(path / f"{preset}_int8.weights.h5") def load(config_path: pathlib.Path, weights_path: pathlib.Path, preset: str): # Load by config file with open(config_path, "r") as f: config = json.loads(f.read()) model: keras_nlp.models.GemmaCausalLM = ( keras.saving.deserialize_keras_object(config) ) # Load weights model.load_weights(weights_path) # Load preset assets model.preprocessor.tokenizer.load_preset_assets(preset) return model if __name__ == "__main__": keras.config.set_dtype_policy("bfloat16") x = keras.ops.ones([1]) * keras.ops.ones([1]) # Trigger TF dummy logs args = get_args() path = pathlib.Path(args.path) print(f"Peak memory usage (init): {get_memory_usage():.3f} GB") # Save if args.save: model = keras_nlp.models.GemmaCausalLM.from_preset(args.model) model.summary() print( "Peak memory usage (loaded float model): " f"{get_memory_usage():.3f} GB" ) save_int8_model(args.model, model, path) # Load else: config_path = path / f"{args.model}_int8.json" weights_path = path / f"{args.model}_int8.weights.h5" model = load(config_path, weights_path, args.model) print( "Peak memory usage (loaded int8 model): " f"{get_memory_usage():.3f} GB" ) print(model.generate("What is Keras?", max_length=128)) ```

cc @fchollet

james77777778 commented 2 months ago

We can wait for https://github.com/keras-team/keras/pull/19595 to simplify the modification. However, a new release is still required for the changes.

fchollet commented 2 months ago

Thanks for the PR! It looks broadly reasonable to me. I'll defer to @mattdangerw for final review.

james77777778 commented 2 months ago

I've updated the PR to fix the compatibility issue with Keras2.

The choice of which ReversibleEmbedding to import will be determined by config.keras_3().

mattdangerw commented 2 months ago

Thanks for the PR! I just got back from vacay; will review tomorrow!

james77777778 commented 2 months ago

Please check the reopend PR here: https://github.com/keras-team/keras-nlp/pull/1612