NVIDIA / FasterTransformer

Transformer related optimization, including BERT, GPT
Apache License 2.0
5.73k stars 882 forks source link

Stop the generation if the eod is reached #526

Open akhoroshev opened 1 year ago

akhoroshev commented 1 year ago

I play with examples/cpp/gpt/gpt_example.cc and found that generation of tokens does't finish when the first EOD is reached.

This is my gpt_config.ini. I use gpt2 model which was converted into FT format

[ft_instance_hyperparameter]
max_batch_size=8 ; Use for allocate the buffer
max_seq_len=1024 ; The sequence length of position embedding table, should move to model hyper-parameter
beam_width=1 ; beam width for beam search
top_k=0 ; k value for top k sampling
top_p=0.5 ; p value for top p sampling
temperature=1.0 ; Use for sampling
repetition_penalty=2.0 ; Use for sampling
presence_penalty=0.0  ; Only one of repetition_penalty and presence_penalty are allowed.
len_penalty=0.0
beam_search_diversity_rate=0.0
data_type=fp16
sparse=0
model_name=self_defined
model_dir=/home/akhoroshev/ft_gpt2_fp16/1-gpu
shared_contexts_ratio=0

[self_defined]
head_num=12
size_per_head=64
vocab_size=50257
decoder_layers=12

For example if I set

[request]
request_batch_size=1    ; determine by the request
request_output_len=128   ; determine by the request
return_log_probs=false  ; return the output log probs and cumulative log probs.
context_log_probs=false ; include input contexts in the cumulative log probability computation.

and feed tokens [4919, 389] as input

I receive log output

Device NVIDIA A100 80GB PCIe
[INFO] max_input_len (2). 
[INFO] total_output_len (130). 
[WARNING] gemm_config.in is not found; using default GEMM algo
[FT][WARNING] Device 0 peer access Device 1 is not available.
after allocation    : free: 40.26 GB, total: 79.18 GB, used: 38.92 GB
[INFO] request_batch_size 1 beam_width 1 head_num 12 size_per_head 64 total_output_len 130 decoder_layers 12 vocab_size 50257 FT-CPP-decoding-beamsearch-time 149.00 ms
Writing 130 elements
 4919   389   345  1804  1701   198     1  5779    11   314 
zeroCount = 1

and result (103 sensible tokens at the beginning and 27 eod in the tail) whole computation took 149.00 ms

4919 389 345 1804 1701 198 1 5779 11 314 1101 407 523 1654 526 383 25920 82 14464 290 2900 1088 13 366 2061 546 262 1334 286 534 1767 30 2141 484 804 588 502 287 511 8242 393 1223 986 3966 616 1793 2474 843 673 4966 284 607 1735 355 611 2491 656 257 2046 12 46280 457 2214 0 632 373 2138 13400 2045 475 340 1422 470 1011 890 329 2506 338 2951 547 319 428 2576 326 550 3443 1282 503 422 739 606 477 1978 351 2279 1016 2642 878 7288 1392 1969 1576 1399 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 

BUT

when I set

[request]
request_batch_size=1    ; determine by the request
request_output_len=1000   ; determine by the request
return_log_probs=false  ; return the output log probs and cumulative log probs.
context_log_probs=false ; include input contexts in the cumulative log probability computation.

I receive log output

Device NVIDIA A100 80GB PCIe
[INFO] max_input_len (2). 
[INFO] total_output_len (1002). 
[WARNING] gemm_config.in is not found; using default GEMM algo
[FT][WARNING] Device 0 peer access Device 1 is not available.
after allocation    : free: 40.26 GB, total: 79.18 GB, used: 38.92 GB
[INFO] request_batch_size 1 beam_width 1 head_num 12 size_per_head 64 total_output_len 1002 decoder_layers 12 vocab_size 50257 FT-CPP-decoding-beamsearch-time 810.24 ms
Writing 1002 elements
 4919   389   345  1804  1701   198     1  5779    11   314 
zeroCount = 1

and result (103 sensible tokens (same as previous run, it is ok because seed is fixed) at the beginning and 899 eod in the tail) whole computation took 810.24 ms

4919 389 345 1804 1701 198 1 5779 11 314 1101 407 523 1654 526 383 25920 82 14464 290 2900 1088 13 366 2061 546 262 1334 286 534 1767 30 2141 484 804 588 502 287 511 8242 393 1223 986 3966 616 1793 2474 843 673 4966 284 607 1735 355 611 2491 656 257 2046 12 46280 457 2214 0 632 373 2138 13400 2045 475 340 1422 470 1011 890 329 2506 338 2951 547 319 428 2576 326 550 3443 1282 503 422 739 606 477 1978 351 2279 1016 2642 878 7288 1392 1969 1576 1399 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 

The question is why the generation of EOD tokens consumes time?

HEAD is bb94e2d9bbed65c7ea09b0fd340c290dae50f706 (origin/main, origin/HEAD, main) Mar 14

g++ (Ubuntu 10.3.0-1ubuntu1~20.04) 10.3.0
Copyright (C) 2020 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

CUDA 12.1

cmake configuration for building
cmake -DSM=80 -DCMAKE_BUILD_TYPE=Release -DBUILD_MULTI_GPU=OFF
byshiue commented 1 year ago

This is a know issue, please refer this issue https://github.com/NVIDIA/FasterTransformer/pull/487.

gttiankai commented 1 year ago

I used the tmp/fix_gpt_earlystop branch, but still have the same problem

byshiue commented 1 year ago

I used the tmp/fix_gpt_earlystop branch, but still have the same problem

Please provide your reproduced steps.

byshiue commented 1 year ago

The issue is fixed in the MR https://github.com/NVIDIA/FasterTransformer/pull/584 and merge in main branch. Sorry for the late fixing.

skyser2003 commented 1 year ago

Hello, I wonder if #584 also applies to GPT-J?

I am testing inferences with tritonserver's fastertransformer backend and GPT-J converted model, and it takes as much as time in proportion to request_output_len.

If I request 30 tokens, it takses about 150ms, and last 5 tokens are eos tokens. And if I request 330 tokens, it takes about 1400ms, and still last 305 tokens are all eos tokens.

akhoroshev commented 1 year ago

@byshiue thanks for fixing in parallel gpt 😌

gptneox and gptj have the same bug? When I tested gptneox I met an same bug...

Do you have plans for fix them?

Missmiaom commented 1 year ago

@byshiue Thanks for fixing the problem.

but I found that when running Bloom on a single GPU, it still cannot stop after generating the eos token, and it will not stop until the maximum length.

I added a check for finished_buf_ after "result sampling and stop check", and the test passed. I don't know if it is correct in parallel scenario, can you please review it?

ParallelGpt.cc:

PUSH_RANGE("result sampling and stop check");
dynamic_decode_layer_->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors);

+// check finished
+cudaD2Hcpy(h_finished_buf_, finished_buf_, batch_size * beam_width);
+uint sum = 0;
+for (uint i = 0; i < batch_size * beam_width; i++) {
+    sum += (int)h_finished_buf_[i];
+}
+if (sum == batch_size * beam_width) {
+    subbatch_should_stop = true;
+}

*generation_should_stop_ &= subbatch_should_stop;
POP_RANGE;