pytorch / executorch

On-device AI across mobile, embedded and edge for PyTorch
https://pytorch.org/executorch/
Other
1.68k stars 280 forks source link

kv cache manipulation? #3518

Open l3utterfly opened 4 months ago

l3utterfly commented 4 months ago

Is it possible to manipulate the kv cache for llama models?

A common use case during inference is to strike/remove values from the kv cache when regenerating or editing generated outputs, so the llm does not need to decode from the beginning.

Are there any APIs available to do this right now? If not, can you give me a general pointer on what needs to be done? I'm happy to implement myself.

iseeyuan commented 4 months ago

@l3utterfly to clarify, if we can do a "stack" style on the kv cache. For example, if there's original prompt and outputs. Later we can feed the model with the same prompt, but different output. To save calculation, we don't have to repopulate the original kv cache for the prompt (or any inputs that are the same), and only evict where the input is different?

We are happy if you could put a PR for us to review. @JacobSzwejbka or @larryliu0820 , do you know where's the good API/entry point for @l3utterfly to add this feature, or if we want to build this API first?

l3utterfly commented 4 months ago

Yes. For reference, llama.cpp does these kinds of kv cache manipulations. Additionally, a great feature would be the ability to save and load kv caches.

I did the PR for the rollback/regenerate feature in llama.cpp, happy to implement something similar here if you can give me a quick pointer on where to add these APIs

larryliu0820 commented 4 months ago

@l3utterfly thanks for offering help! We have been talking about implementing different kv cache manipulation techniques but haven't got a chance to that part. For now you can look at how it is currently implemented:

https://github.com/pytorch/executorch/blob/main/examples/models/llama2/llama_transformer.py#L183

Feel free to experiment with it and send a PR

JacobSzwejbka commented 4 months ago

save and load kv caches.

Havent thought about the ability to mutate state from outside model execution. It should be possible. Let me think about how the apis would look, as the concept of what tensor is persistent state is not really available in the runtime today, nor do we serialize info that would be necessary to figure that out today.

JacobSzwejbka commented 4 months ago

One thing you could do would be put mutable buffers in the graph AoT onto their own mem-id with a custom memory plan, and then just copy into and out of that buffer in the runtime. So sidestep any tensor concepts and just mutate the arenas directly

JacobSzwejbka commented 4 months ago

How do you want to manipulate the cache/ what granularity do you look at? Are you going by like layer5.k_cache?

l3utterfly commented 4 months ago

@JacobSzwejbka Currently I'm mainly focused on implementing the kv cache optimisation for transformer models (e.g. llama2 and 3). So I guess only the kv values of the attention heads would work. Are you thinking of a general API in executorch to get/set kv values for all layers?

Also I'm focusing on the c++ code because my goal is to run this on Android.

One thing you could do would be put mutable buffers in the graph AoT onto their own mem-id with a custom memory plan, and then just copy into and out of that buffer in the runtime. So sidestep any tensor concepts and just mutate the arenas directly

This sounds like a good workaround for implementing this feature. Can you recommend me a good place to put this code in executorch?

JacobSzwejbka commented 4 months ago

You first need to write a custom memory plan

https://github.com/pytorch/executorch/blob/main/exir/capture/_config.py#L48C38-L48C56

In that plan you need to identify mutable buffers as you iterate over the graph. This can be a little complex in the general case but for simple kv cache stuff I think this will work: https://github.com/pytorch/executorch/blob/main/exir/memory_planning.py#L308

Then finally on the runtime side you just need to mess around with the buffers you pass to https://github.com/pytorch/executorch/blob/main/runtime/core/hierarchical_allocator.h#L35

In the future we want to make all of these steps easier.

  1. Easy flag you can set to just automatically lift buffers to their own mem_ids
  2. A way to associate a string with a mem_id. For buffers like kvcache that string could be the fqn.
  3. A more direct api to update buffers at runtime and even swap the buffers to a different one after initialization. This would let you "load" a kv cache without a copy even after init.

This is a bit of a high level overview, but hopefully its enough to let you get started. If you have more questions feel free to post here and tag me and Ill try and help.

cc @mikekgfb @iseeyuan

l3utterfly commented 4 months ago

Thank you so much for the pointers! I will try this


From: Jacob Szwejbka @.> Sent: Friday, May 10, 2024 3:03:06 AM To: pytorch/executorch @.> Cc: l3utterfly @.>; Mention @.> Subject: Re: [pytorch/executorch] kv cache manipulation? (Issue #3518)

You first need to write a custom memory plan

https://github.com/pytorch/executorch/blob/main/exir/capture/_config.py#L48C38-L48C56

In that plan you need to identify mutable buffers as you iterate over the graph. This can be a little complex in the general case but for simple kv cache stuff I think this will work: https://github.com/pytorch/executorch/blob/main/exir/memory_planning.py#L308

Then finally on the runtime side you just need to mess around with the buffers you pass to https://github.com/pytorch/executorch/blob/main/runtime/core/hierarchical_allocator.h#L35

In the future we want to make all of these steps easier.

  1. Easy flag you can set to just automatically lift buffers to their own mem_ids
  2. A way to associate a string with a mem_id. For buffers like kvcache that string could be the fqn.
  3. A more direct api to update buffers at runtime and even swap the buffers to a different one after initialization. This would let you "load" a kv cache without a copy even after init.

cc @mikekgfbhttps://github.com/mikekgfb @iseeyuanhttps://github.com/iseeyuan

— Reply to this email directly, view it on GitHubhttps://github.com/pytorch/executorch/issues/3518#issuecomment-2103162996, or unsubscribehttps://github.com/notifications/unsubscribe-auth/ABVFX24D76YGX4WTWEMILX3ZBO25VAVCNFSM6AAAAABHJDMJLWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMBTGE3DEOJZGY. You are receiving this because you were mentioned.Message ID: @.***>

mikekgfb commented 4 months ago

Thanks for putting this together @JacobSzwejbka ! And thanks so much for your contribution to executorch @l3utterfly ! So excited too see your work with Layla and GGML and super excited to have you become a contributor to executorch!!!!

As always, thanks for your support and leadership @iseeyuan !

l3utterfly commented 3 months ago

@JacobSzwejbka Following your pointers and reading through Executorch's documentation several times, I've managed to implement the custom memory planner and obtained the names of the kv_cache arenas via:

class KVMemIdMemoryPlanningPass(MemoryPlanningPass):
    def run(self, graph_module: torch.fx.GraphModule, graph_signature: Optional[ExportGraphSignature]) -> PassResult:
        for subgm in graph_module.modules():
            if not isinstance(subgm, torch.fx.GraphModule):
                continue
            for node in subgm.graph.nodes:
                if _is_mutable_buffer(node, graph_signature):
                    print(f"Mutable buffer found: {node}")

        return super().run(graph_module, graph_signature)

I'm a little stuck on how should I manipulate this on the c++ side in hierarchical_allocator.h. I see a method to get a memory buffer by passing the memory_id. This seems to be the buffer index, and I need the offset and size of the buffer.

  1. How do I determine the memory id, I see you can set the "mem_id" by doing: node.meta["spec"].mem_id = 1. So does this mean this node will be in buffer[1]? (https://pytorch.org/executorch/stable/compiler-memory-planning.html)
  2. How do I get the size of the buffer, is that fixed by calculating the sizeof(float) * tensor dimensions?
  3. After obtaining the buffer via hierarchical_allocator->get_offset_address, is modifying the contents via the pointer sufficient? Do I need to handle syncing the buffer with the underlying hardware (GPU, DSP, etc.)?
JacobSzwejbka commented 3 months ago

1 and 2. I think you are thinking about it a bit backwards, it might be because you are using the module class api? But you should have the buffers before creating the hierarchical_allocator. The order is 1. Get buffer -> 2. Create Allocator -> 3. Create hierarchical_allocator from allocators. Here is a link to get the expected size and number of buffers expected in the heirarchical_allocator. https://github.com/pytorch/executorch/blob/main/runtime/executor/program.h#L152

  1. You should just be able to mutate the contents through pointer. Im not aware of any delegates today that consume the buffer, but if/when its added this approach will probably start failing because we dont have any good apis today to allow mutation of delegates from outside.
l3utterfly commented 3 months ago

Is this the place I determine which buffers are for the kv cache?

Error Module::load_method(const std::string& method_name) {
  if (!is_method_loaded(method_name)) {
    ET_CHECK_OK_OR_RETURN_ERROR(load());

    MethodHolder method_holder;
    const auto method_metadata =
        ET_UNWRAP(program_->method_meta(method_name.c_str()));
    const auto planned_buffersCount =
        method_metadata.num_memory_planned_buffers();
    method_holder.planned_buffers.reserve(planned_buffersCount);
    method_holder.planned_spans.reserve(planned_buffersCount);

    for (auto index = 0; index < planned_buffersCount; ++index) {
      const auto buffer_size =
          method_metadata.memory_planned_buffer_size(index).get();
      method_holder.planned_buffers.emplace_back(buffer_size);
      method_holder.planned_spans.emplace_back(
          method_holder.planned_buffers.back().data(), buffer_size);
    }
    method_holder.planned_memory = std::make_unique<HierarchicalAllocator>(Span(
        method_holder.planned_spans.data(),
        method_holder.planned_spans.size()));
    method_holder.memory_manager = std::make_unique<MemoryManager>(
        memory_allocator_.get(), method_holder.planned_memory.get());
    method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method(
        method_name.c_str(),
        method_holder.memory_manager.get(),
        event_tracer_.get()));
    methods_.emplace(method_name, std::move(method_holder));
  }
  return Error::Ok;
}
l3utterfly commented 3 months ago

@JacobSzwejbka I have managed to get a naive implementation of kv cache save + load working.

I have a question:

My kv cache buffer is in index 0 in the planned_buffers in the method_holder.

This code obtains the piece of memory correctly:

// copy the contents of kv cache at each step
auto kv_cache_buffer = module_->methods_["forward"].planned_buffers[0];
kv_cache_buffers[pos].assign(kv_cache_buffer.begin(), kv_cache_buffer.end()) ;

Which I'm simply copying to a cache variable at each step for now.

I'm trying to use the provided methods in the memory manager et al., but they seem to return different results:

 // copy the contents of kv cache at each step
auto kv_cache_buffer = module_->methods_["forward"].memory_manager->planned_memory()->get_offset_address(0, 0, buffer_size).get();

I looked through the load_method code above, memory_manager, planned_memory (HierarchicalAllocator) seems to be just holding the pointers to the planned_buffer, so I'm unsure why would they return different results. It seems I'm misunderstanding something in the get_offset_address function

JacobSzwejbka commented 3 months ago

Sorry for the delay I've been out of town and off the grid.

auto kv_cachebuffer = module->methods_["forward"].memory_manager->planned_memory()->get_offset_address(0, 0, buffer_size).get();

We have some legacy code where mem_id 0 is reserved for reasons that dont make sense anymore. @dbort put up some code a while back to hide this from users as best we could. You might be running into that when you go directly through the allocator apis to get the buffer offset (maybe mem id 1 would work for you). If the first code you linked works I would just continue using that.

prashantaithal commented 1 month ago

@l3utterfly thanks for offering help! We have been talking about implementing different kv cache manipulation techniques but haven't got a chance to that part.

You mentioned implementing different kv cache manipulation techniques, does that mean there already exists APIs to use the kv cache for LLM inference in executorch ? Or are the new APIs being worked on in this issue?

JacobSzwejbka commented 1 month ago

You mentioned implementing different kv cache manipulation techniques, does that mean there already exists APIs to use the kv cache for LLM inference in executorch ? Or are the new APIs being worked on in this issue?

ExecuTorch already supports kv-cache enabled Llama if thats what you are asking? You can see this under examples/ in the ET repo or in the torchchat repo which just launched. @prashantaithal

prashantaithal commented 1 month ago

Thanks for the response, you must mean llama_transformer.py.

I had a basic question, is kv cache part of the pytorch llama model(*.pt/pth) or is it implemented separately outside of the model? I ask this after looking at the python script (linked above), I see the model's Attention class using the KVCache.

Apologies for hijacking this thread, but I could not find any "discussions" tab in the repositories.

JacobSzwejbka commented 1 month ago

Yes its embedded within the .pte as hidden model state based on the implementation in llama_transformer.py. We had an earlier version running that lifted the cache to model IO as well, so if you want easier manipulation of the cache outside the model that approach might work better for you.

The current implementation in llama_transformer just better aligns with how we had seen people authoring models (stateful) so we wanted to show support for that.

prashantaithal commented 1 month ago

Thanks for the response. Where can I find the earlier version?