aws-neuron / transformers-neuronx

Apache License 2.0
88 stars 25 forks source link

Skipping generation for useless tokens, and modiying cacheids #68

Closed enochlev closed 6 months ago

enochlev commented 6 months ago

I am trying to skip generating some tokens that could be skipped via copy paste to hopefully reduce speed up by 70% given my use case, however the problem I am coming up with is when I reset caching... the overhead takes too much time. When I maintain caching, its probabilities seems to not be totaly wrong.

The main goal in end to have constrained generation that supposed to save time if there is only one possible next token to genrate

Here is my current code to reproduce this

def convert_to_tree(sequences): tree = {} for sequence in sequences: sequence_ids = tokenizer.encode(sequence,add_special_tokens=False) current_tree = tree for token in sequence_ids: if token not in current_tree: current_tree[token] = { "token_string": tokenizer.decode([token]), "dangling": True, "tree": {} } else: # Update the dangling value if this token appears more than once current_tree[token]["dangling"] = False current_tree = current_tree[token]["tree"] return tree tree = preprocess([" a business man who"]) from transformers_neuronx import bucket, utils def sample(self, input_ids, sequence_length, start_ids=None, top_k=50, top_p=1.0, eos_token_override=None, temperature=1.0, streamer=None, c_tree = None): # To enable optimized context encoding network, we must pad # up to the context length estimate or we will not correctly # select the final context logits (See: layers/transformer.py). # This also means we need to shift the start_ids over to correct # for padding. offset = 0 batch_size, context_length = input_ids.shape prefixed_length = self.prefixed_length if context_length < prefixed_length: self.prefixed_length = 0 else: input_ids = input_ids[:, prefixed_length:] context_length -= prefixed_length sequence_length -= prefixed_length estimate = bucket.find(self.context_buckets, context_length) if estimate: if context_length < estimate: input_ids = utils.pad(input_ids, 1, estimate, left=True) offset = estimate - context_length if not prefixed_length: if start_ids is None: start_ids = torch.zeros(batch_size, dtype=torch.int32) start_ids += offset sequence_length += offset # Sequence length cannot be greater than n_positions sequence_length = min(sequence_length, self.max_positions) result = sample_llama( self, input_ids, start_ids, sequence_length, eos_token_id=self.config.eos_token_id if eos_token_override is None else eos_token_override, top_k=top_k, top_p=top_p, temperature=temperature, streamer=streamer, c_tree = c_tree ) if offset != 0: result = result[:, offset:] return result from transformers_neuronx.sampling import validate_top_k_top_p_min_tokens_to_keep, top_k_top_p_filtering @torch.no_grad() def sample_llama(model, input_ids, start_ids, sequence_length, eos_token_id=2, top_k=50, top_p=1.0, temperature=1.0, streamer=None, c_tree = None): #validate_top_k_top_p_min_tokens_to_keep(top_k, top_p, None) # populate key/value caches according to the prompt text _, start = input_ids.shape cache_ids = torch.arange(start, dtype=torch.int32) next_token_scores = model(input_ids, cache_ids, start_ids) return sample_loop_llama( model, input_ids, start_ids,next_token_scores, sequence_length, eos_token_id, top_k, top_p, temperature, streamer, c_tree ) #test if cahcing working by turning off restricted generation after 3 tokens were added manually next_token_scores = None def sample_loop_llama(model, input_ids, start_ids,next_token_scores, sequence_length, eos_token_id=2, top_k=50, top_p=1.0, temperature=1.0, streamer=None, c_tree=None): validate_top_k_top_p_min_tokens_to_keep(top_k, top_p, None) if not isinstance(temperature, float) or not (temperature > 0): raise ValueError('temperature has to be a strictly positive float.') # Flags, one per sequence in a batch, to indicate if a sequence hit eos_token_id done_flags = torch.full((input_ids.size(dim=0), 1), False) tokens = [input_ids] _, start = input_ids.shape cache_ids = torch.arange(start, dtype=torch.int32) next_token_scores = model(input_ids, cache_ids, start_ids) print("inputs") print((input_ids,input_ids.shape)) print("cache_ids") input((cache_ids,cache_ids.shape)) tokens_tmp = [] cache_ids_temp = [] for cur_len in range(start, sequence_length): next_len = cur_len + 1 #top_values, top_indices = top_k_top_p_filtering(next_token_scores, top_k=top_k, top_p=top_p) top_indices = list(c_tree.keys()) if len(top_indices) == 1: #skip next_token_scores because there is only one possible token inputs = top_indices[0] inputs = torch.reshape(torch.tensor(inputs),(1,1)) done_flags = torch.logical_or(done_flags, inputs == eos_token_id) token = torch.where(done_flags.eq(True), eos_token_id, inputs) tokens.append(token) if streamer is not None and hasattr(streamer, 'response_with_prefix') and streamer.response_with_prefix: streamer.put(torch.cat(tokens, dim=-1)) elif streamer: streamer.put(token) c_tree = c_tree[top_indices[0]]['tree'] if len(list(c_tree.keys())) == 0: pass#break # forward pass to get next token cache_ids_temp.append(cur_len) tokens_tmp.append(token) ###TODO: assign token to cache_ids elif len(top_indices) == 0: cache_ids = torch.as_tensor(cache_ids_temp, dtype=torch.int32) tokens_pt = torch.as_tensor([tokens_tmp], dtype=torch.int32) print("inputs") print((tokens_pt,tokens_pt.shape)) print("cache_ids") print((cache_ids,cache_ids.shape)) if len(tokens_tmp) != 0:#header condition only next_token_scores = model(tokens_pt, cache_ids, start_ids) cache_ids_temp = [] tokens_tmp = [] ####this whole code will make it contrained generation, but it was commented out to make sure probabilties for random generation are working correctly # top_values = next_token_scores[0][top_indices] # top_value = torch.argmax(top_values) # inputs = top_indices[top_value] # c_tree = c_tree[inputs]['tree'] # inputs = torch.reshape(torch.tensor(inputs),(1,1)) # # Update done flags. # done_flags = torch.logical_or(done_flags, inputs == eos_token_id) # # Update token id to be eos_token_id if the corresponding done flag is True. For a batch, # # this means that, while every sequence in the batch has the same length, a sequence that # # encounters eos_token_id earlier will be filled with eos_token_ids post the first appearance # # of eos_token_id. # token = torch.where(done_flags.eq(True), eos_token_id, inputs) # tokens.append(token) if temperature != 1.0: next_token_scores /= temperature top_values, top_indices = top_k_top_p_filtering(next_token_scores, top_k=top_k, top_p=top_p) # sample probs = torch.nn.functional.softmax(top_values, dim=-1) inputs_in_topk = torch.multinomial(probs, num_samples=1, replacement=True) inputs = torch.gather(top_indices, 1, inputs_in_topk) done_flags = torch.logical_or(done_flags, inputs == eos_token_id) token = torch.where(done_flags.eq(True), eos_token_id, inputs) tokens.append(token) if streamer is not None and hasattr(streamer, 'response_with_prefix') and streamer.response_with_prefix: streamer.put(torch.cat(tokens, dim=-1)) elif streamer: streamer.put(token) if len(list(c_tree.keys())) == 0: pass#break cache_ids_temp.append(cur_len) tokens_tmp.append(token) # if next_len >= sequence_length or done_flags.all(): # break # forward pass to get next token #add multiple models to merge multiple scores if streamer: streamer.end() return torch.cat(tokens, dim=-1) # run inference with top-k sampling print(len(input_ids[0])) with torch.inference_mode(): start = time.time() #print(neuron_model.forward(input_ids)) #generated_sequences = neuron_model.sample(input_ids, sequence_length=len(input_ids[0]) + 100) generated_sequences = sample(neuron_model,input_ids, sequence_length=len(input_ids[0]) + 15,temperature=.8,c_tree=tree) elapsed = time.time() - start nl="\n\n\n\n" generated_sequences = [tokenizer.decode(seq) for seq in generated_sequences] print(f'generated sequences {nl.join(generated_sequences)} in {elapsed} seconds')

they key logs is as so... I thought I was doing everything right, but it seems to still produced wrong results

inputs (tensor([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2180, 278, 2215, 1095, 310, 4726, 988, 278, 1632, 860, 280, 629, 465, 25088, 322, 278, 8805, 1560, 10071, 5232, 4167, 473, 746, 372, 13031, 29879, 322, 694, 13, 18513, 29879, 3926, 1809, 5174, 292, 2030, 274, 5727, 338, 278, 7103, 310, 278, 365, 2027, 287, 10980, 1165, 13, 2855, 6483, 297, 278, 1632, 860, 280, 629, 465, 777, 2305, 1827, 565, 366, 1106, 6483, 3307, 366, 508, 1603, 1074, 9826, 988, 278, 13, 29931, 272, 1165, 2748, 8389, 925, 408, 1472, 408, 372, 1033, 1434, 18462, 26239, 278, 10980, 1165, 3448, 13, 5618, 471, 278, 10980, 1165, 3139, 2020, 471, 372, 727, 1126, 2020, 471, 372, 26239, 322, 4586, 9051, 515, 278, 2215, 1095, 310, 13, 27734, 988, 278, 1632, 860, 280, 629, 465, 25088, 450, 2030, 1551, 2242, 261, 1603, 12080, 1244, 13, 29909, 808, 1075, 540, 9906, 13, 13, 13, 4013, 5828, 338, 1048, 29871]]), torch.Size([1, 800])) cache_ids (tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799], dtype=torch.int32), torch.Size([800])) inputs ### new inputs with 5 extra cach_ids (tensor([[29871, 263, 5381, 767, 1058]], dtype=torch.int32), torch.Size([1, 5])) cache_ids (tensor([800, 801, 802, 803, 804], dtype=torch.int32), torch.Size([5])) #I expect normal generation afterwords inputs (tensor([[29949]], dtype=torch.int32), torch.Size([1, 1])) cache_ids (tensor([805], dtype=torch.int32), torch.Size([1])) inputs (tensor([[259]], dtype=torch.int32), torch.Size([1, 1])) cache_ids (tensor([806], dtype=torch.int32), torch.Size([1])) inputs (tensor([[259]], dtype=torch.int32), torch.Size([1, 1])) cache_ids (tensor([807], dtype=torch.int32), torch.Size([1])) inputs (tensor([[903]], dtype=torch.int32), torch.Size([1, 1])) cache_ids (tensor([808], dtype=torch.int32), torch.Size([1])) inputs (tensor([[386]], dtype=torch.int32), torch.Size([1, 1])) cache_ids (tensor([809], dtype=torch.int32), torch.Size([1])) inputs (tensor([[29899]], dtype=torch.int32), torch.Size([1, 1])) cache_ids (tensor([810], dtype=torch.int32), torch.Size([1])) inputs (tensor([[29871]], dtype=torch.int32), torch.Size([1, 1])) cache_ids (tensor([811], dtype=torch.int32), torch.Size([1])) inputs (tensor([[259]], dtype=torch.int32), torch.Size([1, 1])) cache_ids (tensor([812], dtype=torch.int32), torch.Size([1])) inputs (tensor([[29955]], dtype=torch.int32), torch.Size([1, 1])) cache_ids (tensor([813], dtype=torch.int32), torch.Size([1])) generated sequences At the far end of town where the Gricklegrass grows and the wind smells slowandsour when it blows and no birds ever sing excepting old crows is the Street of the Lifted Lorax And deep in the Gricklegrass some people say if you look deep enough you can still see today where the Lorax once stood just as long as it could before somebody lifted the Lorax away What was the Lorax Any why was it there And why was it lifted and taken somewhere from the far end of town where the Gricklegrass grows The old Onceler still lives here Ask him he knows This story is about a business man whoO _th- 7 in 1.5860817432403564 seconds

the key details in the logs is that I add 5 extra tokens on the next model logits call along with 5 extra cach_ids correctly ordered. I thought I did it correctly, but after normal geneartion... it prouduced garbage

mrnikwaws commented 6 months ago

Hi @enochlev,

It looks like you are trying to generate the remaining tokens in a sequence given a common prefix. Given this is your goal, you need to ensure that the suffix you are attempting to encode, begins with the correct indices into the KV cache. This should be the last token index of the common prefix. Note that it may still be more performant to generate a larger context encoding depending on your use case

enochlev commented 6 months ago

Based on your comment, are you are telling me that the cache_ids have to be correctly incrementally aligned? I thought I did so in my example

First forward call: (prefix) input_ids: [[0, 0, 0, 0, 0, 0, 0 ...., 13, 4013, 5828, 338, 1048, 29871]] cache_ids: [ 0, 1, 2, 3, 4, 5, 6....., 794, 795, 796, 797, 798, 799]

Standard forward call with live cache input_ids: [[1058]] cache_ids: [800] **probabilities seems to be fine with this

Next forward call to skip 4 forward calls..#what I am trying to implement input_ids: [[29871, 263, 5381, 767, 1058]] cache_ids: [801, 802, 803, 804, 805]. **probabilities seemed to get messed up with this even though the cache_ids are aligned

When you are mentioned prefix, are you suggesting the set_prefixed function call?

enochlev commented 6 months ago

I am assuming its not solvable at the moment, or I misunderstood your reply.

I will be closing issue in a few days