keras-team / keras-nlp

Modular Natural Language Processing workflows with Keras
Apache License 2.0
734 stars 216 forks source link

Support dynamic int8 quantization for Gemma #1612

Closed james77777778 closed 2 weeks ago

james77777778 commented 2 months ago

I reopen this PR, which originated from https://github.com/keras-team/keras-nlp/pull/1591, due to the API generation issue

Notes:

Here are some numbers:

Model Mem. (bfloat16) Mem. (int8) Weights (kagglehub) Weights (int8) Notes
"gemma_1.1_instruct_2b_en" 5.69GB 2.82GB 4.7GB 2.4GB
"gemma_1.1_instruct_7b_en" 18.97GB 8.69GB 16.0GB 8.0GB Run on CPU

Model outputs:

# "gemma_1.1_instruct_2b_en" int8 version
What is Keras?

Keras is an open-source machine learning library and framework that provides a high-level interface for building and training deep learning models. It is built on top of TensorFlow, allowing users to leverage the vast resources and capabilities of the TensorFlow ecosystem.

**Key features of Keras:**

- High-level API for building and training models
- Support for a wide range of deep learning algorithms
- Optimized for performance and scalability
- Integration with TensorFlow ecosystem for seamless data loading and processing

**Benefits of using Keras:**

- **Simplified model building:** Keras provides a user-friendly interface for constructing deep learning

# "gemma_1.1_instruct_7b_en" int8 version
What is Keras?

**Keras** is a high-level API for TensorFlow and other machine learning libraries. It provides a user-friendly and modular interface for building, training, and evaluating deep learning models. Keras is designed to be accessible to beginners and experienced ML engineers alike.

**Key features of Keras:**

- **Modular design:** Allows for easy composition of different layers and models.
- **TensorFlow compatibility:** Leverages the power of TensorFlow for backend computation.
- **Python API:** Written in Python, making it easy to use and integrate with other Python libraries.
- **Wide range of layers:**

Standalone script:

python3 gemma_int8.py --model gemma_1.1_instruct_2b_en --save
python3 gemma_int8.py --model gemma_1.1_instruct_2b_en

# Use CPU for the following commands
python3 gemma_int8.py --model gemma_1.1_instruct_7b_en --save
python3 gemma_int8.py --model gemma_1.1_instruct_7b_en
gemma_int8.py ```python import argparse import json import os import pathlib import keras import psutil import tensorflow as tf import keras_nlp # Setup kaggle information os.environ["KAGGLE_USERNAME"] = "xxxxx" os.environ["KAGGLE_KEY"] = "xxxxx" def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--model", default="gemma_1.1_instruct_2b_en", choices=["gemma_1.1_instruct_2b_en", "gemma_1.1_instruct_7b_en"], help="Which model to demonstrate", ) parser.add_argument( "--path", default=".", help="Path to save and load the model", ) parser.add_argument( "--save", action="store_true", help="Quantize and save the model", ) args = parser.parse_args() return args def get_memory_usage(): # From CPU or GPU:0 try: memory_stats = tf.config.experimental.get_memory_info("GPU:0") peak_usage = memory_stats["peak"] / (2**30) except Exception: memory_usage = psutil.Process().memory_info().rss peak_usage = memory_usage / (2**30) return peak_usage def save_int8_model( preset: str, model: keras_nlp.models.GemmaCausalLM, path: pathlib.Path ): model.quantize("int8") model.summary() # Save config config = keras.saving.serialize_keras_object(model) with open(path / f"{preset}_int8.json", "w") as f: f.write(json.dumps(config)) # Save weights model.save_weights(path / f"{preset}_int8.weights.h5") def load(config_path: pathlib.Path, weights_path: pathlib.Path, preset: str): # Load by config file with open(config_path, "r") as f: config = json.loads(f.read()) model: keras_nlp.models.GemmaCausalLM = ( keras.saving.deserialize_keras_object(config) ) # Load weights model.load_weights(weights_path) # Load preset assets model.preprocessor.tokenizer.load_preset_assets(preset) return model if __name__ == "__main__": keras.config.set_dtype_policy("bfloat16") x = keras.ops.ones([1]) * keras.ops.ones([1]) # Trigger TF dummy logs args = get_args() path = pathlib.Path(args.path) print(f"Peak memory usage (init): {get_memory_usage():.3f} GB") # Save if args.save: model = keras_nlp.models.GemmaCausalLM.from_preset(args.model) model.summary() print( "Peak memory usage (loaded float model): " f"{get_memory_usage():.3f} GB" ) save_int8_model(args.model, model, path) # Load else: config_path = path / f"{args.model}_int8.json" weights_path = path / f"{args.model}_int8.weights.h5" model = load(config_path, weights_path, args.model) print( "Peak memory usage (loaded int8 model): " f"{get_memory_usage():.3f} GB" ) print(model.generate("What is Keras?", max_length=128)) ```
james77777778 commented 2 months ago

I've heard from @awsaf49 that this PR worked nicely to significantly reduce the memory footprint for running Gemma 7B on a single 16GB GPU. If we distribute the model across 2 GPUs, we can further increase the max_length in model.generate to 1024.

Ref: https://www.kaggle.com/code/awsaf49/gemma-1-1-7b-int8-load

mattdangerw commented 2 months ago

@james77777778 thanks very much for this! Sorry for the delay reviewing, started poking around last week and should be done later today.

james77777778 commented 2 months ago

Hi @mattdangerw Thank you for reviewing. Using config.keras_3 to branch the logic in ReversibleEmbedding is definitely feasible. I will fix it soon.

How does dtypes_policies work? Can you link relevant code or an explainer?

dtypes_policies will be a dict[str, dtype_policy]. It is required because, in quantization, we have int8 EinsumDense mixed with other floating layers in CachedGemmaAttention. We need to enable these layers to save and load the correct configuration of each dtype policy.

I don't think this will work as is for new uploads of models to Kaggle and elsewhere...

  1. To allow casting of float dtypes automatically...

I haven't thought of it before. Thanks for pointing it out.

To prevent the complete dtype spec from being uploaded, we can add the logic to detect whether it is a quantized model. We'll include the dtype spec if it's quantized and skip it otherwise. What do you think?

Let's think if there's a way we can do this where most of the logic is consolidated to base classes like backbone.py.

We could incorporate the detection I mentioned into backbone.py. However, it is necessary to add specific logic to each model (and even each layer) to support a more fine-grained dtype policy control.

I had some discussions with @fchollet here: https://github.com/keras-team/keras/issues/19381

The main challenge with subclasses is identifying which layers (and its sublayers) are quantized and enabling the serialization/deserialization for them. We can't apply the setter trick I mentioned in that discussion for Gemma because it's already built when in from_config.

mattdangerw commented 2 months ago

To prevent the complete dtype spec from being uploaded, we can add the logic to detect whether it is a quantized model. We'll include the dtype spec if it's quantized and skip it otherwise. What do you think?

Yeah I think this is the right call? We also want to allow users to save their own quantized versions, upload them and share them where others could get the quantized config.

It feels slightly awkward/implicit, but I think it's the right way to preserve our current usages, but still leave room for what we want for quantization.

Long winded way to say sounds good to me, let's try it :)

Still want to think about how we can add this support to every model in the library with out as much of a code diff, but haven't had time to think on that question yet.

james77777778 commented 2 months ago

Long winded way to say sounds good to me, let's try it :)

Great, now that we've reached a consensus. These changes shouldn't take long.

Still want to think about how we can add this support to every model in the library with out as much of a code diff, but haven't had time to think on that question yet.

I don't have a better idea at the moment and I think that a refactor for dtype policy control in Core Keras might be necessary, especially for the subclasses. However, considering the backward compatibility, this could be a challenging task.

james77777778 commented 1 month ago

I will continue this PR once we have a new release of Keras (>3.3.3).

There is an updated in DTypePolicy that makes it possible to have a more flexible quantized dtype policy. EX: QuantizedDTypePolicy("int8", source_name=None) will interpret the source dtype policy using keras.dtype_policies.dtype_policy()

martin-gorner commented 1 month ago

"a refactor for dtype policy control in Core Keras might be necessary, especially for the subclasses. However, considering the backward compatibility, this could be a challenging task."

My 2c: now is the time to do it as KerasNLP is still a fairly new library. As usage picks up, things will only get harder.

james77777778 commented 1 month ago

My 2c: now is the time to do it as KerasNLP is still a fairly new library. As usage picks up, things will only get harder.

Actually, I don't have a concrete idea for the refactoring at the moment.

I can point out the issue: We rely on a single dtype argument for subclass creation. So, without complex and verbose logic in __init__, it is hard to support the mixing of floating / quantized dtype polices.

Maybe, we can implicitly pass a dict called sublayer_policies if the layer contains sublayers to tackle this. However, currently, there is no way to automatically map between the policy and the sublayer in __init__. I don't think it's a good idea to require users to hard-code it.

mattdangerw commented 1 month ago

I think at the very least, we would probably want to handle this at the Backbone level. E.g. automatically fill in a dict config for dtype by list all direct sublayers by name, something like that. Maybe add some common functionality to help layer construction. This is purely from a code maintainability standpoint--we want adding a new model to this repo to be relatively low friction, and we should factor out common logic where we can. This is a large addition of "diff" for each model.

I can point out the issue: We rely on a single dtype argument for subclass creation. So, without complex and verbose logic in init, it is hard to support the mixing of floating / quantized dtype polices.

A half formed thought is this was kinda similar to our distribution LayoutMap problem. Basically for distributing variables across machines, we do this...

layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = ...
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = ...
layout_map["decoder_block.*attention_output.*kernel"] = ...
distribution = keras.distribution.ModelParallel(mesh, layout_map)
# Globally.
keras.distribution.set_distribution(model_parallel)
# Or locally.
with distribution.scope():
   ...

A somewhat parallel API might be...

dtype_map = keras.dtype_policies.DtypeMap(default="bfloat16")
dtype_map["token_embedding/embeddings"] = "float32"
dtype_map["decoder_block.*attention.*(query|key|value).*kernel"] = "int8"
policy = keras.dtype_policies.DTypePolicy(dtype_map)
# Globally.
keras.dtype_policies.set_dtype_policy(policy)
# Or locally.
model = SomeModel(dtype=polilcy)

Then the init logic can stay simple...

self.sublayer = SomeLayer(..., dtype=self.dtype_policy)

Then we'd have one policy that we could pass around that gave a whole mapping of dtypes all the way down. That also gives a relatively quicker way to specify things like a dtype for all query projections, say. Would still take some figuring out during saving, etc.

@james77777778 wdyt?

james77777778 commented 1 month ago

Hi @mattdangerw

I think this idea is great and it should be easy to implement. Please refer to this PR for more details: https://github.com/keras-team/keras/pull/19783

I have changed the value type in the mapping from "dtype" to "dtype_policy". This adjustment should make more sense because we rely on that for the behavior of the quantized layers.

In that PR, the saving and loading issues have been solved. It wasn't a difficult one :)

james77777778 commented 2 weeks ago

I have submitted a new PR for this: https://github.com/keras-team/keras-nlp/pull/1670

I'm closing this PR now