vwxyzjn / lm-human-preference-details

RLHF implementation details of OAI's 2019 codebase
MIT License
152 stars 7 forks source link

right_to_left_pad optimization #26

Closed vwxyzjn closed 10 months ago

vwxyzjn commented 1 year ago

At the cost of a more restrictive operation, we can significantly boost performance of the right_to_left_pad function. ChatGPT helped with this haha https://chat.openai.com/share/7ffba63f-dad2-420d-9998-1ec06f602279 pretty amazing.

CC @liutianlin0121

import torch

def right_to_left_pad_vectorized(token, pad_id):
    # Step 1: Create a mask for padding elements
    pad_mask = (token == pad_id)

    # Step 2: Calculate the number of padding elements in each row
    pad_counts = pad_mask.sum(axis=1)

    # Step 3: Calculate the indices to gather from for each row
    rows, cols = token.shape
    cols_range = torch.arange(cols, device=token.device)
    rows_range = torch.arange(rows, device=token.device)
    indices = torch.tile(cols_range, (rows, 1))
    rolled_indices = (indices + cols - pad_counts[:, None]) % cols

    # Step 4: Use advanced indexing to rearrange the elements
    shifted_token = token[rows_range[:, None], rolled_indices]

    # Step 5: Set the first few elements of each row to the padding ID
    padding_mask = cols_range < pad_counts[:, None]
    shifted_token[padding_mask] = pad_id    
    return shifted_token

def right_to_left_pad(tokens, pad_id):
    return torch.tensor(
        [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens],
        device=tokens.device,
    )

pad_id = 50259
token = torch.tensor([
    [20, 30, 40, 50, pad_id, pad_id, pad_id],
    [90, 30, 40, pad_id, pad_id, pad_id, pad_id],
])
print(right_to_left_pad_vectorized(token, pad_id))
print(right_to_left_pad(token, pad_id))
benchmark_token = torch.zeros((64, 512), dtype=torch.int32)
benchmark_token[:, -30:-1] = pad_id
benchmark_token[:, -1] = pad_id
print(right_to_left_pad_vectorized(benchmark_token, pad_id))
print(right_to_left_pad(benchmark_token, pad_id))

import timeit
print("cpu:")
print(timeit.timeit(lambda: right_to_left_pad_vectorized(benchmark_token, pad_id), number=10))
print(timeit.timeit(lambda: right_to_left_pad(benchmark_token, pad_id), number=10))
device = torch.device("cuda:0")
print("gpu:")
benchmark_token_device = benchmark_token.to(device)
print(timeit.timeit(lambda: right_to_left_pad_vectorized(benchmark_token_device, pad_id), number=10))
print(timeit.timeit(lambda: right_to_left_pad(benchmark_token_device, pad_id), number=10))
(lm-human-preference-details-py3.8) costa@ip-26-0-150-12:/fsx/costa/lm-human-preference-details$ python -i eff.py 
tensor([[50259, 50259, 50259,    20,    30,    40,    50],
        [50259, 50259, 50259, 50259,    90,    30,    40]])
tensor([[50259, 50259, 50259,    20,    30,    40,    50],
        [50259, 50259, 50259, 50259,    90,    30,    40]])
tensor([[50259, 50259, 50259,  ...,     0,     0,     0],
        [50259, 50259, 50259,  ...,     0,     0,     0],
        [50259, 50259, 50259,  ...,     0,     0,     0],
        ...,
        [50259, 50259, 50259,  ...,     0,     0,     0],
        [50259, 50259, 50259,  ...,     0,     0,     0],
        [50259, 50259, 50259,  ...,     0,     0,     0]], dtype=torch.int32)
tensor([[50259, 50259, 50259,  ...,     0,     0,     0],
        [50259, 50259, 50259,  ...,     0,     0,     0],
        [50259, 50259, 50259,  ...,     0,     0,     0],
        ...,
        [50259, 50259, 50259,  ...,     0,     0,     0],
        [50259, 50259, 50259,  ...,     0,     0,     0],
        [50259, 50259, 50259,  ...,     0,     0,     0]])
cpu:
0.006918723986018449
2.1983908410184085
gpu:
0.005345842044334859
12.780526765040122
liutianlin0121 commented 1 year ago

Wow the acceleration is impressive!

One thing to note is that I think currently it works by assuming all pad tokens on the right. If some pad tokens are in the middle, then the result could be wrong.

For instance (I converted the implementation from torch to numpy):

def right_to_left_pad_vectorized(token, pad_id):
    # Step 1: Create a mask for padding elements
    pad_mask = (token == pad_id)

    # Step 2: Calculate the number of padding elements in each row
    pad_counts = pad_mask.sum(axis=1)

    # Step 3: Calculate the indices to gather from for each row
    rows, cols = token.shape
    cols_range = np.arange(cols)
    rows_range = np.arange(rows)
    indices = np.tile(cols_range, (rows, 1))
    rolled_indices = (indices + cols - pad_counts[:, None]) % cols

    # Step 4: Use advanced indexing to rearrange the elements
    shifted_token = token[rows_range[:, None], rolled_indices]

    # Step 5: Set the first few elements of each row to the padding ID
    padding_mask = cols_range < pad_counts[:, None]
    shifted_token[padding_mask] = pad_id    
    return shifted_token

pad_id = 50259
token = np.array([
    [20, 30, 40, 50, pad_id, pad_id, pad_id, 50],
    [90, 30, 40, pad_id, pad_id, pad_id, pad_id, 20],
])

output = right_to_left_pad_vectorized(token, pad_id)

# output:
# array([[50259, 50259, 50259,    20,    30,    40,    50, 50259],
#       [50259, 50259, 50259, 50259,    90,    30,    40, 50259]])

I'll try to think of ways to deal with middle pad tokens.

vwxyzjn commented 1 year ago

Btw I think we only need to deal with the (continuous) cases:

token = np.array([
    [pad_id, 30, 40, 50, pad_id, pad_id, pad_id,],
    [90, 30, 40, pad_id, pad_id, pad_id, pad_id],
])

Basically we have query=[pad_id, 30, 40] and after generation we get query_response = [pad_id, 30, 40, 50256, 30, 30] and post_processed_query_response = [pad_id, 30, 40, 50256, pad_id, pad_id]. We need to make it post_processed_query_response = [pad_id, pad_id, pad_id, 30, 40, 50256] before passing into the reward model.

vwxyzjn commented 1 year ago

Ok so this is better

import torch

def right_to_left_pad_vectorized(data, pad_id):
    # Step 1: Create a boolean mask
    mask = (data == pad_id).long()
    # Step 3: Use argsort on the inverted boolean mask to get sorted indices
    sorted_indices = torch.argsort(~mask, axis=1)
    # Step 4: Use advanced indexing to rearrange the elements
    rows_range = torch.arange(data.shape[0], device=data.device)
    shifted_data = data[rows_range[:, None], sorted_indices]
    return shifted_data

def right_to_left_pad(tokens, pad_id):
    return torch.tensor(
        [[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens],
        device=tokens.device,
    )
device = torch.device("cuda:0")
pad_id = 50259
token = torch.tensor([
    [20, 30, 40, 50, pad_id, pad_id, pad_id],
    # [90, 30, 40, pad_id, pad_id, pad_id, pad_id],
    # [90, 30, 40, pad_id, 50, 60, pad_id],
    [pad_id, 10, 40, 20, 50, 60, pad_id],
])
print(right_to_left_pad_vectorized(token, pad_id))
print(right_to_left_pad(token, pad_id))
benchmark_token = torch.zeros((64, 512), dtype=torch.int32)
benchmark_token[:, -30:-1] = pad_id
benchmark_token[:, -1] = pad_id
print(right_to_left_pad_vectorized(benchmark_token, pad_id))
print(right_to_left_pad(benchmark_token, pad_id))

import timeit
print("cpu:")
print(timeit.timeit(lambda: right_to_left_pad_vectorized(benchmark_token, pad_id), number=10))
print(timeit.timeit(lambda: right_to_left_pad(benchmark_token, pad_id), number=10))
print("gpu:")
benchmark_token_device = benchmark_token.to(device)
print(timeit.timeit(lambda: right_to_left_pad_vectorized(benchmark_token_device, pad_id), number=10))
print(timeit.timeit(lambda: right_to_left_pad(benchmark_token_device, pad_id), number=10))
^Ccosta@ip-26-0-150-12:/fsx/costa/lm-human-preference-details$ poetry run python -i eff.py
tensor([[50259, 50259, 50259,    20,    30,    40,    50],
        [50259, 50259,    10,    40,    20,    50,    60]])
tensor([[50259, 50259, 50259,    20,    30,    40,    50],
        [50259, 50259,    10,    40,    20,    50,    60]])
tensor([[50259, 50259, 50259,  ...,     0,     0,     0],
        [50259, 50259, 50259,  ...,     0,     0,     0],
        [50259, 50259, 50259,  ...,     0,     0,     0],
        ...,
        [50259, 50259, 50259,  ...,     0,     0,     0],
        [50259, 50259, 50259,  ...,     0,     0,     0],
        [50259, 50259, 50259,  ...,     0,     0,     0]], dtype=torch.int32)
tensor([[50259, 50259, 50259,  ...,     0,     0,     0],
        [50259, 50259, 50259,  ...,     0,     0,     0],
        [50259, 50259, 50259,  ...,     0,     0,     0],
        ...,
        [50259, 50259, 50259,  ...,     0,     0,     0],
        [50259, 50259, 50259,  ...,     0,     0,     0],
        [50259, 50259, 50259,  ...,     0,     0,     0]])
cpu:
0.005651316023431718
2.196436281024944
gpu:
0.035678989021107554
12.842351484985556
liutianlin0121 commented 1 year ago

I see. The argsort version is similar to the current jax implementation https://github.com/vwxyzjn/lm-human-preference-details/blob/ccc19538e817e98a60d3253242ac15e2a562cb49/lm_human_preference_details/train_policy_jax.py#L316-L321

I used it to make the jax function jittable. I didn't expect it to be faster!

vwxyzjn commented 1 year ago
image

Ahhh yeah it's exactly the same! Not sure how I missed that haha.

vwxyzjn commented 12 months ago

Btw the one based on argsort might have a bug:

import torch

def shift_pad_id_left(data, pad_id):
    # Step 1: Create a boolean mask
    mask = (data == pad_id).long()
    # Step 3: Use argsort on the inverted boolean mask to get sorted indices
    sorted_indices = torch.argsort(~mask, axis=1)
    # Step 4: Use advanced indexing to rearrange the elements
    rows_range = torch.arange(data.shape[0], device=data.device)
    shifted_data = data[rows_range[:, None], sorted_indices]
    return shifted_data

def right_padding_to_left_padding(tokens, pad_id):
    """Convert from right padding to left padding."""
    assert tokens.ndim == 2
    return torch.tensor([[pad_id] * (row == pad_id).sum() + [x for x in row if x != pad_id] for row in tokens])

query = torch.tensor([[   50, 10526, 22083, 49828,    25,   374,    14,  3855, 47733, 30829,
          198, 49560,  2538,    25,  1160,    10,  8059,  3750,   287,  1478,
         1528,    13, 45827,   422,  1223,   314,  1100,   994,    13,   198,
        32782,    25,   314,   423,   587,  2111,   290,  9894,   284,  4425,
         3463,   329,   257,  3155,   812,    13,   314,   561,   466,   257,
        17578,  8027,   329,   546,   257,  1285,    11,   788, 11238,    13,
          314,   561,  4483,  1729,  2245,   357,    40,  3111,   287,   257,
        38164,   828,  4144,   362, 14096,   286,   279,   538, 13396,  4445,
          290,  1239,  5517,    13,   314,   373,   407,  3772,   351,   703,
          314,  2936,    13,   198, 14303,   734,  2745,  2084,    13,   314,
         1100,  1223,   357,    40,  6044,   508,  4481,    11,  7926,     8,
          546,   407,  2111,   284,  4425,  3463,    11,   475,  2427,  2111,
          284,   651,  5448,    13,   632, 16752,  1223,   287,   502,   290,
          314,  3066,   284,  1949,   340,   503,    13,   220,   198,  2293,
         1478,  1528,   314,  1422,   470,  4483, 10536,  4107,  9013,   393,
        42402,    11,  2427,   550, 15921,    14,   303,  1136,  2977,    14,
        20970,    14, 44749,    14, 14784,    13,   314,  1239,   625, 15063,
          393, 15063,   878,  1016,   284,  3996,    13,   314, 24070,   691,
         1660,   290,  6041,   286,   340,    13,   314, 25805,  4445,   290,
         3111,   503,   790,   584,  1110,    13,   198,   314,  4251,   616,
          734,  1285,  2496,   287,   838,  1528,    11,   290,   314,  1254,
         1049,    13,   314,  1053,  2722,   617,  4633,  3513,   546,   340,
          422,   262,   661,   314,  2107,   351,    11,  2192,   780,   484,
         1165,  6531,   351,  3463,    11,   475,   314,   716,   407,  9616,
          340,   651,   284,   502,    13,   198,   198, 14990,    26,  7707,
           25,   220,  5521,   503,    11,  4144,  6041,   286,  1660,    11,
         4483,  5448,    11,  4483, 15921,    14,   303,  1136,  2977,    14,
        20970,    14, 44749,    14, 14784,    13, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,]])

left_padded_query = shift_pad_id_left(query, 50256)
right_padding_to_left_padding_query = right_padding_to_left_padding(query, 50256)
print(query)
print(left_padded_query)
print(right_padding_to_left_padding_query)
torch.testing.assert_allclose(left_padded_query, right_padding_to_left_padding_query)
tensor([[   50, 10526, 22083, 49828,    25,   374,    14,  3855, 47733, 30829,
           198, 49560,  2538,    25,  1160,    10,  8059,  3750,   287,  1478,
          1528,    13, 45827,   422,  1223,   314,  1100,   994,    13,   198,
         32782,    25,   314,   423,   587,  2111,   290,  9894,   284,  4425,
          3463,   329,   257,  3155,   812,    13,   314,   561,   466,   257,
         17578,  8027,   329,   546,   257,  1285,    11,   788, 11238,    13,
           314,   561,  4483,  1729,  2245,   357,    40,  3111,   287,   257,
         38164,   828,  4144,   362, 14096,   286,   279,   538, 13396,  4445,
           290,  1239,  5517,    13,   314,   373,   407,  3772,   351,   703,
           314,  2936,    13,   198, 14303,   734,  2745,  2084,    13,   314,
          1100,  1223,   357,    40,  6044,   508,  4481,    11,  7926,     8,
           546,   407,  2111,   284,  4425,  3463,    11,   475,  2427,  2111,
           284,   651,  5448,    13,   632, 16752,  1223,   287,   502,   290,
           314,  3066,   284,  1949,   340,   503,    13,   220,   198,  2293,
          1478,  1528,   314,  1422,   470,  4483, 10536,  4107,  9013,   393,
         42402,    11,  2427,   550, 15921,    14,   303,  1136,  2977,    14,
         20970,    14, 44749,    14, 14784,    13,   314,  1239,   625, 15063,
           393, 15063,   878,  1016,   284,  3996,    13,   314, 24070,   691,
          1660,   290,  6041,   286,   340,    13,   314, 25805,  4445,   290,
          3111,   503,   790,   584,  1110,    13,   198,   314,  4251,   616,
           734,  1285,  2496,   287,   838,  1528,    11,   290,   314,  1254,
          1049,    13,   314,  1053,  2722,   617,  4633,  3513,   546,   340,
           422,   262,   661,   314,  2107,   351,    11,  2192,   780,   484,
          1165,  6531,   351,  3463,    11,   475,   314,   716,   407,  9616,
           340,   651,   284,   502,    13,   198,   198, 14990,    26,  7707,
            25,   220,  5521,   503,    11,  4144,  6041,   286,  1660,    11,
          4483,  5448,    11,  4483, 15921,    14,   303,  1136,  2977,    14,
         20970,    14, 44749,    14, 14784,    13, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256]])
tensor([[50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256,   198,    13,  1110,   584,   790,   503,
          3111,   290,  4445, 25805,   314,    13,   340,   286,   314,  6041,
           314,  4251,   616,   734,  1285,  2496,   287,   838,  1528,    11,
           290,   314,  1254,  1049,    13,    11,   314, 42402,  2427,   550,
         15921,    14,   303,  1136,  2977,    14, 20970,    14, 44749,    14,
         14784,    13,   290,  1239,   625, 15063,   393, 15063,   878,  1016,
           284,  3996,    13,   314, 24070,   691,  1660,  4483,   198,   198,
         14990,    26,  7707,    25,   220,  5521,   503,    11,  4144,  6041,
           286,  1660,    11,    13,  5448,    11,  4483, 15921,    14,   303,
          1136,  2977,    14, 20970,    14, 44749,    14, 14784,    13,   484,
          2722,   617,  4633,  3513,   546,   340,   422,   262,   661,   314,
          2107,   351,    11,  2192,   780,  1053,  1165,  6531,   351,  3463,
            11,   475,   314,   716,   407,  9616,   340,   651,   284,   502,
           314,  1285,   257,   546,   329,  8027, 17578,   257,   466,   561,
            11,    13,   812,  3155,   257,   329,  3463,  4425,   284,   357,
         14096,   362,  4144,   828, 38164,   257,   287,  3111,    40,  9894,
          2245,  1729,  4483,   561,   314,    13, 11238,   788, 30829,   287,
          3750,  8059,    10,  1160,    25,  2538, 49560,   198,  1478, 47733,
          3855,    14,   374,    25, 49828, 22083, 10526,    13,   290,  2111,
           587,   423,   314,    25, 32782,   198,   286,   994,  1100,   314,
          1223,   422, 45827,    13,  1528,   651,   314,   290,   502,   287,
          1223, 16752,   632,    13,  5448,  3066,   284,  2111,  2427,   475,
            11,  3463,  4425,   284,  1478,   393,  9013,  4107, 10536,  4483,
           470,  1422,   314,  1528,  2111,  2293,   198,   220,    13,   503,
           340,  1949,   284,   314,   198,    13,  2936,   314,   703,   351,
          3772,   407,   373, 14303,    13,  5517,  1239,   290,  4445, 13396,
           538,   279,    40,   407,   546,     8,  7926,    11,  4481,   508,
          6044,    50,   357,  1223,  1100,   314,    13,  2084,  2745,   734]])
tensor([[50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256,    50, 10526, 22083, 49828,    25,   374,
            14,  3855, 47733, 30829,   198, 49560,  2538,    25,  1160,    10,
          8059,  3750,   287,  1478,  1528,    13, 45827,   422,  1223,   314,
          1100,   994,    13,   198, 32782,    25,   314,   423,   587,  2111,
           290,  9894,   284,  4425,  3463,   329,   257,  3155,   812,    13,
           314,   561,   466,   257, 17578,  8027,   329,   546,   257,  1285,
            11,   788, 11238,    13,   314,   561,  4483,  1729,  2245,   357,
            40,  3111,   287,   257, 38164,   828,  4144,   362, 14096,   286,
           279,   538, 13396,  4445,   290,  1239,  5517,    13,   314,   373,
           407,  3772,   351,   703,   314,  2936,    13,   198, 14303,   734,
          2745,  2084,    13,   314,  1100,  1223,   357,    40,  6044,   508,
          4481,    11,  7926,     8,   546,   407,  2111,   284,  4425,  3463,
            11,   475,  2427,  2111,   284,   651,  5448,    13,   632, 16752,
          1223,   287,   502,   290,   314,  3066,   284,  1949,   340,   503,
            13,   220,   198,  2293,  1478,  1528,   314,  1422,   470,  4483,
         10536,  4107,  9013,   393, 42402,    11,  2427,   550, 15921,    14,
           303,  1136,  2977,    14, 20970,    14, 44749,    14, 14784,    13,
           314,  1239,   625, 15063,   393, 15063,   878,  1016,   284,  3996,
            13,   314, 24070,   691,  1660,   290,  6041,   286,   340,    13,
           314, 25805,  4445,   290,  3111,   503,   790,   584,  1110,    13,
           198,   314,  4251,   616,   734,  1285,  2496,   287,   838,  1528,
            11,   290,   314,  1254,  1049,    13,   314,  1053,  2722,   617,
          4633,  3513,   546,   340,   422,   262,   661,   314,  2107,   351,
            11,  2192,   780,   484,  1165,  6531,   351,  3463,    11,   475,
           314,   716,   407,  9616,   340,   651,   284,   502,    13,   198,
           198, 14990,    26,  7707,    25,   220,  5521,   503,    11,  4144,
          6041,   286,  1660,    11,  4483,  5448,    11,  4483, 15921,    14,
           303,  1136,  2977,    14, 20970,    14, 44749,    14, 14784,    13]])
/home/costa/Documents/go/src/github.com/vwxyzjn/lm-human-preference-details/shift_pad_bug.py:54: FutureWarning: `torch.testing.assert_allclose()` is deprecated since 1.12 and will be removed in a future release. Please use `torch.testing.assert_close()` instead. You can find detailed upgrade instructions in https://github.com/pytorch/pytorch/issues/61844.
  torch.testing.assert_allclose(left_padded_query, right_padding_to_left_padding_query)
Traceback (most recent call last):
  File "/home/costa/Documents/go/src/github.com/vwxyzjn/lm-human-preference-details/shift_pad_bug.py", line 54, in <module>
    torch.testing.assert_allclose(left_padded_query, right_padding_to_left_padding_query)
  File "/home/costa/.cache/pypoetry/virtualenvs/lm-human-preference-details-3ZaD8v_C-py3.9/lib/python3.9/site-packages/torch/testing/_comparison.py", line 1553, in assert_allclose
    torch.testing.assert_close(
  File "/home/costa/.cache/pypoetry/virtualenvs/lm-human-preference-details-3ZaD8v_C-py3.9/lib/python3.9/site-packages/torch/testing/_comparison.py", line 1511, in assert_close
    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not equal!

Mismatched elements: 272 / 300 (90.7%)
Greatest absolute difference: 49547 at index (0, 35)
Greatest relative difference: 1223.6922607421875 at index (0, 107)