keras-team / keras-hub

Pretrained model hub for Keras 3
Apache License 2.0
790 stars 242 forks source link

Modular caching layer / network design #1261

Closed jackd closed 1 year ago

jackd commented 1 year ago

From the roadmap, "KerasNLP is focused on modular and reusable building blocks". Having tried to implement some causal generative models I've found this not to be the case. The low level blocks are great - e.g. CachedMultiHeadAttention - but higher level blocks (TransformerDecoder, GPT2 models) have implementations tightly coupled to the underlying attention mechanism. Specifically, each of these higher level constructs needs to know not only how to call a layer in standard training, but also the types of inputs required during token-by-token generation.

Describe the solution you'd like Conceptually, caching layers (like CachedMultiHeadAttention) need to support 3 functionalities:

  1. a standard method to be used during training;
  2. a method to be used at inference time that takes a context vector (e.g. a prompt) and returns a cache (and possibly next output logits); and
  3. a method for iterative token generation based on the most recently sampled token and the cache produced in the previous step.

CachedMultiHeadAttention does all three in a single call method, with the presence or absence of input arguments dictating the behaviour, but it could just as easily be implemented as

class CachedMultiHeadAttention(Layer):
    def call(self, inputs):
        ...
        return normal_output

    def call_and_create_cache(self, inputs, last_valid_index, max_length):
        ...
        return normal_output, cache

    def call_with_cache(self, inputs, cache, current_index):
        ...
        return normal_output, updated_cache

Higher level layers/models can then implement call_and_create_cache and call_with_cache based on the keras graph implied in call, substituting out the relevant calls for call_and_create_cache or call_with_cache. Cache creation / storage / retrieval would have to be at the call node level, rather than the child-layer level since each layer could potentially be called multiple times and would need it's own cache for each call, but the keras infrastructure is already set up to support this.

This would allow:

I've experimented with this here (CausalLM model here) and it works - though there are no doubt edge cases that aren't accounted for. It's based on Function._run_through_graph and uses private member _nodes_by_depth.

I would be very happy to contribute such a feature. Before I spend any more time on it though, a few questions:

  1. Would such a feature be considered?
  2. What would the API be like?
  3. If it is desired, how should I structure PR(s)? A single large one, one with just the infrastructure and a separate one that refactors TransformerEncoder and GPT2 models?
mattdangerw commented 1 year ago

Thanks for the issue! Mulling this one over. Will reply more soon.

mattdangerw commented 1 year ago

Sorry for the delay! Some thoughts...

At a high-level I agree it's hard to bring different attention to our encoder/decoder blocks. Note that we don't care about bringing a different attention mechanism to GPT2--fork if you want, but GPT2 is GPT2. However a lot of people want a "stock" transformer block (attention + feedfoward + residuals) with a modified attention mechanism. We should explore making this easier.

I do think there are some design constraints we should follow about that aren't addressed above:

1) We should try to route all layer computaiton through __call__ -> call. It is better to add new inputs and options to call() (like the optional cache in our layer), then to make a new call-like method that don't use __call__. A lot of Keras proper relies on the relationship between __call__ -> call. Autocasting variables (for mixed precision), eager building of state, functional tracing, to name a few, see here. 2) We should probably decouple variable creation from any forward pass. This is especially important for our jax backend, which likes stateless functions. Specifically, before compiling a jax loop, things like the cache should be created and passed as loop variables. We could definitely consider exposing cache creation on the layer itself, but it might be better to do that like cache = attention_layer.create_cache(...), that just initializes the cache state, this could be called multiple times as needed. 3) We should avoid using functional model internals like Function._run_through_graph and _nodes_by_depth from KerasNLP. This is just too fragile and would slow down changes to Keras proper.

One option for making it easier to use TransformerDecoder with different attention would be a subclass. Maybe there are other approaches we could consider?

class CustomDecoder(keras_nlp.layers.CustomDecoder):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def create_attention_layer(self, inputs_shape):
        # Override the default MultiHeadAttention here.
        # Examples, RoPE embeddings, multi-query, T5 style trainable bias, etc.
        return keras.layers.CustomAttention(
            num_heads=self.num_heads,
            key_dim=int(inputs_shape[-1] // self.num_heads),
            dropout=self.dropout,
            kernel_initializer=clone_initializer(self.kernel_initializer),
            bias_initializer=clone_initializer(self.bias_initializer),
        )

Anyway, I suspect more discussion might be needed here, but maybe this helps understand the design constraints we are working with?

jackd commented 1 year ago

maybe this helps understand the design constraints we are working with?

It absolutely does, thanks! I thought 1 and 3 might be constraints - just wasn't sure how "hard" they were (I'm a hacker, not a maintainer - can you tell?).

Re: CustomDecoder: this just feels like it's kicking the can down the road. Sure, this would make this particular stock transformer block for customizable, but what if I want to change the block architecture? I still essentially need to write two separate forward passes - one for when a cache is present and one for when one is absent (and potentially a third "create_cache" variant). Going up a level things get even worse - a CustomBackbone would have one forward pass, and a CustomCausalLM would need to define two more (one for cache creation, one for cache update). I'm not so worried with how a particular layer creates/uses a cache, but in finding a maintainable way of passing that cache around.

Regarding the restrictions you've enumerated above:

  1. We should try to route all layer computation through __call__ -> call

"Leaf" caching layers (i.e. those that create/use/manage their own caches) still have 3 separate call_x functions, but they are all accessed through __call__ -> call. This adds the restriction that the cache must be a single tensor (i.e. cannot be a list/dict/nested structure).

  1. We should probably decouple variable creation from any forward pass.

"variable" is an overloaded term, so I'm going to refer to these as "cache"s, but I get your drift. Conceptually this would make things simpler, though I don't see a way of making it computationally efficient. Most attention mechanisms have different parallel / sequential implementations, e.g. if we have a prompt for GPT2, we want to compute the valid lower triangular sub-matrix corresponding to the prompt in one go without any existing cache, then update that cache token-by-token during generation.

In the below I've left it coupled with the forward without a cache, but if you can show there's an efficient way of doing things in a decoupled manner it wouldn't be hard to decouple.

  1. We should avoid using functional model internals

This was almost a deal breaker for my involvement since I didn't think keras exposed any of its underlying graph components publicly... until I found clone_model, which is essentially just a wrapper around _run_through_graph.

I've included a draft implementation below - see the bottom for a full script including the immediately snippets below, but I might take a moment to illustrate the interface from an implementers perspective with a simple example. The general idea is that "leaf" caching layers - layers which know how to create, use and update their own states - must implement two methods: one forward method without a cache that also creates the cache, and another forward pass using and potentially updating the cache.

class LagAndAdd(CachingLayer):
    def call_and_create_cache(self, x):
        lagged = keras.ops.pad(x[:, :-1], ((0, 0), (1, 0), (0, 0)))
        # always return the cache
        cache = x[:, -1:]
        return x + lagged, cache

    def call_with_cache(self, x, cache):
        updated_cache = x
        return x + cache, updated_cache

We can construct models using these layers as if they were regular layers, ignoring their caching potential.

inp = keras.Input((None, 3))
x = LagAndAdd()(inp)
x = LagAndAdd()(x)
base_model = keras.Model(inp, x)
x = keras.random.normal((5, 7, 3))
base_out = base_model(x)

We can then transform these models using graph transformations to ones that do both a forward pass and cache creation, and (separately) a model that performs a forward pass with cache + cache update

call_and_create_cache_model = get_call_and_create_cache(base_model)
call_with_cache_model = get_call_with_cache(call_and_create_cache_model)

leading, cache = call_and_create_cache_model(x[:, :-1])
trailing, updated_cache = call_with_cache_model((x[:, -1:], cache))

np.testing.assert_allclose(keras.ops.concatenate((leading, trailing), axis=1), base_out)
print("Passed functional model version")

and you can do a similar thing with a layer which isn't a model but defines a forward pass using other layers.

class CompoundCache(CachingFunctionalLayer):
    def build(self, input_shape):
        if self.built:
            return
        self.layer0 = LagAndAdd()
        self.layer1 = LagAndAdd()
        for layer in (self.layer0, self.layer1):
            layer.build(input_shape)
        super().build(input_shape)

    def call_without_cache(self, x):
        residual = x
        x = self.layer0(x)
        x = self.layer1(x)
        return x + residual

layer = CompoundCache()

base_out = layer(x)
leading, cache = layer(x[:, :-1], return_cache=True)
trailing, cache = layer(x[:, -1:], cache=cache)
np.testing.assert_allclose(keras.ops.concatenate((leading, trailing), axis=1), base_out)
print("Passed CachingFunctionalLayer version")

Below are backing base classes, and everything from above for convenience. It'll need some more rigorous tests and tweaks before it's ready for a PR, but is this the kind of thing that might be accepted? That said, if a PR is a better place to continue discussion let me know and I'll put one together.

import abc
import typing as tp
import tree
import keras_core as keras

class CachingLayer(keras.layers.Layer):
    """A layer that can create and update a cache for iterative inference.

    Implementations should implement `call_and_create_cache` and
    `call_with_cache`. They may optionally implement `call_without_cache`
    if creation of the cache in `call_and_create_cache` is expensive.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.uses_cache = True

    def call(self, *args, cache=None, return_cache=None, **kwargs):
        if cache is None:
            if return_cache:
                return self.call_and_create_cache(*args, **kwargs)
            return self.call_without_cache(*args, **kwargs)
        assert return_cache is None or return_cache
        return self.call_with_cache(*args, cache=cache, **kwargs)

    @abc.abstractmethod
    def call_and_create_cache(self, *args, **kwargs):
        """Get the output of this layer and create a cache.

        The returned cache may be used in subsequent calls to
        `call_with_cache`.
        """

    @abc.abstractmethod
    def call_with_cache(self, *args, cache, **kwargs):
        """Get the output of this layer using a previously created cache.

        This method should return *args, where args[:-1] is the normal
        output of the layer, and args[-1] is a single-tensor cache.
        """

    def call_without_cache(self, *args, **kwargs):
        """Get the output of this layer without a cache input or output.

        By default, this redirects to `call_and_create_cache` and throws
        out the `cache`. Implementers should override this method if
        there is a more optimal implementation that does not involve
        creating the cache at all.
        """
        *output, cache = self.call_and_create_cache(*args, **kwargs)
        del cache
        if len(output) == 1:
            return output[0]
        return output

def get_call_and_create_cache(model: keras.Model):
    """
    Get `call_and_create_cache` model from `call_without_cache`.
    """
    cache_outputs = []

    def clone_function(op):
        if isinstance(op, CachingLayer):

            def f(*args, **kwargs):
                kwargs = dict(kwargs)
                kwargs["return_cache"] = True
                *output, cache = op(*args, **kwargs)
                cache_outputs.append(cache)
                if len(output) == 1:
                    return output[0]
                return output

            return f
        return op

    cloned = keras.models.clone_model(
        model, model.input, clone_function=clone_function
    )
    output = cloned.output
    cache_output = keras.ops.stack(cache_outputs, axis=1)
    if isinstance(output, keras.KerasTensor):
        output = (output, cache_output)
    else:
        output = (*output, cache_output)
    return keras.Model(cloned.input, output)

def get_call_with_cache(model: keras.Model):
    """
    Get `call_with_cache` model from a `call_and_create_cache` model.
    """
    cache_output = model.output[-1]
    cache_input = keras.Input(batch_shape=cache_output.shape, dtype=cache_output.dtype)
    cache_inputs = keras.ops.unstack(cache_input, axis=1)
    # reverse order so we can pop in order
    cache_inputs = cache_inputs[-1::-1]

    def clone_function(op):
        if getattr(op, "uses_cache", False):

            def f(*args, **kwargs):
                assert kwargs["return_cache"], kwargs
                return op(*args, cache=cache_inputs.pop(), **kwargs)

            return f
        return op

    inp = model.input
    if isinstance(inp, keras.KerasTensor):
        inputs = (inp, cache_input)
    else:
        inputs = (*inp, cache_input)
    return keras.models.clone_model(model, inputs, clone_function=clone_function)

def _is_tensor(x) -> bool:
    return hasattr(x, "__array__")

Tensor = tp.Any

class CachingFunctionalLayer(CachingLayer):
    """A caching layer made from other caching layers.

    Implementations should implement `call_without_cache`, which should
    be conceptually similar to a layer's standard `call` method without
    concerns for the absence, presence or creation of any caches used by
    constituent layers.

    Currently, the only condition on constituent caching layers is that
    they all produce caches of the same size such that they can be stacked.

    The main difference between a normal layer's `call` method and
    `call_without_cache` is that `call_without_cache` may be called with
    symbolic inputs (`keras.KerasTensor`s). This is used for graph
    transformations that create `call_and_create_cache` and `call_with_cache`
    implementations.
    """
    def _get_call_and_create_cache_model(
        self, args, kwargs
    ) -> tp.Tuple[keras.Model, tp.List[Tensor]]:
        tensors = [arg for arg in tree.flatten((args, kwargs)) if _is_tensor(arg)]
        model_args, model_kwargs = tree.map_structure(
            lambda x: keras.Input(batch_shape=x.shape, dtype=x.dtype)
            if _is_tensor(x)
            else x,
            (args, kwargs),
        )
        inputs = [
            arg
            for arg in tree.flatten((model_args, model_kwargs))
            if keras.backend.is_keras_tensor(arg)
        ]
        output = self.call_without_cache(*model_args, **model_kwargs)
        model = keras.Model(inputs, output)
        call_and_create_cache_model = get_call_and_create_cache(model)
        return call_and_create_cache_model, tensors

    def call_and_create_cache(self, *args, **kwargs):
        """Get the output of this layer and create a cache.

        The returned cache may be used in subsequent calls to
        `call_with_cache`.
        """
        model, tensors = self._get_call_and_create_cache_model(args, kwargs)
        return model(tensors)

    def call_with_cache(self, *args, cache, **kwargs):
        """Get the output of this layer using a previously created cache.

        This method should return *args, where args[:-1] is the normal
        output of the layer, and args[-1] is a single-tensor cache.
        """
        call_and_create_cache_model, tensors = self._get_call_and_create_cache_model(
            args, kwargs
        )
        call_with_cache_model = get_call_with_cache(call_and_create_cache_model)
        return call_with_cache_model((*tensors, cache))

    @abc.abstractmethod
    def call_without_cache(self, *args, **kwargs):
        """Get the output of this layer without a cache input or output.

        args and kwargs may contain symbolic tensors or backend tensors, but
        never both.
        """
        raise NotImplementedError("Abstract method")

def main():
    import numpy as np
    class LagAndAdd(CachingLayer):
        def call_and_create_cache(self, x):
            lagged = keras.ops.pad(x[:, :-1], ((0, 0), (1, 0), (0, 0)))
            # always return the cache
            cache = x[:, -1:]
            return x + lagged, cache

        def call_with_cache(self, x, cache):
            updated_cache = x
            return x + cache, updated_cache

    inp = keras.Input((None, 3))
    x = LagAndAdd()(inp)
    x = LagAndAdd()(x)
    base_model = keras.Model(inp, x)
    x = keras.random.normal((5, 7, 3))
    base_out = base_model(x)

    call_and_create_cache_model = get_call_and_create_cache(base_model)
    call_with_cache_model = get_call_with_cache(call_and_create_cache_model)

    leading, cache = call_and_create_cache_model(x[:, :-1])
    trailing, updated_cache = call_with_cache_model((x[:, -1:], cache))

    np.testing.assert_allclose(keras.ops.concatenate((leading, trailing), axis=1), base_out)
    print("Passed functional model version")

    class CompoundCache(CachingFunctionalLayer):
        def build(self, input_shape):
            if self.built:
                return
            self.layer0 = LagAndAdd()
            self.layer1 = LagAndAdd()
            for layer in (self.layer0, self.layer1):
                layer.build(input_shape)
            super().build(input_shape)

        def call_without_cache(self, x):
            residual = x
            x = self.layer0(x)
            x = self.layer1(x)
            return x + residual

    layer = CompoundCache()

    base_out = layer(x)
    leading, cache = layer(x[:, :-1], return_cache=True)
    trailing, cache = layer(x[:, -1:], cache=cache)
    np.testing.assert_allclose(keras.ops.concatenate((leading, trailing), axis=1), base_out)
    print("Passed CachingFunctionalLayer version")

if __name__ == '__main__':
    main()
mattdangerw commented 1 year ago

if we have a prompt for GPT2, we want to compute the valid lower triangular sub-matrix corresponding to the prompt in one go without any existing cache, then update that cache token-by-token during generation

I think this is actually supported today. And how we do it for our generative models. You can call cached_attn(full_query, value, key, cache, cache_update_index=0) to update your key/value cache for all known tokens in a fixed prompt input. Then you can call cached_attn(single_token_query, value, key, cache, cache_update_index=i) for i in range(first_unknown_token, desired_length).

Or you could do the whole cached computation in a loop one token at a time--even for the user supplied prompt. cached_attn(single_token_query, value, key, cache, cache_update_index=i) for i in range(0, desired_length). This would be slower walltime if the fixed part of the prompt is long, but still valid if you want to save on GPU memory, which is usually more the concern for many real world workflows these days.

It is worth noting that for good XLA compiled performance, you want to always pad all input lengths to the attn layer. So use a query/key/value length of either 1 or max generation length. And a cache length of max generation length. If you don't XLA performance will tank, quite differently from eager torch. XLA is our "first class citizen" for performance at train time and generation time.

Also worth noting that for the second case here (where you only compute a single token at a time and slice updates into your cache), you could never mix cache creation with call efficiently for jax. Jax will need the cache to be a loop variable in the compiled loop, and you don't actually want the forward pass to be run outside the compiled loop. So here you do want cache creation (just initializing nd.array of zeros) and call to be separate.

Not sure if that helps, but while not obvious the current signature allows seeding a cache and incrementally updating a cache.

mattdangerw commented 1 year ago

Anyway, at a high-level I think this is probably not something we would want a PR for right now. A few reasons.

First the UX is a little different from what we have. We don't really ship model in -> model out functions like call_and_create_cache_model, and instead try hard to just stick the existing abstractions e.g. layers, models and metrics. I don't think the benefit of easier customized transformer blocks would be worth adding big new concepts to our APIs. Writing your own custom transformer block is not that bad, and we don't want to carry a lot of complexity to make it easier.

The second is more to do with where all this lives and the progression of our Keras 3 release. We are somewhat in an interim state, but soon Keras 3 will be released, and this package can more readily rely on Keras 3 features. I suspect we will then want to push some of this development down into Keras 3...

1) Add grouped-query and multi-query attention to core keras. Quite popular and already a PR for this -> https://github.com/keras-team/keras/pull/18488 2) Add some form of cache/index args to all these layers (which could be used for bulk cache update, token by token cache update, or no cache pass as mentioned above). We could consider a function to initialize and empty cache here, which I do think needs to be split out because of the jax constraints on state (but maybe I am still missing things). 3) Add a good way to add rotary positions embedding to all these class, possibly just via subclass. 4) Finally, build on top of all these core Keras features for all KerasNLP models.

A nice to have here would be an easy pattern for creating a transformer block with custom attention, but I don't think a deal breaker. More important is to have the low level building blocks (multi-head, grouped query, etc) in core Keras, and the high-level popular models in KerasNLP. At the end of the day, there is still so much variation in transformer blocks that I don't really think it's a feasible design goal to cover them all in one class. So make your own "transformer block" from dense and attention layers will have to be a common path.

Hope that helps explain things! But wouldn't put the breaks on this by any means, just might be more something for it's own repo, because of slightly different design goals.

jackd commented 1 year ago

You can call cached_attn(full_query, value, key, cache, cache_update_index=0) to update your key/value cache for all known tokens in a fixed prompt input.

There I was thinking cache_update_index was the index of the cache being updated :S

Also worth noting that for the second case here (where you only compute a single token at a time and slice updates into your cache), you could never mix cache creation with call efficiently for jax. Jax will need the cache to be a loop variable in the compiled loop

I don't follow this. Surely there's an initialization stage before the loop. That's where call_and_create_cache would be run - once, before token-by-token generation. Currently in GPT2 you create the cache then update it once during _build_cache and then toss out the generated features - that involves inputs of different shapes than the token-by-token generation, so from what I know of jax it will need to do a separate trace anyway.

I should say I'm thinking about this from a concept more general than just transformers - particularly models like RWKV and RetNet that don't only have linear memory constraints.

Just read your second point and that all sounds fine. I'll close this for now - might comment later if I get around to splitting something like what's above into a separate repo for my own purposes in case others decide it's useful. Thanks for the idea bounce :)