Closed jackd closed 1 year ago
Thanks for the issue! Mulling this one over. Will reply more soon.
Sorry for the delay! Some thoughts...
At a high-level I agree it's hard to bring different attention to our encoder/decoder blocks. Note that we don't care about bringing a different attention mechanism to GPT2
--fork if you want, but GPT2
is GPT2
. However a lot of people want a "stock" transformer block (attention + feedfoward + residuals) with a modified attention mechanism. We should explore making this easier.
I do think there are some design constraints we should follow about that aren't addressed above:
1) We should try to route all layer computaiton through __call__
-> call
. It is better to add new inputs and options to call()
(like the optional cache
in our layer), then to make a new call-like method that don't use __call__
. A lot of Keras proper relies on the relationship between __call__
-> call
. Autocasting variables (for mixed precision), eager building of state, functional tracing, to name a few, see here.
2) We should probably decouple variable creation from any forward pass. This is especially important for our jax backend, which likes stateless functions. Specifically, before compiling a jax loop, things like the cache should be created and passed as loop variables. We could definitely consider exposing cache creation on the layer itself, but it might be better to do that like cache = attention_layer.create_cache(...)
, that just initializes the cache state, this could be called multiple times as needed.
3) We should avoid using functional model internals like Function._run_through_graph
and _nodes_by_depth
from KerasNLP. This is just too fragile and would slow down changes to Keras proper.
One option for making it easier to use TransformerDecoder
with different attention would be a subclass. Maybe there are other approaches we could consider?
class CustomDecoder(keras_nlp.layers.CustomDecoder):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def create_attention_layer(self, inputs_shape):
# Override the default MultiHeadAttention here.
# Examples, RoPE embeddings, multi-query, T5 style trainable bias, etc.
return keras.layers.CustomAttention(
num_heads=self.num_heads,
key_dim=int(inputs_shape[-1] // self.num_heads),
dropout=self.dropout,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
)
Anyway, I suspect more discussion might be needed here, but maybe this helps understand the design constraints we are working with?
maybe this helps understand the design constraints we are working with?
It absolutely does, thanks! I thought 1 and 3 might be constraints - just wasn't sure how "hard" they were (I'm a hacker, not a maintainer - can you tell?).
Re: CustomDecoder
: this just feels like it's kicking the can down the road. Sure, this would make this particular stock transformer block for customizable, but what if I want to change the block architecture? I still essentially need to write two separate forward passes - one for when a cache is present and one for when one is absent (and potentially a third "create_cache" variant). Going up a level things get even worse - a CustomBackbone
would have one forward pass, and a CustomCausalLM
would need to define two more (one for cache creation, one for cache update). I'm not so worried with how a particular layer creates/uses a cache, but in finding a maintainable way of passing that cache around.
Regarding the restrictions you've enumerated above:
- We should try to route all layer computation through
__call__
->call
"Leaf" caching layers (i.e. those that create/use/manage their own caches) still have 3 separate call_x
functions, but they are all accessed through __call__
-> call
. This adds the restriction that the cache
must be a single tensor (i.e. cannot be a list/dict/nested structure).
- We should probably decouple variable creation from any forward pass.
"variable" is an overloaded term, so I'm going to refer to these as "cache"s, but I get your drift. Conceptually this would make things simpler, though I don't see a way of making it computationally efficient. Most attention mechanisms have different parallel / sequential implementations, e.g. if we have a prompt for GPT2, we want to compute the valid lower triangular sub-matrix corresponding to the prompt in one go without any existing cache, then update that cache token-by-token during generation.
In the below I've left it coupled with the forward without a cache, but if you can show there's an efficient way of doing things in a decoupled manner it wouldn't be hard to decouple.
- We should avoid using functional model internals
This was almost a deal breaker for my involvement since I didn't think keras
exposed any of its underlying graph components publicly... until I found clone_model
, which is essentially just a wrapper around _run_through_graph
.
I've included a draft implementation below - see the bottom for a full script including the immediately snippets below, but I might take a moment to illustrate the interface from an implementers perspective with a simple example. The general idea is that "leaf" caching layers - layers which know how to create, use and update their own states - must implement two methods: one forward method without a cache that also creates the cache, and another forward pass using and potentially updating the cache.
class LagAndAdd(CachingLayer):
def call_and_create_cache(self, x):
lagged = keras.ops.pad(x[:, :-1], ((0, 0), (1, 0), (0, 0)))
# always return the cache
cache = x[:, -1:]
return x + lagged, cache
def call_with_cache(self, x, cache):
updated_cache = x
return x + cache, updated_cache
We can construct models using these layers as if they were regular layers, ignoring their caching potential.
inp = keras.Input((None, 3))
x = LagAndAdd()(inp)
x = LagAndAdd()(x)
base_model = keras.Model(inp, x)
x = keras.random.normal((5, 7, 3))
base_out = base_model(x)
We can then transform these models using graph transformations to ones that do both a forward pass and cache creation, and (separately) a model that performs a forward pass with cache + cache update
call_and_create_cache_model = get_call_and_create_cache(base_model)
call_with_cache_model = get_call_with_cache(call_and_create_cache_model)
leading, cache = call_and_create_cache_model(x[:, :-1])
trailing, updated_cache = call_with_cache_model((x[:, -1:], cache))
np.testing.assert_allclose(keras.ops.concatenate((leading, trailing), axis=1), base_out)
print("Passed functional model version")
and you can do a similar thing with a layer which isn't a model but defines a forward pass using other layers.
class CompoundCache(CachingFunctionalLayer):
def build(self, input_shape):
if self.built:
return
self.layer0 = LagAndAdd()
self.layer1 = LagAndAdd()
for layer in (self.layer0, self.layer1):
layer.build(input_shape)
super().build(input_shape)
def call_without_cache(self, x):
residual = x
x = self.layer0(x)
x = self.layer1(x)
return x + residual
layer = CompoundCache()
base_out = layer(x)
leading, cache = layer(x[:, :-1], return_cache=True)
trailing, cache = layer(x[:, -1:], cache=cache)
np.testing.assert_allclose(keras.ops.concatenate((leading, trailing), axis=1), base_out)
print("Passed CachingFunctionalLayer version")
Below are backing base classes, and everything from above for convenience. It'll need some more rigorous tests and tweaks before it's ready for a PR, but is this the kind of thing that might be accepted? That said, if a PR is a better place to continue discussion let me know and I'll put one together.
import abc
import typing as tp
import tree
import keras_core as keras
class CachingLayer(keras.layers.Layer):
"""A layer that can create and update a cache for iterative inference.
Implementations should implement `call_and_create_cache` and
`call_with_cache`. They may optionally implement `call_without_cache`
if creation of the cache in `call_and_create_cache` is expensive.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.uses_cache = True
def call(self, *args, cache=None, return_cache=None, **kwargs):
if cache is None:
if return_cache:
return self.call_and_create_cache(*args, **kwargs)
return self.call_without_cache(*args, **kwargs)
assert return_cache is None or return_cache
return self.call_with_cache(*args, cache=cache, **kwargs)
@abc.abstractmethod
def call_and_create_cache(self, *args, **kwargs):
"""Get the output of this layer and create a cache.
The returned cache may be used in subsequent calls to
`call_with_cache`.
"""
@abc.abstractmethod
def call_with_cache(self, *args, cache, **kwargs):
"""Get the output of this layer using a previously created cache.
This method should return *args, where args[:-1] is the normal
output of the layer, and args[-1] is a single-tensor cache.
"""
def call_without_cache(self, *args, **kwargs):
"""Get the output of this layer without a cache input or output.
By default, this redirects to `call_and_create_cache` and throws
out the `cache`. Implementers should override this method if
there is a more optimal implementation that does not involve
creating the cache at all.
"""
*output, cache = self.call_and_create_cache(*args, **kwargs)
del cache
if len(output) == 1:
return output[0]
return output
def get_call_and_create_cache(model: keras.Model):
"""
Get `call_and_create_cache` model from `call_without_cache`.
"""
cache_outputs = []
def clone_function(op):
if isinstance(op, CachingLayer):
def f(*args, **kwargs):
kwargs = dict(kwargs)
kwargs["return_cache"] = True
*output, cache = op(*args, **kwargs)
cache_outputs.append(cache)
if len(output) == 1:
return output[0]
return output
return f
return op
cloned = keras.models.clone_model(
model, model.input, clone_function=clone_function
)
output = cloned.output
cache_output = keras.ops.stack(cache_outputs, axis=1)
if isinstance(output, keras.KerasTensor):
output = (output, cache_output)
else:
output = (*output, cache_output)
return keras.Model(cloned.input, output)
def get_call_with_cache(model: keras.Model):
"""
Get `call_with_cache` model from a `call_and_create_cache` model.
"""
cache_output = model.output[-1]
cache_input = keras.Input(batch_shape=cache_output.shape, dtype=cache_output.dtype)
cache_inputs = keras.ops.unstack(cache_input, axis=1)
# reverse order so we can pop in order
cache_inputs = cache_inputs[-1::-1]
def clone_function(op):
if getattr(op, "uses_cache", False):
def f(*args, **kwargs):
assert kwargs["return_cache"], kwargs
return op(*args, cache=cache_inputs.pop(), **kwargs)
return f
return op
inp = model.input
if isinstance(inp, keras.KerasTensor):
inputs = (inp, cache_input)
else:
inputs = (*inp, cache_input)
return keras.models.clone_model(model, inputs, clone_function=clone_function)
def _is_tensor(x) -> bool:
return hasattr(x, "__array__")
Tensor = tp.Any
class CachingFunctionalLayer(CachingLayer):
"""A caching layer made from other caching layers.
Implementations should implement `call_without_cache`, which should
be conceptually similar to a layer's standard `call` method without
concerns for the absence, presence or creation of any caches used by
constituent layers.
Currently, the only condition on constituent caching layers is that
they all produce caches of the same size such that they can be stacked.
The main difference between a normal layer's `call` method and
`call_without_cache` is that `call_without_cache` may be called with
symbolic inputs (`keras.KerasTensor`s). This is used for graph
transformations that create `call_and_create_cache` and `call_with_cache`
implementations.
"""
def _get_call_and_create_cache_model(
self, args, kwargs
) -> tp.Tuple[keras.Model, tp.List[Tensor]]:
tensors = [arg for arg in tree.flatten((args, kwargs)) if _is_tensor(arg)]
model_args, model_kwargs = tree.map_structure(
lambda x: keras.Input(batch_shape=x.shape, dtype=x.dtype)
if _is_tensor(x)
else x,
(args, kwargs),
)
inputs = [
arg
for arg in tree.flatten((model_args, model_kwargs))
if keras.backend.is_keras_tensor(arg)
]
output = self.call_without_cache(*model_args, **model_kwargs)
model = keras.Model(inputs, output)
call_and_create_cache_model = get_call_and_create_cache(model)
return call_and_create_cache_model, tensors
def call_and_create_cache(self, *args, **kwargs):
"""Get the output of this layer and create a cache.
The returned cache may be used in subsequent calls to
`call_with_cache`.
"""
model, tensors = self._get_call_and_create_cache_model(args, kwargs)
return model(tensors)
def call_with_cache(self, *args, cache, **kwargs):
"""Get the output of this layer using a previously created cache.
This method should return *args, where args[:-1] is the normal
output of the layer, and args[-1] is a single-tensor cache.
"""
call_and_create_cache_model, tensors = self._get_call_and_create_cache_model(
args, kwargs
)
call_with_cache_model = get_call_with_cache(call_and_create_cache_model)
return call_with_cache_model((*tensors, cache))
@abc.abstractmethod
def call_without_cache(self, *args, **kwargs):
"""Get the output of this layer without a cache input or output.
args and kwargs may contain symbolic tensors or backend tensors, but
never both.
"""
raise NotImplementedError("Abstract method")
def main():
import numpy as np
class LagAndAdd(CachingLayer):
def call_and_create_cache(self, x):
lagged = keras.ops.pad(x[:, :-1], ((0, 0), (1, 0), (0, 0)))
# always return the cache
cache = x[:, -1:]
return x + lagged, cache
def call_with_cache(self, x, cache):
updated_cache = x
return x + cache, updated_cache
inp = keras.Input((None, 3))
x = LagAndAdd()(inp)
x = LagAndAdd()(x)
base_model = keras.Model(inp, x)
x = keras.random.normal((5, 7, 3))
base_out = base_model(x)
call_and_create_cache_model = get_call_and_create_cache(base_model)
call_with_cache_model = get_call_with_cache(call_and_create_cache_model)
leading, cache = call_and_create_cache_model(x[:, :-1])
trailing, updated_cache = call_with_cache_model((x[:, -1:], cache))
np.testing.assert_allclose(keras.ops.concatenate((leading, trailing), axis=1), base_out)
print("Passed functional model version")
class CompoundCache(CachingFunctionalLayer):
def build(self, input_shape):
if self.built:
return
self.layer0 = LagAndAdd()
self.layer1 = LagAndAdd()
for layer in (self.layer0, self.layer1):
layer.build(input_shape)
super().build(input_shape)
def call_without_cache(self, x):
residual = x
x = self.layer0(x)
x = self.layer1(x)
return x + residual
layer = CompoundCache()
base_out = layer(x)
leading, cache = layer(x[:, :-1], return_cache=True)
trailing, cache = layer(x[:, -1:], cache=cache)
np.testing.assert_allclose(keras.ops.concatenate((leading, trailing), axis=1), base_out)
print("Passed CachingFunctionalLayer version")
if __name__ == '__main__':
main()
if we have a prompt for GPT2, we want to compute the valid lower triangular sub-matrix corresponding to the prompt in one go without any existing cache, then update that cache token-by-token during generation
I think this is actually supported today. And how we do it for our generative models. You can call cached_attn(full_query, value, key, cache, cache_update_index=0)
to update your key/value cache for all known tokens in a fixed prompt input. Then you can call cached_attn(single_token_query, value, key, cache, cache_update_index=i)
for i in range(first_unknown_token, desired_length)
.
Or you could do the whole cached computation in a loop one token at a time--even for the user supplied prompt. cached_attn(single_token_query, value, key, cache, cache_update_index=i)
for i in range(0, desired_length)
. This would be slower walltime if the fixed part of the prompt is long, but still valid if you want to save on GPU memory, which is usually more the concern for many real world workflows these days.
It is worth noting that for good XLA compiled performance, you want to always pad all input lengths to the attn layer. So use a query/key/value length of either 1
or max generation length. And a cache length of max generation length. If you don't XLA performance will tank, quite differently from eager torch. XLA is our "first class citizen" for performance at train time and generation time.
Also worth noting that for the second case here (where you only compute a single token at a time and slice updates into your cache), you could never mix cache creation with call efficiently for jax. Jax will need the cache to be a loop variable in the compiled loop, and you don't actually want the forward pass to be run outside the compiled loop. So here you do want cache creation (just initializing nd.array of zeros) and call to be separate.
Not sure if that helps, but while not obvious the current signature allows seeding a cache and incrementally updating a cache.
Anyway, at a high-level I think this is probably not something we would want a PR for right now. A few reasons.
First the UX is a little different from what we have. We don't really ship model in -> model out functions like call_and_create_cache_model
, and instead try hard to just stick the existing abstractions e.g. layers, models and metrics. I don't think the benefit of easier customized transformer blocks would be worth adding big new concepts to our APIs. Writing your own custom transformer block is not that bad, and we don't want to carry a lot of complexity to make it easier.
The second is more to do with where all this lives and the progression of our Keras 3 release. We are somewhat in an interim state, but soon Keras 3 will be released, and this package can more readily rely on Keras 3 features. I suspect we will then want to push some of this development down into Keras 3...
1) Add grouped-query and multi-query attention to core keras. Quite popular and already a PR for this -> https://github.com/keras-team/keras/pull/18488
2) Add some form of cache
/index
args to all these layers (which could be used for bulk cache update, token by token cache update, or no cache pass as mentioned above). We could consider a function to initialize and empty cache here, which I do think needs to be split out because of the jax constraints on state (but maybe I am still missing things).
3) Add a good way to add rotary positions embedding to all these class, possibly just via subclass.
4) Finally, build on top of all these core Keras features for all KerasNLP models.
A nice to have here would be an easy pattern for creating a transformer block with custom attention, but I don't think a deal breaker. More important is to have the low level building blocks (multi-head, grouped query, etc) in core Keras, and the high-level popular models in KerasNLP. At the end of the day, there is still so much variation in transformer blocks that I don't really think it's a feasible design goal to cover them all in one class. So make your own "transformer block" from dense and attention layers will have to be a common path.
Hope that helps explain things! But wouldn't put the breaks on this by any means, just might be more something for it's own repo, because of slightly different design goals.
You can call cached_attn(full_query, value, key, cache, cache_update_index=0) to update your key/value cache for all known tokens in a fixed prompt input.
There I was thinking cache_update_index
was the index of the cache being updated :S
Also worth noting that for the second case here (where you only compute a single token at a time and slice updates into your cache), you could never mix cache creation with call efficiently for jax. Jax will need the cache to be a loop variable in the compiled loop
I don't follow this. Surely there's an initialization stage before the loop. That's where call_and_create_cache
would be run - once, before token-by-token generation. Currently in GPT2 you create the cache then update it once during _build_cache
and then toss out the generated features - that involves inputs of different shapes than the token-by-token generation, so from what I know of jax it will need to do a separate trace anyway.
I should say I'm thinking about this from a concept more general than just transformers - particularly models like RWKV and RetNet that don't only have linear memory constraints.
Just read your second point and that all sounds fine. I'll close this for now - might comment later if I get around to splitting something like what's above into a separate repo for my own purposes in case others decide it's useful. Thanks for the idea bounce :)
From the roadmap, "KerasNLP is focused on modular and reusable building blocks". Having tried to implement some causal generative models I've found this not to be the case. The low level blocks are great - e.g.
CachedMultiHeadAttention
- but higher level blocks (TransformerDecoder
,GPT2
models) have implementations tightly coupled to the underlying attention mechanism. Specifically, each of these higher level constructs needs to know not only how to call a layer in standard training, but also the types of inputs required during token-by-token generation.Describe the solution you'd like Conceptually, caching layers (like
CachedMultiHeadAttention
) need to support 3 functionalities:CachedMultiHeadAttention
does all three in a singlecall
method, with the presence or absence of input arguments dictating the behaviour, but it could just as easily be implemented asHigher level layers/models can then implement
call_and_create_cache
andcall_with_cache
based on the keras graph implied incall
, substituting out the relevantcall
s forcall_and_create_cache
orcall_with_cache
. Cache creation / storage / retrieval would have to be at the call node level, rather than the child-layer level since each layer could potentially be called multiple times and would need it's own cache for each call, but the keras infrastructure is already set up to support this.This would allow:
__call__
methods, ignorant of any caching; andgenerate
function the "just works".I've experimented with this here (CausalLM model here) and it works - though there are no doubt edge cases that aren't accounted for. It's based on
Function._run_through_graph
and uses private member_nodes_by_depth
.I would be very happy to contribute such a feature. Before I spend any more time on it though, a few questions: