keras-team / keras-nlp

Modular Natural Language Processing workflows with Keras
Apache License 2.0
734 stars 216 forks source link

[DO NOT MERGE] Experimental implementation of CausalLM with a Keras Functional backbone_with_cache #1598

Closed martin-gorner closed 2 weeks ago

martin-gorner commented 2 months ago

This is a proof of concept PR for the new layer graph cloning API in Keras (https://github.com/keras-team/keras/pull/19600). It is not meant to be merged as such but provide a tangible use case for the design of the new layer graph cloning API.

The problem to solve was:

In order to let users implement what they want in the backbone and have call_with_cache still work in XXXCausalLLM, it is necessary to add the caching to the backbone in a Keras Functional way, and respect the Functional layer graph of the backbone.

The new layer graph cloning API can be used:

This PR implements a Keras Functional call_with_cache for GPT2 and Gemma.

martin-gorner commented 2 months ago

For a demo of the fuctionality see this Colab: Model rewiring demo with LLMs.ipynb.

For example, you can insert control vectors into an LLM backbone with this clone_fn applied to the backbone:

def clone_fn(layer, *args, **kwargs):
  if isinstance(layer, keras_nlp.layers.TransformerDecoder):
    x = layer(*args, **kwargs)
    x = ControlVectorLayer()(x)
    return x
  else:
    return layer(*args, **kwargs) # identity

Before/after visualization of the backbone: beforeafter

martin-gorner commented 2 months ago

And here is what a re-wired backbone with caches looks like. Since it is now a proper Keras Functional model, it can be plotted. The layout is not the best but you can see the cache input fed into all layers and an updated cache fed out and collected at the end. cache

martin-gorner commented 2 months ago

Known issue: the max_length parameter in generate(prompt, max_length=64) does not work.

martin-gorner commented 2 months ago

I have changed the implementation to use the new new_model = clone_model(model, clone_function='lambda x:x', call_function=...) API instead of the previously suggested output = clone_layer_graph(input, output, clone_fn=...).

For this use case, i.e. rewiring a language model backbone with KV caches, the new API is a bit awkward, as it forces the user to use an intermediate model. In simplified code:

rewired_backbone = clone_model(backbone,
                               clone_function=lambda x:x, # no cloning
                               call_function=rewire_fn)

# Build a new backbone with caches in inputs and outputs.
input = {
    "token_ids": rewired_backbone.input["token_ids"],
    "cache": cache_input, # new input
    "cache_update_index": cache_update_index_input, # new input
}

# During the rewiring process, next_caches were collected, add them as a new output
next_cache = ops.stack(next_caches, axis=1)
output = (rewired_backbone.output, next_cache)

# create a new backbone that now uses caches in its forward pass
real_rewired_backbone = keras.Model(input, output, name=backbone.name + "_with_cache")
return real_rewired_backbone

The intermediate model rewired_backbone is "wrong" as it still has the original inputs of backbone, i.e. token_ids, and padding_mask, while its layer graph no longer uses padding_mask and now uses additional inputs cache_input and cache_update_index. The user has to create a new model real_rewired_backbone to fix those issues. It's also surprising that these graph connectedness issues were not caught when rewired_backbone was constructed. This code might fail in the future if graph connectedness checks are improved.

The previously suggested API did not have this awkwardness as it did not involve an intermediate Model. In simplified code:

# Build a new backbone with caches in inputs and outputs.
input = {
    "token_ids": backbone.input["token_ids"],
    "cache": cache_input, # new input
    "cache_update_index": cache_update_index_input, # new input
}

# This call can check for graph connectedness without failing
new_output = clone_layer_graph(input, backbone.output, clone_fn=rewire_fn)

# During the rewiring process, next_caches were collected, add them as a new output
next_cache = ops.stack(next_caches, axis=1)
output = (new_output, next_cache)

# create a new backbone that now uses caches in its forward pass
rewired_backbone = keras.Model(input, output, name=backbone.name + "_with_cache")
return rewired_backbone

Additional Note: I also noticed that the new API clones input tensors. For example, rewired_backbone.input["token_ids"] and backbone.input["token_ids"] are different tensors. The previously suggested API clone_layer_graph was keeping input layers identical as they do not need to be cloned. The new behavior might be surprising for users wondering why the token_ids input has changed during the re-wiring process and wether it's a bug.

mattdangerw commented 2 months ago

Update here... @fchollet has added support for optional functional inputs.

So what I think we can do is write a backbone that allows two optional inputs cache and index. Then we can write a causal lm that need zero knowledge of the internals of the backbone, just inputs and output during generation. So the entire "rewire" code can go away.

I think this is the right solution abstraction wise, and will allow a lot more aggressive model surgeries.

But landing this will still take some effort as we will need to drop Keras 2 codepaths in the library (Keras 2 will not support optional inputs).