Closed martin-gorner closed 2 weeks ago
For a demo of the fuctionality see this Colab: Model rewiring demo with LLMs.ipynb.
For example, you can insert control vectors into an LLM backbone with this clone_fn
applied to the backbone:
def clone_fn(layer, *args, **kwargs):
if isinstance(layer, keras_nlp.layers.TransformerDecoder):
x = layer(*args, **kwargs)
x = ControlVectorLayer()(x)
return x
else:
return layer(*args, **kwargs) # identity
Before/after visualization of the backbone:
And here is what a re-wired backbone with caches looks like. Since it is now a proper Keras Functional model, it can be plotted. The layout is not the best but you can see the cache input fed into all layers and an updated cache fed out and collected at the end.
Known issue: the max_length
parameter in generate(prompt, max_length=64)
does not work.
I have changed the implementation to use the new new_model = clone_model(model, clone_function='lambda x:x', call_function=...)
API instead of the previously suggested output = clone_layer_graph(input, output, clone_fn=...)
.
For this use case, i.e. rewiring a language model backbone with KV caches, the new API is a bit awkward, as it forces the user to use an intermediate model. In simplified code:
rewired_backbone = clone_model(backbone,
clone_function=lambda x:x, # no cloning
call_function=rewire_fn)
# Build a new backbone with caches in inputs and outputs.
input = {
"token_ids": rewired_backbone.input["token_ids"],
"cache": cache_input, # new input
"cache_update_index": cache_update_index_input, # new input
}
# During the rewiring process, next_caches were collected, add them as a new output
next_cache = ops.stack(next_caches, axis=1)
output = (rewired_backbone.output, next_cache)
# create a new backbone that now uses caches in its forward pass
real_rewired_backbone = keras.Model(input, output, name=backbone.name + "_with_cache")
return real_rewired_backbone
The intermediate model rewired_backbone
is "wrong" as it still has the original inputs of backbone
, i.e. token_ids
, and padding_mask
, while its layer graph no longer uses padding_mask
and now uses additional inputs cache_input
and cache_update_index
. The user has to create a new model real_rewired_backbone
to fix those issues. It's also surprising that these graph connectedness issues were not caught when rewired_backbone
was constructed. This code might fail in the future if graph connectedness checks are improved.
The previously suggested API did not have this awkwardness as it did not involve an intermediate Model
. In simplified code:
# Build a new backbone with caches in inputs and outputs.
input = {
"token_ids": backbone.input["token_ids"],
"cache": cache_input, # new input
"cache_update_index": cache_update_index_input, # new input
}
# This call can check for graph connectedness without failing
new_output = clone_layer_graph(input, backbone.output, clone_fn=rewire_fn)
# During the rewiring process, next_caches were collected, add them as a new output
next_cache = ops.stack(next_caches, axis=1)
output = (new_output, next_cache)
# create a new backbone that now uses caches in its forward pass
rewired_backbone = keras.Model(input, output, name=backbone.name + "_with_cache")
return rewired_backbone
Additional Note: I also noticed that the new API clones input tensors. For example, rewired_backbone.input["token_ids"]
and backbone.input["token_ids"]
are different tensors. The previously suggested API clone_layer_graph
was keeping input layers identical as they do not need to be cloned. The new behavior might be surprising for users wondering why the token_ids
input has changed during the re-wiring process and wether it's a bug.
Update here... @fchollet has added support for optional functional inputs.
So what I think we can do is write a backbone that allows two optional inputs cache
and index
. Then we can write a causal lm that need zero knowledge of the internals of the backbone, just inputs and output during generation. So the entire "rewire" code can go away.
I think this is the right solution abstraction wise, and will allow a lot more aggressive model surgeries.
But landing this will still take some effort as we will need to drop Keras 2 codepaths in the library (Keras 2 will not support optional inputs).
This is a proof of concept PR for the new layer graph cloning API in Keras (https://github.com/keras-team/keras/pull/19600). It is not meant to be merged as such but provide a tangible use case for the design of the new layer graph cloning API.
The problem to solve was:
In order to let users implement what they want in the backbone and have call_with_cache still work in XXXCausalLLM, it is necessary to add the caching to the backbone in a Keras Functional way, and respect the Functional layer graph of the backbone.
The new layer graph cloning API can be used:
This PR implements a Keras Functional call_with_cache for GPT2 and Gemma.