andy-yang-1 / DoubleSparse

16-fold memory access reduction with nearly no loss
MIT License
33 stars 1 forks source link

The method transfer the KV cache from cpu memory to gpu memory #3

Open digbangbang opened 1 week ago

digbangbang commented 1 week ago

Wonderful work! Following Q and looking forward ur reply.

1) I am curious about the method in your paper that copy the KV cache from cpu memory to gpu memory.

image

Since I have test the following code

model = nn.Linear()

model.to('cuda')

sometimes the model is large, then the time is costly.

2) Otherwise, does the method in ur paper is time efficient?

andy-yang-1 commented 1 week ago

@digbangbang Thank you for your feedback!

  1. We have two implementations, the first one is like
    with torch.cuda.stream(loading_stream):
    next_sorted_channel = global_sorted_channels[self.layer_idx+1]
    global_kv_caches[self.layer_idx+1].load_gpu(next_sorted_channel.view(-1))

    This method is easy to implement and has good performance, because it won't affect kernels' performance.

Another implementation is with DGL:

def main_worker(args) -> None:
    os.environ['CUDA_MPS_ACTIVE_THREAD_PERCENTAGE'] = "95"
    serving_framework.run(args)

def kvcache_cpu_worker(task_queue, out_queue, cpu_kv_caches, barrier):
    os.environ['CUDA_MPS_ACTIVE_THREAD_PERCENTAGE'] = "5"
    while True:
        task = task_queue.get()
        ...

def main(args)
    mp.set_start_method("spawn")
    main_worker(args)
    kvcache_cpu_worker(args.task_queue, args.out_queue, args.cpu_kv_caches, args.barrier)

This method can help to achieve higher overlap rate, but it will affect kernels' performance. So we prefer the first method.

  1. This offloading method gains less speedup than original one. The results shows than it is slightly faster than original GPT-Fast under regular workload.