pytorch / executorch

On-device AI across mobile, embedded and edge for PyTorch
https://pytorch.org/executorch/
Other
1.9k stars 310 forks source link

Support for dynamic caches #4740

Open awgr opened 1 month ago

awgr commented 1 month ago

🚀 The feature, motivation and pitch

foreword and motivation

This is a foreword on mutable states and the forward pass.

Compared to history, people are now writing models with more types of state which need to be managed across consecutive forward passes of models. Generally, model implementers are left to write their own cache implementations for whatever states are required for the forward pass. Many types of states have fixed size per layer, so a static allocation is sufficient.

However, some types of state have a size which is data-dependent. For states with data-dependent size, the dependency is typically to the length of the input sequence of time steps. There may be others, but this feature request focuses on sequence length, as this dependency is common. Using a transformer example, the key-value cache is dynamic in sequence length, storing one key and one value state per layer per time-step. To pre-allocate a static buffer for this state yields two small problems:

(1) the static buffer state must be copied into the layer's key and value states, instead of those key and value states being computed over directly (2) the static allocation of these cache lines can be very large, even for a very small input sequence.

Both of these impact system resources of models running on edge devices, where resources are constrained, and, latency of the forward pass.

feature

Design a dynamic allocation model for caching model states. An example design:

pitch

Users of executorch will likely see shorter initialization times of their application, and smaller model residency. If an optimization can be captured related to computing over values loaded from cache, the forward latency should be reduced. Sequence length of a compiled model can be bounded by resources on the machine.

alternatives

I've considered using a static cache. The static cache is more or less sufficient for the majority of use cases, but as devices become more portable, their system resources also tend to decrease. In these constrained environments, a static solution requires a different implementation per compile target, which is onerous.

iseeyuan commented 1 month ago

@awgr thanks a lot for submitting the issue with great suggestoins!
@helunwencser and @JacobSzwejbka , let's us this issue as a channel to discuss the dynamic kv cache!

helunwencser commented 1 month ago

I started looking into this issue since last week. https://github.com/pytorch/executorch/pull/4785 is a quick prototype. In general, we won't be able to register a dynamic buffer inside a model and resize it as needed. One approach we can do is making the kv cache as an input tensor and set its size as a dynamic shape. Then we should be able to resize the dynamic cache size as needed.

In the above PR, I have a simple prototype. More changes will be needed to support this properly inside llama model.