Closed james77777778 closed 2 months ago
We can wait for https://github.com/keras-team/keras/pull/19595 to simplify the modification. However, a new release is still required for the changes.
Thanks for the PR! It looks broadly reasonable to me. I'll defer to @mattdangerw for final review.
I've updated the PR to fix the compatibility issue with Keras2.
The choice of which ReversibleEmbedding
to import will be determined by config.keras_3()
.
Thanks for the PR! I just got back from vacay; will review tomorrow!
Please check the reopend PR here: https://github.com/keras-team/keras-nlp/pull/1612
I understand these patches might seem inelegant, but I've struggled to find a better way to pass
dtype_policy
for each layer in subclasses (GemmaBackbone
,CachedGemmaAttention
andGemmaDecoderBlock
)It is worth noting that I've overridden most of the int8 quantization code in
ReversibleEmbedding
inherited fromEmbedding
. After carefully reading the code fromgemma_pytorch
, I found that the shape of the scalar for the embedding is set to(input_dim,)
. Asinput_dim
is larger thatoutput_dim
, this will significantly reduce quantization error.Here are some numbers:
Model outputs:
Standalone script:
gemma_int8.py
```python import argparse import json import os import pathlib import keras import psutil import tensorflow as tf import keras_nlp # Setup kaggle information os.environ["KAGGLE_USERNAME"] = "xxxxx" os.environ["KAGGLE_KEY"] = "xxxxx" def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--model", default="gemma_1.1_instruct_2b_en", choices=["gemma_1.1_instruct_2b_en", "gemma_1.1_instruct_7b_en"], help="Which model to demonstrate", ) parser.add_argument( "--path", default=".", help="Path to save and load the model", ) parser.add_argument( "--save", action="store_true", help="Quantize and save the model", ) args = parser.parse_args() return args def get_memory_usage(): # From CPU or GPU:0 try: memory_stats = tf.config.experimental.get_memory_info("GPU:0") peak_usage = memory_stats["peak"] / (2**30) except Exception: memory_usage = psutil.Process().memory_info().rss peak_usage = memory_usage / (2**30) return peak_usage def save_int8_model( preset: str, model: keras_nlp.models.GemmaCausalLM, path: pathlib.Path ): model.quantize("int8") model.summary() # Save config config = keras.saving.serialize_keras_object(model) with open(path / f"{preset}_int8.json", "w") as f: f.write(json.dumps(config)) # Save weights model.save_weights(path / f"{preset}_int8.weights.h5") def load(config_path: pathlib.Path, weights_path: pathlib.Path, preset: str): # Load by config file with open(config_path, "r") as f: config = json.loads(f.read()) model: keras_nlp.models.GemmaCausalLM = ( keras.saving.deserialize_keras_object(config) ) # Load weights model.load_weights(weights_path) # Load preset assets model.preprocessor.tokenizer.load_preset_assets(preset) return model if __name__ == "__main__": keras.config.set_dtype_policy("bfloat16") x = keras.ops.ones([1]) * keras.ops.ones([1]) # Trigger TF dummy logs args = get_args() path = pathlib.Path(args.path) print(f"Peak memory usage (init): {get_memory_usage():.3f} GB") # Save if args.save: model = keras_nlp.models.GemmaCausalLM.from_preset(args.model) model.summary() print( "Peak memory usage (loaded float model): " f"{get_memory_usage():.3f} GB" ) save_int8_model(args.model, model, path) # Load else: config_path = path / f"{args.model}_int8.json" weights_path = path / f"{args.model}_int8.weights.h5" model = load(config_path, weights_path, args.model) print( "Peak memory usage (loaded int8 model): " f"{get_memory_usage():.3f} GB" ) print(model.generate("What is Keras?", max_length=128)) ```cc @fchollet