NVIDIA / FasterTransformer

Transformer related optimization, including BERT, GPT
Apache License 2.0
5.73k stars 882 forks source link

GptNeoX got an invalid result when the callback was set. #588

Open zhang-ge-hao opened 1 year ago

zhang-ge-hao commented 1 year ago

@byshiue

Branch/Tag/Commit

main

Docker Image Version

nvcr.io/nvidia/pytorch:21.11-py3

GPU name

TITAN

model

https://huggingface.co/TabbyML/NeoX-1.3B

Reproduced Steps

the code which can reproduce this bug: https://github.com/AkiyamaYummy/FasterTransformer/tree/tmp

the difference between my buggy code and the newest main branch's code: https://github.com/NVIDIA/FasterTransformer/commit/2c046e3a47836073c867b8fbc284ce7b4d68cac4

image

Got an invalid result when the callback was set.

without registerCallback() used, after the gptneox_example run, the out file was:

1545 3158 64 15227 9 3298 2262 187 50274 32488 187 50274 15462 253 3781 275 1659 15 187 50274 32488 187 50274 4 17399 253 3781 275 1659 15 187 50274 3298 426 20045 9 3298 13 2234 30 2260 1269 27 1269 60 18 3291 187 50274 4 17399 253 3781 275 1659 15 187 50274 3298 426 20045 9 3298 13 2234 30 2260 1269 27 1269 60 17 3291 187 50274 2309 4077 187 187 1545 3158 64 15227 9 3298 2262 187 50274 32488 187 50274 15462 253 3781 275 1659 15 187 50274 32488 187 50274 4 17399 253 3781 275 1659 15 187 50274 3298 426 20045 9 3298 13 2234 30 2260 1269 27 1269 60 18 3291 187 50274 4 17399 253 3781 275 1659 15 187 

with registerCallback() used, after the gptneox_example run, the out file was:

1545 3158 64 15227 9 3298 2262 187 50274 338 8472 9 3298 10 654 374 27 187 50270 2309 544 17 13 470 62 187 50274 7271 27 187 50270 2309 544 17 13 470 62 187 187 1545 3158 64 15227 9 3298 2262 187 50274 338 8472 9 3298 10 654 374 27 187 50270 2309 544 17 13 470 62 187 50274 7271 27 187 50270 187 50270 187 50270 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 34 187 187 
zhang-ge-hao commented 1 year ago

I seem to have fixed this bug.

param.max_sequence_lengths will be updated in invokeGatherTree(param);

however, if the sequence_lengths_ was updated, the process in the next step will be wrong.

so it's sequence_lengths that should be passed into the function invokeGatherTree(param); instead of sequence_lengths_.

I'll submit a PR.