Closed robertknight closed 3 days ago
While working on this I ran into an issue where in-place extension of the KV-cache fails for some caches in the Whisper example because a KV-cache tensor (eg. past_key_values.0.decoder.key
) is an input to both a Concat
op and a Shape
op, and the Concat
op gets scheduled first. The buffer can't be used in-place because the Shape
op needs it later. See also https://github.com/robertknight/rten/issues/98.
Another complication with Hugging Face models exported by Optimum: The KV cache buffers initially reserved by rten-generate don't get used if using the decoder_model_merged
model. The first run of the models uses an alternate branch which doesn't use the past_key_values.{layer}.decoder.{key, value}
inputs. Instead new tensors are allocated which don't have spare capacity reserved for in-place growth.
Resolved in https://github.com/robertknight/rten/pull/407. Some additional changes were required in rten-generate (https://github.com/robertknight/rten/pull/408) to take advantage of this in the Hugging Face models (see Whisper and TrOCR examples) where the issue originally came up.
https://github.com/robertknight/rten/pull/239 added an optimization that reduces the cost of KV-cache extension by a
Concat
op fromO(sequence_length)
toO(1)
. This optimization currently does not work if the input KV-cache buffer is a captured value in a subgraph. This is because captured values are exposed to subgraphs as immutable views, even if the value was passed toModel::run
as an owned value. For the optimization to work,Concat
needs to be run as an in-place operation which requires the first input to be an owned value.Solving the general problem that captured values cannot be used in in-place operations would help in other ways too. For example there are models that try to reshape captured values and currently this reshape copies the data unnecessarily.
Conceptually what I think is needed is the ability for subgraph capture environments (
CaptureEnv
) to capture values by value rather than by reference, if those values are not going to be used outside of the subgraph.