Closed james77777778 closed 2 weeks ago
I've heard from @awsaf49 that this PR worked nicely to significantly reduce the memory footprint for running Gemma 7B on a single 16GB GPU.
If we distribute the model across 2 GPUs, we can further increase the max_length
in model.generate
to 1024.
Ref: https://www.kaggle.com/code/awsaf49/gemma-1-1-7b-int8-load
@james77777778 thanks very much for this! Sorry for the delay reviewing, started poking around last week and should be done later today.
Hi @mattdangerw
Thank you for reviewing.
Using config.keras_3
to branch the logic in ReversibleEmbedding
is definitely feasible. I will fix it soon.
How does dtypes_policies work? Can you link relevant code or an explainer?
dtypes_policies
will be a dict[str, dtype_policy]
. It is required because, in quantization, we have int8 EinsumDense
mixed with other floating layers in CachedGemmaAttention
. We need to enable these layers to save and load the correct configuration of each dtype policy.
I don't think this will work as is for new uploads of models to Kaggle and elsewhere...
- To allow casting of float dtypes automatically...
I haven't thought of it before. Thanks for pointing it out.
To prevent the complete dtype spec from being uploaded, we can add the logic to detect whether it is a quantized model. We'll include the dtype spec if it's quantized and skip it otherwise. What do you think?
Let's think if there's a way we can do this where most of the logic is consolidated to base classes like
backbone.py
.
We could incorporate the detection I mentioned into backbone.py
.
However, it is necessary to add specific logic to each model (and even each layer) to support a more fine-grained dtype policy control.
I had some discussions with @fchollet here: https://github.com/keras-team/keras/issues/19381
The main challenge with subclasses is identifying which layers (and its sublayers) are quantized and enabling the serialization/deserialization for them.
We can't apply the setter trick I mentioned in that discussion for Gemma because it's already built when in from_config
.
To prevent the complete dtype spec from being uploaded, we can add the logic to detect whether it is a quantized model. We'll include the dtype spec if it's quantized and skip it otherwise. What do you think?
Yeah I think this is the right call? We also want to allow users to save their own quantized versions, upload them and share them where others could get the quantized config.
It feels slightly awkward/implicit, but I think it's the right way to preserve our current usages, but still leave room for what we want for quantization.
Long winded way to say sounds good to me, let's try it :)
Still want to think about how we can add this support to every model in the library with out as much of a code diff, but haven't had time to think on that question yet.
Long winded way to say sounds good to me, let's try it :)
Great, now that we've reached a consensus. These changes shouldn't take long.
Still want to think about how we can add this support to every model in the library with out as much of a code diff, but haven't had time to think on that question yet.
I don't have a better idea at the moment and I think that a refactor for dtype policy control in Core Keras might be necessary, especially for the subclasses. However, considering the backward compatibility, this could be a challenging task.
I will continue this PR once we have a new release of Keras (>3.3.3
).
There is an updated in DTypePolicy
that makes it possible to have a more flexible quantized dtype policy.
EX: QuantizedDTypePolicy("int8", source_name=None)
will interpret the source dtype policy using keras.dtype_policies.dtype_policy()
"a refactor for dtype policy control in Core Keras might be necessary, especially for the subclasses. However, considering the backward compatibility, this could be a challenging task."
My 2c: now is the time to do it as KerasNLP is still a fairly new library. As usage picks up, things will only get harder.
My 2c: now is the time to do it as KerasNLP is still a fairly new library. As usage picks up, things will only get harder.
Actually, I don't have a concrete idea for the refactoring at the moment.
I can point out the issue:
We rely on a single dtype
argument for subclass creation. So, without complex and verbose logic in __init__
, it is hard to support the mixing of floating / quantized dtype polices.
Maybe, we can implicitly pass a dict called sublayer_policies
if the layer contains sublayers to tackle this. However, currently, there is no way to automatically map between the policy and the sublayer in __init__
. I don't think it's a good idea to require users to hard-code it.
I think at the very least, we would probably want to handle this at the Backbone
level. E.g. automatically fill in a dict config for dtype by list all direct sublayers by name, something like that. Maybe add some common functionality to help layer construction. This is purely from a code maintainability standpoint--we want adding a new model to this repo to be relatively low friction, and we should factor out common logic where we can. This is a large addition of "diff" for each model.
I can point out the issue: We rely on a single dtype argument for subclass creation. So, without complex and verbose logic in init, it is hard to support the mixing of floating / quantized dtype polices.
A half formed thought is this was kinda similar to our distribution LayoutMap
problem. Basically for distributing variables across machines, we do this...
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = ...
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = ...
layout_map["decoder_block.*attention_output.*kernel"] = ...
distribution = keras.distribution.ModelParallel(mesh, layout_map)
# Globally.
keras.distribution.set_distribution(model_parallel)
# Or locally.
with distribution.scope():
...
A somewhat parallel API might be...
dtype_map = keras.dtype_policies.DtypeMap(default="bfloat16")
dtype_map["token_embedding/embeddings"] = "float32"
dtype_map["decoder_block.*attention.*(query|key|value).*kernel"] = "int8"
policy = keras.dtype_policies.DTypePolicy(dtype_map)
# Globally.
keras.dtype_policies.set_dtype_policy(policy)
# Or locally.
model = SomeModel(dtype=polilcy)
Then the init logic can stay simple...
self.sublayer = SomeLayer(..., dtype=self.dtype_policy)
Then we'd have one policy that we could pass around that gave a whole mapping of dtypes all the way down. That also gives a relatively quicker way to specify things like a dtype for all query projections, say. Would still take some figuring out during saving, etc.
@james77777778 wdyt?
Hi @mattdangerw
I think this idea is great and it should be easy to implement. Please refer to this PR for more details: https://github.com/keras-team/keras/pull/19783
I have changed the value type in the mapping from "dtype" to "dtype_policy". This adjustment should make more sense because we rely on that for the behavior of the quantized layers.
In that PR, the saving and loading issues have been solved. It wasn't a difficult one :)
I have submitted a new PR for this: https://github.com/keras-team/keras-nlp/pull/1670
I'm closing this PR now
I reopen this PR, which originated from https://github.com/keras-team/keras-nlp/pull/1591, due to the API generation issue
Notes:
dtype_policy
for each layer in subclassesReversibleEmbedding
is imported by the choice ofconfig.keras_3()
. Is there a better way to support both Keras 2 and Keras 3?keras.distribution
. Can someone provide guidance?Here are some numbers:
Model outputs:
Standalone script:
gemma_int8.py
```python import argparse import json import os import pathlib import keras import psutil import tensorflow as tf import keras_nlp # Setup kaggle information os.environ["KAGGLE_USERNAME"] = "xxxxx" os.environ["KAGGLE_KEY"] = "xxxxx" def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--model", default="gemma_1.1_instruct_2b_en", choices=["gemma_1.1_instruct_2b_en", "gemma_1.1_instruct_7b_en"], help="Which model to demonstrate", ) parser.add_argument( "--path", default=".", help="Path to save and load the model", ) parser.add_argument( "--save", action="store_true", help="Quantize and save the model", ) args = parser.parse_args() return args def get_memory_usage(): # From CPU or GPU:0 try: memory_stats = tf.config.experimental.get_memory_info("GPU:0") peak_usage = memory_stats["peak"] / (2**30) except Exception: memory_usage = psutil.Process().memory_info().rss peak_usage = memory_usage / (2**30) return peak_usage def save_int8_model( preset: str, model: keras_nlp.models.GemmaCausalLM, path: pathlib.Path ): model.quantize("int8") model.summary() # Save config config = keras.saving.serialize_keras_object(model) with open(path / f"{preset}_int8.json", "w") as f: f.write(json.dumps(config)) # Save weights model.save_weights(path / f"{preset}_int8.weights.h5") def load(config_path: pathlib.Path, weights_path: pathlib.Path, preset: str): # Load by config file with open(config_path, "r") as f: config = json.loads(f.read()) model: keras_nlp.models.GemmaCausalLM = ( keras.saving.deserialize_keras_object(config) ) # Load weights model.load_weights(weights_path) # Load preset assets model.preprocessor.tokenizer.load_preset_assets(preset) return model if __name__ == "__main__": keras.config.set_dtype_policy("bfloat16") x = keras.ops.ones([1]) * keras.ops.ones([1]) # Trigger TF dummy logs args = get_args() path = pathlib.Path(args.path) print(f"Peak memory usage (init): {get_memory_usage():.3f} GB") # Save if args.save: model = keras_nlp.models.GemmaCausalLM.from_preset(args.model) model.summary() print( "Peak memory usage (loaded float model): " f"{get_memory_usage():.3f} GB" ) save_int8_model(args.model, model, path) # Load else: config_path = path / f"{args.model}_int8.json" weights_path = path / f"{args.model}_int8.weights.h5" model = load(config_path, weights_path, args.model) print( "Peak memory usage (loaded int8 model): " f"{get_memory_usage():.3f} GB" ) print(model.generate("What is Keras?", max_length=128)) ```