Open youkaichao opened 4 months ago
Thanks for the investigation! I'll be happy to refactor prepare input with array in recent days.
I'll be happy to refactor prepare input with array in recent days.
Thanks! Just a reminder, not all lists need to be replaced with array. Some lists can be removed actually. For example:
this line does not need to be a list. It is used in only two places:
So we only need to keep track of self.max_decode_seq_len
, and update it in every _add_seq_group
.
Thank you for this write up, I would be happy to dig into this refactor.
I spent some time on this and found that we may need to be careful when using array
. The key is array
is a C++ wrapper, so it basically works as C++ vectors, including the complexity of their operators. Specifically, after replacing lists with arrays in prepare input, I didn't see any speedup, so I benchmarked the operators between list and array:
op | list | array | numpy |
---|---|---|---|
init | 2.77 ms | 10.79 ms | 35.63 ms |
pre-allocate | 0.12 ms | 0.37 ms | 0.21 ms |
append | 4.19 ms | 6.78 ms | 3623.14 ms |
extend | 4.25 ms | 374.62 ms | 316.91 ms |
slice | 0.09 ms | 0.09 ms | 0.11 ms |
assign | 0.02 ms | 0.03 ms | 0.06 ms |
to_torch | 1210.37 ms | 32.54 ms | 32.77 ms |
We can see that extend
is super expensive for both array
and numpy
. This is mainly because the extended part is a Python object and it has to be initialized to array
or np.array
first. This operation is quite common in prepare input because of CUDA graph padding.
Here is the benchmark script for reference
import timeit
def run_op(name, list_cmd, array_cmd, np_cmd, setup_code=None):
# Setup code
if setup_code is None:
setup_code = """
import array
import torch
import numpy as np
lst = list(range(10000))
lst2 = list(range(1000))
arr = array.array('l', range(10000))
arr2 = array.array('l', range(10000))
np_lst = np.zeros(10000)
np_lst2 = np.zeros(10000)
"""
list_time = timeit.timeit(list_cmd, setup=setup_code, number=1000) * 1000
array_time = timeit.timeit(array_cmd, setup=setup_code, number=1000) * 1000
np_time = timeit.timeit(np_cmd, setup=setup_code, number=1000) * 1000
print(f"{name}|{list_time:.2f} ms|{array_time:.2f} ms|{np_time:.2f} ms")
print("op|list|array|numpy")
print("--|--|--|--")
run_op(
"init",
"[[] for _ in range(100)]",
"[array.array('l') for _ in range(100)]",
"[np.array([], dtype=np.int64) for _ in range(100)]")
run_op(
"pre-allocate",
"a = [0] * 100",
"a = array.array('l', [0]) * 100",
"a = np.zeros(100, dtype=np.int64)")
run_op(
"append",
"[lst.append(10) for _ in range(100)]",
"[arr.append(10) for _ in range(100)]",
"[np.append(np_lst, 10) for _ in range(100)]")
run_op(
"extend",
"lst.extend(lst2)",
"arr.extend(lst)",
"np.append(np_lst, lst)")
run_op(
"slice",
"lst[10:50]",
"arr[10:50]",
"np_lst[10:50]")
run_op(
"assign",
"lst[555] = 10",
"arr[555] = 10",
"np_lst[555] = 10")
run_op(
"to_torch",
"torch.tensor(lst, device='cuda'); torch.cuda.synchronize()",
"torch.frombuffer(arr, dtype=torch.long).to('cuda'); torch.cuda.synchronize()",
"torch.as_tensor(np_lst, device='cuda'); torch.cuda.synchronize()")
Thanks for the investigation!
after replacing lists with arrays in prepare input
does array
have any slowdown then? If not, we should still turn to array, because it allows much faster incremental prepare input and to torch tensor:
data = array("l")
# the following loop has constant time per iteration. Total time complexity is O(N)
for i in range(100):
data.append(i)
tensor = torch.frombuffer(data, dtype=torch.int)
data = []
# the following loop has increasing time per iteration. Total time complexity is O(N^2)
for i in range(100):
data.append(i)
tensor = torch.tensor(data, dtype=torch.int)
@comaniac when you index into array, it seems that it does "unboxing" (creates a new pythonic object) which also may slow down things. Did you happen to measure indexing?
For reference: https://stackoverflow.com/questions/36778568/why-are-pythons-arrays-slow
@alexm-neuralmagic the unboxing and boxing (conversion between an 8-bit int to a Python int object) can slow down things indeed. We need to turn to array-native ideally, use array
as much as possible.
Do you know which part of code in vLLM will suffer from this? Maybe we can find alternative ways to make it more array-oriented to avoid the unboxing and boxing cost.
does array have any slowdown then? If not, we should still turn to array, because it allows much faster incremental prepare input and to torch tensor
It unfortunately does slow down things. I believe the slowdown mainly comes from .extend
as it may need to re-allocate memory for array
and numpy
. And we use extend in the following cases:
@comaniac when you index into array, it seems that it does "unboxing" (creates a new pythonic object) which also may slow down things. Did you happen to measure indexing?
I didn't unbox arrays once created.
I don't think cudagraph padding plays an important role here. It just pad for several tokens.
Concatenate input tokens and other input data to batch requests
This might be significant. Let me see if we can find any solution.
For concatenation, if both objects are in the same type (i.e, array
or np.array
), then:
op | list | array | numpy |
---|---|---|---|
extend | 4.32 ms | 39.09 ms | 5.98 ms |
So using np.append(a, b)
may be a way to go
Did an ablation study by only changing slot mapping to array
because it is self-contained. The following is the accumulated time (in seconds) of each function call in my benchmark.
API | list (main) | array | diff |
---|---|---|---|
torch.tensor | 0.172 | 0.158 | -0.014 |
compute_slot_mapping | 0.0214 | 0.0238 | +0.0024 |
torch.frombuffer | 0 | 0.00881 | +0.00881 |
torch.cat | 0 | 0.00637 | +0.00637 |
Note 1: compute_slot_mapping
now has array.array
which initializes an array, so it brings some overheads.
Note 2: To avoid extend
, we use torch.cat([torch.frombuffer(s) for s in slot_mappings])
.
In short, we can see that although this approach reduces torch.tensor
, other overheads eat the speedup, resulting the similar latency (+0.00358).
More performance micro benchmark, just for how to move a Python list to pytorch tensor:
import array
import torch
# print header
print("N\tlist\tlist_with_array\tarray")
for N in [100, 1000, 10000, 100000, 1000000]:
list_data = list(range(N))
array_data = array.array('q', list_data)
def create_from_list():
tensor = torch.tensor(list_data, dtype=torch.int64)
return tensor
def create_from_list_with_array():
tensor = torch.frombuffer(array.array("q", list_data), dtype=torch.int64)
return tensor
def create_from_array():
tensor = torch.frombuffer(array_data, dtype=torch.int64)
return tensor
import time
for _ in range(10):
# warmup
create_from_list()
start = time.time()
for _ in range(100):
create_from_list()
elapsed_list = (time.time() - start) / 100 * 1000 # ms
for _ in range(10):
# warmup
create_from_list_with_array()
start = time.time()
for _ in range(100):
create_from_list_with_array()
elapsed_list_with_array = (time.time() - start) / 100 * 1000 # ms
for _ in range(10):
create_from_array()
start = time.time()
for _ in range(100):
create_from_array()
elapsed_array = (time.time() - start) / 100 * 1000 # ms
print(f"{N}\t{elapsed_list:.3f}\t{elapsed_list_with_array:.3f}\t{elapsed_array:.3f}")
N | torch.tensor(list) | torch. frombuffer(array(list)) | torch.frombuffer(array) |
---|---|---|---|
100 | 0.005 | 0.002 | 0.001 |
1000 | 0.030 | 0.014 | 0.001 |
10000 | 0.278 | 0.137 | 0.001 |
100000 | 2.684 | 1.347 | 0.001 |
1000000 | 26.654 | 13.759 | 0.001 |
As we can see, torch.frombuffer(array)
does zero-copy. But @comaniac find it is difficult for array
to interact with other data types like Python list.
Surprisingly, torch.frombuffer(array(list))
is much faster than torch.tensor(list)
, almost reducing half of the time.
More performance micro benchmark, just for how to move a Python list to pytorch tensor:
import array import torch # print header print("N\tlist\tlist_with_array\tarray") for N in [100, 1000, 10000, 100000, 1000000]: list_data = list(range(N)) array_data = array.array('q', list_data) def create_from_list(): tensor = torch.tensor(list_data, dtype=torch.int64) return tensor def create_from_list_with_array(): tensor = torch.frombuffer(array.array("q", list_data), dtype=torch.int64) return tensor def create_from_array(): tensor = torch.frombuffer(array_data, dtype=torch.int64) return tensor import time for _ in range(10): # warmup create_from_list() start = time.time() for _ in range(100): create_from_list() elapsed_list = (time.time() - start) / 100 * 1000 # ms for _ in range(10): # warmup create_from_list_with_array() start = time.time() for _ in range(100): create_from_list_with_array() elapsed_list_with_array = (time.time() - start) / 100 * 1000 # ms for _ in range(10): create_from_array() start = time.time() for _ in range(100): create_from_array() elapsed_array = (time.time() - start) / 100 * 1000 # ms print(f"{N}\t{elapsed_list:.3f}\t{elapsed_list_with_array:.3f}\t{elapsed_array:.3f}")
N torch.tensor(list) torch. frombuffer(array(list)) torch.frombuffer(array) 100 0.005 0.002 0.001 1000 0.030 0.014 0.001 10000 0.278 0.137 0.001 100000 2.684 1.347 0.001 1000000 26.654 13.759 0.001 As we can see,
torch.frombuffer(array)
does zero-copy. But @comaniac find it is difficult forarray
to interact with other data types like Python list.Surprisingly,
torch.frombuffer(array(list))
is much faster thantorch.tensor(list)
, almost reducing half of the time.
It's amazing.
import array
import torch
src_from_buffer = torch.frombuffer
def from_buffer(buffer, *, dtype, count=-1, offset=0, requires_grad=False):
_map = {
torch.bool: 'b',
torch.int8: 'b',
torch.uint8: 'B',
torch.int16: 'h',
torch.uint16: 'H',
torch.int32: 'i',
torch.uint32: 'I',
torch.int64: 'q',
torch.uint64: 'Q',
torch.float32: 'f',
torch.float64: 'd'
}
if isinstance(buffer, (list, tuple)):
buffer = array.array(_map[dtype], buffer)
return src_from_buffer(buffer=buffer, dtype=dtype, count=count, offset=offset, requires_grad=requires_grad)
torch.frombuffer = from_buffer
After hacking torch.from_buffer
, the performance has doubled too:
N list list_with_array array array2
100 0.009 0.005 0.002 0.005
1000 0.055 0.027 0.002 0.027
10000 0.489 0.215 0.002 0.211
100000 3.738 1.546 0.002 1.520
1000000 33.328 15.310 0.002 16.863
Can we globally hack torch.from_buffer
?
@triple-Mu what's the difference between your array2
and list_with_array
?
@comaniac is considering using the list_with_array
approach for all the possible cases. We can add a function for this. We don't need to hack torch.from_buffer
. If you want to use it in your repo, you can feel free to use it.
@triple-Mu what's the difference between your
array2
andlist_with_array
?@comaniac is considering using the
list_with_array
approach for all the possible cases. We can add a function for this. We don't need to hacktorch.from_buffer
. If you want to use it in your repo, you can feel free to use it.
What I'm trying to say is that we don't need to add a new function to do this, we just need to modify the from_buffer at the beginning of vllm startup, and all from_buffer calls will get performance double.
@triple-Mu what's the difference between your
array2
andlist_with_array
?@comaniac is considering using the
list_with_array
approach for all the possible cases. We can add a function for this. We don't need to hacktorch.from_buffer
. If you want to use it in your repo, you can feel free to use it.What I'm trying to say is that we don't need to add a new function to do this, we just need to modify the from_buffer at the beginning of vllm startup, and all from_buffer calls will get performance double.
So does torch.tensor.
hacking is not good for future maintenance. I would suggest being explicit. We don't need to sacrifice code quality for this.
This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!
Proposal to improve performance
For flexibility, lots of code in vLLM uses Python list.
The memory layout for a Python list of
[1, 2, 3, 4, 5]
, is:This is because a Python list can hold arbitrary Python object.
When we use
torch.tensor([1, 2, 3, 4, 5], dtype=torch.int, device="cuda")
, there's two copy operation happening:1, 2, 3, 4, 5
consecutively (40 bytes)There is a better alternative in Python, called
array.array
. It is very similar tovector
type inC++
, which can hold variable length data with the same type. Since the memory layout is already compact, we can directly create pytorch tensor from it, without copying, and then copy it to GPU. i.e., we can reduce the copy in step 1.Here is some microbenchmark:
The output is:
As we can see, use array to copy to GPU is always faster. When the input is large, the difference is even larger.
However, how can we get an array object? If we do
array_data = array.array('l', list_data)
, it is another copy, and will not give us any benefit.The answer is, we should try to start with
array
, and usearray.append
/array.extend
to replacelist.append
/list.extend
. Then, we should replacetorch.tensor(data, dtype=torch.int, device="cuda")
totorch.frombuffer(data, dtype=torch.int).to(device="cuda")
.This will require rewrite lots of the code in prepare-input and block table preparation, one of the main performance bottleneck.
cc @comaniac for prepare input cc @alexm-neuralmagic @cadedaniel for block manager
Report of performance regression
No response
Misc discussion on performance
No response
Your current environment (if you think it is necessary)