Closed vwxyzjn closed 10 months ago
Wow the acceleration is impressive!
One thing to note is that I think currently it works by assuming all pad tokens on the right. If some pad tokens are in the middle, then the result could be wrong.
For instance (I converted the implementation from torch to numpy):
def right_to_left_pad_vectorized(token, pad_id):
# Step 1: Create a mask for padding elements
pad_mask = (token == pad_id)
# Step 2: Calculate the number of padding elements in each row
pad_counts = pad_mask.sum(axis=1)
# Step 3: Calculate the indices to gather from for each row
rows, cols = token.shape
cols_range = np.arange(cols)
rows_range = np.arange(rows)
indices = np.tile(cols_range, (rows, 1))
rolled_indices = (indices + cols - pad_counts[:, None]) % cols
# Step 4: Use advanced indexing to rearrange the elements
shifted_token = token[rows_range[:, None], rolled_indices]
# Step 5: Set the first few elements of each row to the padding ID
padding_mask = cols_range < pad_counts[:, None]
shifted_token[padding_mask] = pad_id
return shifted_token
pad_id = 50259
token = np.array([
[20, 30, 40, 50, pad_id, pad_id, pad_id, 50],
[90, 30, 40, pad_id, pad_id, pad_id, pad_id, 20],
])
output = right_to_left_pad_vectorized(token, pad_id)
# output:
# array([[50259, 50259, 50259, 20, 30, 40, 50, 50259],
# [50259, 50259, 50259, 50259, 90, 30, 40, 50259]])
I'll try to think of ways to deal with middle pad tokens.
Btw I think we only need to deal with the (continuous) cases:
token = np.array([
[pad_id, 30, 40, 50, pad_id, pad_id, pad_id,],
[90, 30, 40, pad_id, pad_id, pad_id, pad_id],
])
Basically we have query=[pad_id, 30, 40]
and after generation we get query_response = [pad_id, 30, 40, 50256, 30, 30]
and post_processed_query_response = [pad_id, 30, 40, 50256, pad_id, pad_id]
. We need to make it post_processed_query_response = [pad_id, pad_id, pad_id, 30, 40, 50256]
before passing into the reward model.
Ok so this is better
import torch
def right_to_left_pad_vectorized(data, pad_id):
# Step 1: Create a boolean mask
mask = (data == pad_id).long()
# Step 3: Use argsort on the inverted boolean mask to get sorted indices
sorted_indices = torch.argsort(~mask, axis=1)
# Step 4: Use advanced indexing to rearrange the elements
rows_range = torch.arange(data.shape[0], device=data.device)
shifted_data = data[rows_range[:, None], sorted_indices]
return shifted_data
def right_to_left_pad(tokens, pad_id):
return torch.tensor(
[[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens],
device=tokens.device,
)
device = torch.device("cuda:0")
pad_id = 50259
token = torch.tensor([
[20, 30, 40, 50, pad_id, pad_id, pad_id],
# [90, 30, 40, pad_id, pad_id, pad_id, pad_id],
# [90, 30, 40, pad_id, 50, 60, pad_id],
[pad_id, 10, 40, 20, 50, 60, pad_id],
])
print(right_to_left_pad_vectorized(token, pad_id))
print(right_to_left_pad(token, pad_id))
benchmark_token = torch.zeros((64, 512), dtype=torch.int32)
benchmark_token[:, -30:-1] = pad_id
benchmark_token[:, -1] = pad_id
print(right_to_left_pad_vectorized(benchmark_token, pad_id))
print(right_to_left_pad(benchmark_token, pad_id))
import timeit
print("cpu:")
print(timeit.timeit(lambda: right_to_left_pad_vectorized(benchmark_token, pad_id), number=10))
print(timeit.timeit(lambda: right_to_left_pad(benchmark_token, pad_id), number=10))
print("gpu:")
benchmark_token_device = benchmark_token.to(device)
print(timeit.timeit(lambda: right_to_left_pad_vectorized(benchmark_token_device, pad_id), number=10))
print(timeit.timeit(lambda: right_to_left_pad(benchmark_token_device, pad_id), number=10))
^Ccosta@ip-26-0-150-12:/fsx/costa/lm-human-preference-details$ poetry run python -i eff.py
tensor([[50259, 50259, 50259, 20, 30, 40, 50],
[50259, 50259, 10, 40, 20, 50, 60]])
tensor([[50259, 50259, 50259, 20, 30, 40, 50],
[50259, 50259, 10, 40, 20, 50, 60]])
tensor([[50259, 50259, 50259, ..., 0, 0, 0],
[50259, 50259, 50259, ..., 0, 0, 0],
[50259, 50259, 50259, ..., 0, 0, 0],
...,
[50259, 50259, 50259, ..., 0, 0, 0],
[50259, 50259, 50259, ..., 0, 0, 0],
[50259, 50259, 50259, ..., 0, 0, 0]], dtype=torch.int32)
tensor([[50259, 50259, 50259, ..., 0, 0, 0],
[50259, 50259, 50259, ..., 0, 0, 0],
[50259, 50259, 50259, ..., 0, 0, 0],
...,
[50259, 50259, 50259, ..., 0, 0, 0],
[50259, 50259, 50259, ..., 0, 0, 0],
[50259, 50259, 50259, ..., 0, 0, 0]])
cpu:
0.005651316023431718
2.196436281024944
gpu:
0.035678989021107554
12.842351484985556
I see. The argsort version is similar to the current jax implementation https://github.com/vwxyzjn/lm-human-preference-details/blob/ccc19538e817e98a60d3253242ac15e2a562cb49/lm_human_preference_details/train_policy_jax.py#L316-L321
I used it to make the jax function jittable. I didn't expect it to be faster!
Ahhh yeah it's exactly the same! Not sure how I missed that haha.
Btw the one based on argsort
might have a bug:
import torch
def shift_pad_id_left(data, pad_id):
# Step 1: Create a boolean mask
mask = (data == pad_id).long()
# Step 3: Use argsort on the inverted boolean mask to get sorted indices
sorted_indices = torch.argsort(~mask, axis=1)
# Step 4: Use advanced indexing to rearrange the elements
rows_range = torch.arange(data.shape[0], device=data.device)
shifted_data = data[rows_range[:, None], sorted_indices]
return shifted_data
def right_padding_to_left_padding(tokens, pad_id):
"""Convert from right padding to left padding."""
assert tokens.ndim == 2
return torch.tensor([[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens])
query = torch.tensor([[ 50, 10526, 22083, 49828, 25, 374, 14, 3855, 47733, 30829,
198, 49560, 2538, 25, 1160, 10, 8059, 3750, 287, 1478,
1528, 13, 45827, 422, 1223, 314, 1100, 994, 13, 198,
32782, 25, 314, 423, 587, 2111, 290, 9894, 284, 4425,
3463, 329, 257, 3155, 812, 13, 314, 561, 466, 257,
17578, 8027, 329, 546, 257, 1285, 11, 788, 11238, 13,
314, 561, 4483, 1729, 2245, 357, 40, 3111, 287, 257,
38164, 828, 4144, 362, 14096, 286, 279, 538, 13396, 4445,
290, 1239, 5517, 13, 314, 373, 407, 3772, 351, 703,
314, 2936, 13, 198, 14303, 734, 2745, 2084, 13, 314,
1100, 1223, 357, 40, 6044, 508, 4481, 11, 7926, 8,
546, 407, 2111, 284, 4425, 3463, 11, 475, 2427, 2111,
284, 651, 5448, 13, 632, 16752, 1223, 287, 502, 290,
314, 3066, 284, 1949, 340, 503, 13, 220, 198, 2293,
1478, 1528, 314, 1422, 470, 4483, 10536, 4107, 9013, 393,
42402, 11, 2427, 550, 15921, 14, 303, 1136, 2977, 14,
20970, 14, 44749, 14, 14784, 13, 314, 1239, 625, 15063,
393, 15063, 878, 1016, 284, 3996, 13, 314, 24070, 691,
1660, 290, 6041, 286, 340, 13, 314, 25805, 4445, 290,
3111, 503, 790, 584, 1110, 13, 198, 314, 4251, 616,
734, 1285, 2496, 287, 838, 1528, 11, 290, 314, 1254,
1049, 13, 314, 1053, 2722, 617, 4633, 3513, 546, 340,
422, 262, 661, 314, 2107, 351, 11, 2192, 780, 484,
1165, 6531, 351, 3463, 11, 475, 314, 716, 407, 9616,
340, 651, 284, 502, 13, 198, 198, 14990, 26, 7707,
25, 220, 5521, 503, 11, 4144, 6041, 286, 1660, 11,
4483, 5448, 11, 4483, 15921, 14, 303, 1136, 2977, 14,
20970, 14, 44749, 14, 14784, 13, 50256, 50256, 50256, 50256,
50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,]])
left_padded_query = shift_pad_id_left(query, 50256)
right_padding_to_left_padding_query = right_padding_to_left_padding(query, 50256)
print(query)
print(left_padded_query)
print(right_padding_to_left_padding_query)
torch.testing.assert_allclose(left_padded_query, right_padding_to_left_padding_query)
tensor([[ 50, 10526, 22083, 49828, 25, 374, 14, 3855, 47733, 30829,
198, 49560, 2538, 25, 1160, 10, 8059, 3750, 287, 1478,
1528, 13, 45827, 422, 1223, 314, 1100, 994, 13, 198,
32782, 25, 314, 423, 587, 2111, 290, 9894, 284, 4425,
3463, 329, 257, 3155, 812, 13, 314, 561, 466, 257,
17578, 8027, 329, 546, 257, 1285, 11, 788, 11238, 13,
314, 561, 4483, 1729, 2245, 357, 40, 3111, 287, 257,
38164, 828, 4144, 362, 14096, 286, 279, 538, 13396, 4445,
290, 1239, 5517, 13, 314, 373, 407, 3772, 351, 703,
314, 2936, 13, 198, 14303, 734, 2745, 2084, 13, 314,
1100, 1223, 357, 40, 6044, 508, 4481, 11, 7926, 8,
546, 407, 2111, 284, 4425, 3463, 11, 475, 2427, 2111,
284, 651, 5448, 13, 632, 16752, 1223, 287, 502, 290,
314, 3066, 284, 1949, 340, 503, 13, 220, 198, 2293,
1478, 1528, 314, 1422, 470, 4483, 10536, 4107, 9013, 393,
42402, 11, 2427, 550, 15921, 14, 303, 1136, 2977, 14,
20970, 14, 44749, 14, 14784, 13, 314, 1239, 625, 15063,
393, 15063, 878, 1016, 284, 3996, 13, 314, 24070, 691,
1660, 290, 6041, 286, 340, 13, 314, 25805, 4445, 290,
3111, 503, 790, 584, 1110, 13, 198, 314, 4251, 616,
734, 1285, 2496, 287, 838, 1528, 11, 290, 314, 1254,
1049, 13, 314, 1053, 2722, 617, 4633, 3513, 546, 340,
422, 262, 661, 314, 2107, 351, 11, 2192, 780, 484,
1165, 6531, 351, 3463, 11, 475, 314, 716, 407, 9616,
340, 651, 284, 502, 13, 198, 198, 14990, 26, 7707,
25, 220, 5521, 503, 11, 4144, 6041, 286, 1660, 11,
4483, 5448, 11, 4483, 15921, 14, 303, 1136, 2977, 14,
20970, 14, 44749, 14, 14784, 13, 50256, 50256, 50256, 50256,
50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256]])
tensor([[50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
50256, 50256, 50256, 50256, 198, 13, 1110, 584, 790, 503,
3111, 290, 4445, 25805, 314, 13, 340, 286, 314, 6041,
314, 4251, 616, 734, 1285, 2496, 287, 838, 1528, 11,
290, 314, 1254, 1049, 13, 11, 314, 42402, 2427, 550,
15921, 14, 303, 1136, 2977, 14, 20970, 14, 44749, 14,
14784, 13, 290, 1239, 625, 15063, 393, 15063, 878, 1016,
284, 3996, 13, 314, 24070, 691, 1660, 4483, 198, 198,
14990, 26, 7707, 25, 220, 5521, 503, 11, 4144, 6041,
286, 1660, 11, 13, 5448, 11, 4483, 15921, 14, 303,
1136, 2977, 14, 20970, 14, 44749, 14, 14784, 13, 484,
2722, 617, 4633, 3513, 546, 340, 422, 262, 661, 314,
2107, 351, 11, 2192, 780, 1053, 1165, 6531, 351, 3463,
11, 475, 314, 716, 407, 9616, 340, 651, 284, 502,
314, 1285, 257, 546, 329, 8027, 17578, 257, 466, 561,
11, 13, 812, 3155, 257, 329, 3463, 4425, 284, 357,
14096, 362, 4144, 828, 38164, 257, 287, 3111, 40, 9894,
2245, 1729, 4483, 561, 314, 13, 11238, 788, 30829, 287,
3750, 8059, 10, 1160, 25, 2538, 49560, 198, 1478, 47733,
3855, 14, 374, 25, 49828, 22083, 10526, 13, 290, 2111,
587, 423, 314, 25, 32782, 198, 286, 994, 1100, 314,
1223, 422, 45827, 13, 1528, 651, 314, 290, 502, 287,
1223, 16752, 632, 13, 5448, 3066, 284, 2111, 2427, 475,
11, 3463, 4425, 284, 1478, 393, 9013, 4107, 10536, 4483,
470, 1422, 314, 1528, 2111, 2293, 198, 220, 13, 503,
340, 1949, 284, 314, 198, 13, 2936, 314, 703, 351,
3772, 407, 373, 14303, 13, 5517, 1239, 290, 4445, 13396,
538, 279, 40, 407, 546, 8, 7926, 11, 4481, 508,
6044, 50, 357, 1223, 1100, 314, 13, 2084, 2745, 734]])
tensor([[50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
50256, 50256, 50256, 50256, 50, 10526, 22083, 49828, 25, 374,
14, 3855, 47733, 30829, 198, 49560, 2538, 25, 1160, 10,
8059, 3750, 287, 1478, 1528, 13, 45827, 422, 1223, 314,
1100, 994, 13, 198, 32782, 25, 314, 423, 587, 2111,
290, 9894, 284, 4425, 3463, 329, 257, 3155, 812, 13,
314, 561, 466, 257, 17578, 8027, 329, 546, 257, 1285,
11, 788, 11238, 13, 314, 561, 4483, 1729, 2245, 357,
40, 3111, 287, 257, 38164, 828, 4144, 362, 14096, 286,
279, 538, 13396, 4445, 290, 1239, 5517, 13, 314, 373,
407, 3772, 351, 703, 314, 2936, 13, 198, 14303, 734,
2745, 2084, 13, 314, 1100, 1223, 357, 40, 6044, 508,
4481, 11, 7926, 8, 546, 407, 2111, 284, 4425, 3463,
11, 475, 2427, 2111, 284, 651, 5448, 13, 632, 16752,
1223, 287, 502, 290, 314, 3066, 284, 1949, 340, 503,
13, 220, 198, 2293, 1478, 1528, 314, 1422, 470, 4483,
10536, 4107, 9013, 393, 42402, 11, 2427, 550, 15921, 14,
303, 1136, 2977, 14, 20970, 14, 44749, 14, 14784, 13,
314, 1239, 625, 15063, 393, 15063, 878, 1016, 284, 3996,
13, 314, 24070, 691, 1660, 290, 6041, 286, 340, 13,
314, 25805, 4445, 290, 3111, 503, 790, 584, 1110, 13,
198, 314, 4251, 616, 734, 1285, 2496, 287, 838, 1528,
11, 290, 314, 1254, 1049, 13, 314, 1053, 2722, 617,
4633, 3513, 546, 340, 422, 262, 661, 314, 2107, 351,
11, 2192, 780, 484, 1165, 6531, 351, 3463, 11, 475,
314, 716, 407, 9616, 340, 651, 284, 502, 13, 198,
198, 14990, 26, 7707, 25, 220, 5521, 503, 11, 4144,
6041, 286, 1660, 11, 4483, 5448, 11, 4483, 15921, 14,
303, 1136, 2977, 14, 20970, 14, 44749, 14, 14784, 13]])
/home/costa/Documents/go/src/github.com/vwxyzjn/lm-human-preference-details/shift_pad_bug.py:54: FutureWarning: `torch.testing.assert_allclose()` is deprecated since 1.12 and will be removed in a future release. Please use `torch.testing.assert_close()` instead. You can find detailed upgrade instructions in https://github.com/pytorch/pytorch/issues/61844.
torch.testing.assert_allclose(left_padded_query, right_padding_to_left_padding_query)
Traceback (most recent call last):
File "/home/costa/Documents/go/src/github.com/vwxyzjn/lm-human-preference-details/shift_pad_bug.py", line 54, in <module>
torch.testing.assert_allclose(left_padded_query, right_padding_to_left_padding_query)
File "/home/costa/.cache/pypoetry/virtualenvs/lm-human-preference-details-3ZaD8v_C-py3.9/lib/python3.9/site-packages/torch/testing/_comparison.py", line 1553, in assert_allclose
torch.testing.assert_close(
File "/home/costa/.cache/pypoetry/virtualenvs/lm-human-preference-details-3ZaD8v_C-py3.9/lib/python3.9/site-packages/torch/testing/_comparison.py", line 1511, in assert_close
raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not equal!
Mismatched elements: 272 / 300 (90.7%)
Greatest absolute difference: 49547 at index (0, 35)
Greatest relative difference: 1223.6922607421875 at index (0, 107)
At the cost of a more restrictive operation, we can significantly boost performance of the right_to_left_pad function. ChatGPT helped with this haha https://chat.openai.com/share/7ffba63f-dad2-420d-9998-1ec06f602279 pretty amazing.
CC @liutianlin0121