NVIDIA / FasterTransformer

Transformer related optimization, including BERT, GPT
Apache License 2.0
5.73k stars 882 forks source link

GPTNeoX get wrong output when start_ids is long #465

Open TopIdiot opened 1 year ago

TopIdiot commented 1 year ago

Branch/Tag/Commit

main

Docker Image Version

nvcr.io/nvidia-pytorch:22.07-py3

GPU name

A100

CUDA Driver

450.156.00

Reproduced Steps

1. download EleutherAI/gpt-neox gptneox 20B checkpoint and convert the checkpoint to 2-gpus

2. change examples/cpp/gptneox/start_ids.csv to a long ids, for example I used:
22705, 27, 209, 24553, 12105, 33656, 212, 535, 186, 6717, 5567, 27, 209, 15367, 12105, 17762, 6238, 27896, 44692, 844, 33731, 209, 14129, 244, 12755, 214, 5225, 14980, 209, 11894, 244, 31982, 41202, 34380, 6238, 15367, 33018, 26270, 24379, 231, 45091, 24899, 27107, 217, 26581, 242, 5225, 49271, 27107, 235, 4340, 22417, 186, 22705, 27, 209, 12755, 102, 12755, 244, 22417, 186, 6717, 5567, 27, 209, 20622, 33369, 115, 29440, 42697, 17576, 11871, 220, 17492, 27896, 28610, 20326, 15694, 209, 5225, 43799, 17321, 215, 29645, 6238, 36359, 13609, 15899, 216, 7719, 106, 3218, 99, 16401, 21991, 18617, 235, 18140, 28610, 30886, 21458, 19144, 165, 124, 215, 4340, 38389, 29427, 6238, 39465, 17576, 11871, 220, 17492, 27896, 28610, 20326, 5225, 13486, 123, 15531, 6238, 12335, 30329, 37373, 15694, 209, 25949, 44795, 49056, 44157, 27707, 22044, 21365, 230, 18678, 50039, 13609, 4340, 22417, 186, 33474, 20720, 24370, 6238, 24553, 33018, 29463, 20720, 28305, 5225, 29440, 42697, 24370, 33273, 17321, 227, 13922, 209, 30886, 18947, 101, 3218, 112, 18140, 996, 988, 33158, 50274, 4, 3709, 654, 74, 33203, 31, 996, 50274, 4, 3709, 654, 2703, 31, 996, 50274, 4, 3709, 654, 15467, 16, 22477, 16, 1156, 64, 80, 22477, 15, 17471, 31, 996, 50274, 4, 3709, 654, 15467, 16, 22477, 16, 1156, 64, 74, 22477, 15, 17471, 31, 996, 50274, 4, 3709, 654, 15467, 16, 25836, 1320, 16, 4793, 64, 6082, 15, 17471, 31, 996, 50274, 4, 3709, 654, 15467, 16, 25836, 1320, 16, 2703, 15, 17471, 31, 996, 50274, 4, 3709, 654, 15467, 16, 22477, 16, 26458, 64, 80, 22477, 15, 17471, 31, 996, 50274, 4, 3709, 654, 15467, 16, 22477, 16, 26458, 64, 74, 22477, 15, 17471, 31, 996, 50274, 996, 50274, 8400, 1450, 2703, 1239, 64, 8456, 9, 3474, 6268, 1450, 2703, 7, 14113, 64, 2703, 10, 996, 50274, 92, 996, 50270, 8400, 1450, 2703, 4963, 14113, 64, 22417, 186, 22705, 27, 209, 12755, 102, 12755, 244, 22417, 186, 6717, 5567, 27, 209, 5941, 219, 14262, 236, 20720, 44781, 7719, 234, 32, 996, 996, 996, 22417, 186, 22705, 27, 209, 12105, 5225, 187, 186, 6717, 5567, 27, 209, 34439, 5225, 6238, 20622, 16401, 29440, 42697, 29463, 17492, 330, 3424, 209, 5225, 42292, 13486, 217, 12613, 230, 24370, 33273, 17321, 227, 13922, 4340, 40843, 233, 43216, 6238, 39465, 17576, 11871, 220, 17492, 27896, 28610, 20326, 1239, 64, 8456, 209, 5225, 13486, 123, 15531, 6238, 12335, 30329, 5690, 225, 27896, 34811, 17576, 5225, 13922, 209, 30886, 18947, 101, 3218, 112, 13609, 11894, 121, 27937, 36998, 4340, 38389, 29427, 6238, 39465, 29463, 17492, 24370, 37373, 13922, 209, 30886, 18947, 101, 3218, 112, 27258, 107, 23823, 97, 20326, 27896, 30886, 18947, 101, 3218, 112, 50039, 6238, 42843, 29463, 17492, 209, 13486, 123, 15531, 24370, 11894, 121, 27937, 13922, 209, 36998, 4340, 22417, 186, 25538, 29427, 6238, 39465, 29463, 17492, 24370, 37373, 49056, 44157, 27707, 29427, 5225, 36998, 14262, 236, 23373, 27896, 34744, 36778, 34179, 42522, 13609, 4340, 22417, 186, 20622, 46987, 12105, 27896, 25541, 211, 44006, 5225, 30420, 32378, 6238, 36750, 17783, 216, 36361, 12335, 13609, 5225, 29440, 42697, 43251, 20281, 33859, 27727, 7249, 224, 12676, 213, 4340, 24553, 33018, 15756, 119, 32130, 24553, 5225, 41085, 43741, 36778, 16677, 16818, 108, 43316, 21458, 11827, 104, 17955, 232, 4340, 6886, 996, 22417, 186, 22705, 27, 209, 24553, 12105, 33656, 212, 22417, 186, 6717, 5567, 27, 209, 15367, 12105, 17762, 6238, 27896, 44692, 844, 33731, 209, 14129, 244, 12755, 214, 5225, 14980, 209, 11894, 244, 31982, 41202, 34380, 6238, 15367, 33018, 26270, 24379, 231, 45091, 24899, 27107, 217, 26581, 242, 5225, 49271, 27107, 235, 4340, 996, 996, 996, 22417, 186, 22705, 27, 337, 12, 18, 38563, 30329, 24899, 45508, 22417, 186, 6717, 5567, 27, 209, 13610, 330, 3424, 209, 13609, 6238, 18, 12, 18, 209, 38563, 30329, 374, 4340, 996, 996, 996, 996, 22417, 186, 22705, 27, 209, 45804, 34091, 36387, 111, 10041, 244, 5225, 44712, 31129, 22010, 19780, 119, 22417, 186, 6717, 5567, 27, 209, 45804, 34091, 36387, 111, 10041, 244, 5225, 44712, 31129, 22010, 19780, 119, 25434, 20720, 18140, 996, 996, 50276, 18, 15, 209, 49835, 108, 38908, 36387, 111, 10041, 244, 5225, 9504, 242, 28917, 99, 11016, 213, 24168, 162, 99, 224, 34389, 20025, 28607, 125, 6238, 42843, 49344, 36548, 21998, 218, 47937, 4340, 996, 50276, 19, 15, 209, 49037, 23825, 104, 11016, 213, 24168, 26772, 24086, 17289, 231, 16849, 236, 5225, 162, 99, 224, 34389, 6238, 14378, 125, 7056, 224, 32004, 30329, 162, 123, 108, 162, 119, 125, 37542, 32004, 30329, 12462, 112, 163, 218, 100, 5225, 20241, 109, 34693, 214, 4340, 996, 50276, 20, 15, 209, 17576, 34175, 16818, 108, 16805, 105, 15756, 228, 5959, 103, 6238, 38908, 36820, 15756, 228, 5959, 103, 5225, 30255, 225, 17559, 213, 21458, 10962, 114, 30093, 212, 4340, 996, 50276, 21, 15, 209, 49835, 108, 38908, 36387, 111, 10041, 244, 28327, 46049, 14377, 44573, 35201, 14377, 6916, 113, 34389, 29832, 103, 6238, 25434, 12676, 211, 46982, 106, 16805, 213, 37542, 27707, 32026, 100, 4340, 996, 50276, 22, 15, 209, 17576, 34175, 26151, 216, 24017, 36387, 111, 10041, 244, 6238, 38908, 36820, 10962, 114, 30093, 212, 4340, 22417, 186, 44712, 31129, 22010, 19780, 119, 18140, 16740, 9078, 110, 28327, 41683, 5690, 216, 47750, 39573, 45466, 6238, 7056, 117, 29645, 5225, 44712, 31129, 22010, 19780, 119, 43251, 36213, 24086, 15899, 120, 4746, 20241, 109, 34693, 214, 38563, 36213, 48815, 38328, 12811, 213, 4340, 996, 996, 996, 996, 22417, 186, 22705, 27, 209, 12755, 102, 12755, 244, 22417, 186, 6717, 5567, 27, 209, 42055, 24553, 12105, 24893, 9850, 35241, 7056, 121, 36387, 111, 6238, 33018, 39573, 45466, 16877, 20720, 14486, 100, 48604, 99, 18140, 996, 996, 50276, 18, 15, 209, 49835, 108, 17576, 24553, 41197, 18764, 7056, 121, 5690, 211, 43244, 35498, 34380, 5225, 162, 99, 224, 34389, 6238, 12105, 12755, 125, 162, 99, 224, 9078, 235, 12105, 36387, 111, 12363, 220, 4340, 996, 50276, 19, 15, 209, 49037, 23825, 104, 11016, 213, 24168, 24553, 36387, 111, 10041, 244, 5225, 162, 99, 224, 34389, 6238, 42843, 49835, 108, 38908, 39465, 23788, 11016, 213, 24168, 26772, 24086, 17289, 231, 16849, 236, 4340, 996, 50276, 20, 15, 209, 49037, 23825, 104, 24168, 11016, 213, 5225, 9504, 242, 28917, 99, 6238, 42843, 49835, 108, 38908, 9504, 242, 28917, 99, 11016, 213, 24168, 162, 99, 224, 34389, 20025, 28607, 125, 4340, 996, 50276, 21, 15, 209, 42055, 24553, 12105, 24893, 9850, 35241, 7056, 121, 162, 99, 224, 34389, 6238, 44673, 5690, 231, 12755, 217, 29852, 216, 11894, 121, 162, 99, 224, 34389, 5225, 7056, 121, 18751, 99, 29554, 14262, 223, 6238, 16877, 16800, 125, 33859, 34439, 24086, 163, 216, 102, 19780, 124, 162, 99, 224, 34389, 4340, 996, 50276, 22, 15, 209, 17576, 34175, 21998, 218, 47937, 6238, 38908, 11894, 212, 162, 99, 224, 34389, 14377, 34138, 113, 7249, 242, 5225, 47937, 18434, 4340, 996, 50276, 23, 15, 209, 17576, 34175, 8587, 123, 32026, 100, 6238, 16877, 38908, 11894, 212, 162, 99, 224, 34389, 20025, 28607, 125, 32417, 109, 34439, 4340, 996, 50276, 24, 15, 209, 17576, 22417, 186, 22705, 27, 209, 12755, 102, 12755, 244, 535, 186, 6717, 5567, 27, 209, 12755, 244, 6238, 44673, 49835, 108, 38908, 36387, 111, 10041, 244, 5225, 20241, 109, 34693, 214, 38028, 30423, 6238, 14378, 125, 7056, 224, 41873, 37062, 34389, 37542, 36359, 24446, 37062, 34389, 23823, 242, 31902, 226, 4340, 996, 6886, 996, 22417, 186, 22705, 27, 209, 12755, 102, 12755, 244, 187, 186, 6717, 5567, 27, 209

3. change top_k=1 to use greedy search, and request_output_len=50

4. mpirun -n 2 --allow-run-as-root ./bin/gptneox_example

the generated ids are:
 15367,  42010,  22702,  42010,  42010,  22702,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010,  42010

which is wrong output, and inconsistent with EleutherAI/gpt-neox

The bug can be reproduced in both fp16 and fp32 mode. As we known, GPTNeoX support max length with 2048. In this example, the length of input ids is 1224, and the length of new tokens is 50, which much smaller than 2048

TopIdiot commented 1 year ago

@byshiue Could you share some solutions for debugging?

TopIdiot commented 1 year ago

@byshiue it seems like when start_ids is long, the key/value cache for GptNeoxDecoder are wrong. The key/value cache are calculated in GptNeoxContextDecoder, but it is hard to debug. Do you have any suggestions?

byshiue commented 1 year ago

@byshiue it seems like when start_ids is long, the key/value cache for GptNeoxDecoder are wrong. The key/value cache are calculated in GptNeoxContextDecoder, but it is hard to debug. Do you have any suggestions?

Hi, @TopIdiot. Thank you for the feedback. You can try printing some intermediate results and compare with HF.

TopIdiot commented 1 year ago

@byshiue Thanks for your reply. Are there something that limit the input length? In this case, when input_ids_length <= 1472, the result is true, and all the key/ value cache for the first input_ids_length are same. But when it come to 1473, the result is wrong. All the key/ value cache is different far away from the key/ value cache when input_ids_length <= 1472.

byshiue commented 1 year ago

@byshiue Thanks for your reply. Are there something that limit the input length? In this case, when input_ids_length <= 1472, the result is true, and all the key/ value cache for the first input_ids_length are same. But when it come to 1473, the result is wrong. All the key/ value cache is different far away from the key/ value cache when input_ids_length <= 1472.

We have no idea now. Need to investigate the reason when we have bandwidth.