NVIDIA / FasterTransformer

Transformer related optimization, including BERT, GPT
Apache License 2.0
5.77k stars 882 forks source link

Illegal memory access error while runing a gptneox model #353

Closed tyliupku closed 1 year ago

tyliupku commented 1 year ago

Branch/Tag/Commit

main

Docker Image Version

nvcr.io/nvidia/pytorch:22.07-py3

GPU name

T4

CUDA Driver

470.57.02

Reproduced Steps

Background: I try to load a pretrained model (~350M parameters, same as gpt-neox model architecture) and deploy a text generation service with FastTransformer. 

I follow the SOPs listed in the gpt and gptneox guideline. I change the model config and path in FasterTransformer/examples/cpp/gptneox/gptneox_config.ini that points to my own model.
And I change the FasterTransformer/examples/cpp/gptneox/start_ids.csv to the following:

  1 319,8611,65,8387,10,90,14,767,1050,188 
  2 319,1095,10,90,28,388,14,767,28,388    
  3 319,16399,65,4223,10,821,1050,188,188,188
  4 319,16399,65,4223,10,821,1050,188,188,188
  5 319,8611,65,8387,10,90,14,767,1050,209                                                                                               

then I get the illegal memory access error (In the Debug mode). I can make sure that my pretrained model parameters have been successfully loaded. PS: my model use a different vocabulary with the gptneox-20B model from eleutherai.

However, while using the default start_ids.csv (following), the program goes well. I suspect the bug has something to do with the CUDA memory allocation.

  1 688, 253, 1390, 4564, 273, 1897, 13, 247
  2 510, 1457, 8911, 4487, 273, 26593, 310, 6600
  3 510, 1457, 2816, 28260, 452, 247, 747, 1481
  4 510, 1457, 2816, 7717, 556, 3863, 697, 7970
  5 688, 247, 2118, 326, 588, 2779, 1056, 352
  6 510, 1457, 2816, 28260, 8, 13413, 19169, 14745
  7 510, 9462, 5687, 556, 38350, 26212, 253, 747
  8 510, 806, 673, 309, 3047, 253, 6440, 13                                                                                              

1. nvidia-docker run -ti --rm nvcr.io/nvidia/pytorch:22.07-py3 bash
2. git clone https://github.com/NVIDIA/FasterTransformer.git
3. mkdir -p FasterTransformer/build
4. cd FasterTransformer/build
5. ./bin/gptneox_example
6. [FT][ERROR] CUDA runtime error: an illegal memory access was encountered /data/home/FasterTransformer/src/fastertransformer/models/gptneox/GptNeoX.cc:786
byshiue commented 1 year ago

Do you set the start_id and end_id correctly? They are fixed in the example.

tyliupku commented 1 year ago

Do you set the start_id and end_id correctly? They are fixed in the example.

Yes, both start_id and end_id are set as 0 according to my model's vocab. I have double-checked the config setting in the gptneox_config.ini and can make sure everything is set correctly.

Besides, the config file remains the same for the runs corresponds to the default start_ids.csv and my customized start_ids.csv, so I guess the error might come from somewhere in the FT implementation.

I find a similar issue https://github.com/NVIDIA/FasterTransformer/issues/328 but not sure the issue is related.

byshiue commented 1 year ago

Do you use latest main branch?

I guess the error might come from somewhere in the FT implementation.

If the implementation has bug, you shouldn't be able to run on official gpt neox, unless you can reproduce the issue by public model.

Can you provide the gptneox_config.ini and the config.ini generated by converter?

tyliupku commented 1 year ago

Do you use latest main branch?

I pull again and notice that you modify stop_criteria_kernels.cu 7 days ago. Everything I describe below is in the latest main branch.

I guess the error might come from somewhere in the FT implementation.

If the implementation has bug, you shouldn't be able to run on official gpt neox, unless you can reproduce the issue by public model.

My model was trained with the official gptneox repo and has been successfully loaded and did not exhibit any error during inference within the megatron and huggingface framework. I used the transformation script provided by FT (with minor modifications to make it compatible with my model size). My model is not public and I do not have the chance to reproduce the issue by public model (like gptneox-20B).

Can you provide the gptneox_config.ini and the config.ini generated by converter? Please refer to the details below.

gptneox_config.ini

[ft_instance_hyperparameter]
data_type=fp16
enable_custom_all_reduce=0

tensor_para_size=1
pipeline_para_size=1

model_name=codesn
model_dir=../models/codesn

[request]
beam_width=1 # beam width for beam search
top_k=1 ; k value for top k sampling
top_p=0.0 ; p value for top p sampling
temperature=1.0 ; Use for sampling
repetition_penalty=1.0 ; Use for sampling
len_penalty=0.0
beam_search_diversity_rate=0.0
request_batch_size=8 # determine by the request
request_output_len=256 # determine by the request

[gptneox_20B]
head_num=64
size_per_head=96
vocab_size=50432
decoder_layers=44
rotary_embedding=24
start_id=0
end_id=2
inter_size=24576
use_gptj_residual=1

[codesn]
head_num=16
size_per_head=64
vocab_size=50304
decoder_layers=24
rotary_embedding=64
start_id=0
end_id=0
inter_size=4096
use_gptj_residual=0

config.ini


[gptneox]
model_name=gptneox_20B
head_num=64
size_per_head=96
vocab_size=50432
num_layer=44
rotary_embedding=24
start_id=0
end_id=2
inter_size=24576
use_gptj_residual=1
weight_data_type=fp32

[codesn]
model_name=codesn_0-4B
head_num=16
size_per_head=64
vocab_size=50304
num_layer=24
rotary_embedding=64
start_id=0
end_id=0
inter_size=4096
use_gptj_residual=0
weight_data_type=fp16
tyliupku commented 1 year ago
However, while using the default start_ids.csv (following), the program goes well. I suspect the bug has something to do with the CUDA memory allocation.
  # start_ids_A
  1 319,8611,65,8387,10,90,14,767,1050,188
  2 319,8611,65,8387,10,90,14,767,1050,209  
  3 319,16399,65,4223,10,821,1050,188,188,188  
  # start_ids_B (1st and 2nd lines are the same)
  1 319,8611,65,8387,10,90,14,767,1050,188
  2 319,8611,65,8387,10,90,14,767,1050,188  
  3 319,16399,65,4223,10,821,1050,188,188,188  

I try to change the start_ids.csv a little bit and find that only start_ids_B leads to the Illegal memory access error, start_ids_A does not.

But the output of start_ids_A has a negative id (-13877). I saw the same issue before and guess that negative ids and illegal memory access issues might be related to cuda memory allocation?

  # outputs of start_ids_A
  1 319 8611 65 8387 10 90 14 767 1050 188 209 209 209 2287 188 209 209 -13877 209 209 ...
  2 319 8611 65 8387 10 90 14 767 1050 209 307 754 28 776 14 767 28 776 188 209 209 209 ...
  3 319 16399 65 4223 10 821 1050 188 188 188 209 209 209 392 1005 10 821 11 1928 352 28 ...
byshiue commented 1 year ago

Can you try to run with FT_DEBUG_LEVEL=DEBUG?

tyliupku commented 1 year ago

Can you try to run with FT_DEBUG_LEVEL=DEBUG?

Total ranks: 1.
Device Tesla T4
P0 is running with GPU #0.
[FT][WARNING] Skip NCCL initialization since requested tensor/pipeline parallel sizes are equals to 1.
[WARNING] gemm_config.in is not found; using default GEMM algo
[FT][DEBUG] fastertransformer::Allocator<fastertransformer::AllocatorType::CUDA>::Allocator(int)
[FT][DEBUG] fastertransformer::cublasMMWrapper::cublasMMWrapper(cublasHandle_t, cublasLtHandle_t, cudaStream_t, fastertransformer::cublasAlgoMap*, std::mutex*, fastertransformer::IAllocator*)
[FT][DEBUG] void* fastertransformer::IAllocator::reMalloc(T*, size_t, bool) [with T = void; size_t = long unsigned int]
[FT][DEBUG] std::string fastertransformer::IAllocator::getAddress(void*) const
[FT][DEBUG] Cannot find buffer (nil), mallocing new one.
[FT][DEBUG] virtual void* fastertransformer::Allocator<fastertransformer::AllocatorType::CUDA>::malloc(size_t, bool)
[FT][DEBUG] malloc buffer 0x302000000 with size 33554432
[FT][DEBUG] std::string fastertransformer::IAllocator::getAddress(void*) const
[FT][DEBUG] Read 206045184 bytes from ../models/gptneox_new/1-gpu/model.wte.bin
[FT][DEBUG] Read 4096 bytes from ../models/gptneox_new/1-gpu/model.final_layernorm.bias.bin
[FT][DEBUG] Read 4096 bytes from ../models/gptneox_new/1-gpu/model.final_layernorm.weight.bin
[FT][DEBUG] Read 206045184 bytes from ../models/gptneox_new/1-gpu/model.lm_head.weight.bin
[FT][DEBUG] Read 4096 bytes from ../models/gptneox_new/1-gpu/model.layers.0.input_layernorm.bias.bin
[FT][DEBUG] Read 4096 bytes from ../models/gptneox_new/1-gpu/model.layers.0.input_layernorm.weight.bin
[FT][DEBUG] Read 12582912 bytes from ../models/gptneox_new/1-gpu/model.layers.0.attention.query_key_value.weight.0.bin
[FT][DEBUG] Read 12288 bytes from ../models/gptneox_new/1-gpu/model.layers.0.attention.query_key_value.bias.0.bin
[FT][DEBUG] Read 4194304 bytes from ../models/gptneox_new/1-gpu/model.layers.0.attention.dense.weight.0.bin
[FT][DEBUG] Read 4096 bytes from ../models/gptneox_new/1-gpu/model.layers.0.attention.dense.bias.bin
[FT][DEBUG] Read 16777216 bytes from ../models/gptneox_new/1-gpu/model.layers.0.mlp.dense_h_to_4h.weight.0.bin
[FT][DEBUG] Read 16384 bytes from ../models/gptneox_new/1-gpu/model.layers.0.mlp.dense_h_to_4h.bias.0.bin
[FT][DEBUG] Read 16777216 bytes from ../models/gptneox_new/1-gpu/model.layers.0.mlp.dense_4h_to_h.weight.0.bin
[FT][DEBUG] Read 4096 bytes from ../models/gptneox_new/1-gpu/model.layers.0.mlp.dense_4h_to_h.bias.bin
[FT][DEBUG] Read 4096 bytes from ../models/gptneox_new/1-gpu/model.layers.0.post_attention_layernorm.bias.bin
[FT][DEBUG] Read 4096 bytes from ../models/gptneox_new/1-gpu/model.layers.0.post_attention_layernorm.weight.bin
[FT][DEBUG] Read 4096 bytes from ../models/gptneox_new/1-gpu/model.layers.1.input_layernorm.bias.bin
[FT][DEBUG] Read 4096 bytes from ../models/gptneox_new/1-gpu/model.layers.1.input_layernorm.weight.bin
[FT][DEBUG] Read 12582912 bytes from ../models/gptneox_new/1-gpu/model.layers.1.attention.query_key_value.weight.0.bin
[FT][DEBUG] Read 12288 bytes from ../models/gptneox_new/1-gpu/model.layers.1.attention.query_key_value.bias.0.bin
[FT][DEBUG] Read 4194304 bytes from ../models/gptneox_new/1-gpu/model.layers.1.attention.dense.weight.0.bin
[FT][DEBUG] Read 4096 bytes from ../models/gptneox_new/1-gpu/

...... (omit some debug log) ......

[FT][DEBUG] void fastertransformer::DecoderSelfAttentionLayer<T>::forward(std::vector<fastertransformer::Tensor>*, const std::vector<fastertransformer::Tensor>*, const fastertransformer::AttentionWeight<T>*) [with T = __half]
[FT][DEBUG] void fastertransformer::DecoderSelfAttentionLayer<T>::allocateBuffer(size_t) [with T = __half; size_t = long unsigned int]
[FT][DEBUG] void* fastertransformer::IAllocator::reMalloc(T*, size_t, bool) [with T = __half; size_t = long unsigned int]
[FT][DEBUG] std::string fastertransformer::IAllocator::getAddress(void*) const
[FT][DEBUG] Reuse original buffer 0x30e43cc00 and do nothing for reMalloc.
[FT][DEBUG] void* fastertransformer::IAllocator::reMalloc(T*, size_t, bool) [with T = __half; size_t = long unsigned int]
[FT][DEBUG] std::string fastertransformer::IAllocator::getAddress(void*) const
[FT][DEBUG] Reuse original buffer 0x30e445c00 and do nothing for reMalloc.
[FT][DEBUG] void fastertransformer::TensorParallelGeluFfnLayer<T>::forward(std::vector<fastertransformer::Tensor>*, const std::vector<fastertransformer::Tensor>*, const fastertransformer::FfnWeight<T>*) [with T = __half]
[FT][DEBUG] void fastertransformer::FfnLayer<T>::forward(std::vector<fastertransformer::Tensor>*, const std::vector<fastertransformer::Tensor>*, const fastertransformer::FfnWeight<T>*) [with T = __half]
[FT][DEBUG] void fastertransformer::FfnLayer<T>::allocateBuffer(size_t) [with T = __half; size_t = long unsigned int]
[FT][DEBUG] void* fastertransformer::IAllocator::reMalloc(T*, size_t, bool) [with T = __half; size_t = long unsigned int]
[FT][DEBUG] std::string fastertransformer::IAllocator::getAddress(void*) const
[FT][DEBUG] Reuse original buffer 0x30e448c00 and do nothing for reMalloc.
[FT][DEBUG] void fastertransformer::DynamicDecodeLayer<T>::forward(std::unordered_map<std::__cxx11::basic_string<char>, fastertransformer::Tensor>*, const std::unordered_map<std::__cxx11::basic_string<char>, fastertransformer::Tensor>*) [with T = float]
[FT][DEBUG] void fastertransformer::BaseSamplingLayer<T>::forward(std::unordered_map<std::__cxx11::basic_string<char>, fastertransformer::Tensor>*, const std::unordered_map<std::__cxx11::basic_string<char>, fastertransformer::Tensor>*) [with T = float]
[FT][DEBUG] void fastertransformer::TopKSamplingLayer<T>::runSampling(std::unordered_map<std::__cxx11::basic_string<char>, fastertransformer::Tensor>*, const std::unordered_map<std::__cxx11::basic_string<char>, fastertransformer::Tensor>*) [with T = float]
[FT][DEBUG] void fastertransformer::BaseSamplingLayer<T>::forward(std::unordered_map<std::__cxx11::basic_string<char>, fastertransformer::Tensor>*, const std::unordered_map<std::__cxx11::basic_string<char>, fastertransformer::Tensor>*) [with T = float]
[FT][DEBUG] void fastertransformer::invokeStopWordsCriterion(const int*, const int*, const int*, bool*, size_t, size_t, int, int, int, cudaStream_t) start
[FT][DEBUG] void fastertransformer::invokeLengthCriterion(bool*, bool*, int*, const uint32_t*, int, int, int, cudaStream_t) start

terminate called after throwing an instance of 'std::runtime_error'
  what():  [FT][ERROR] CUDA runtime error: an illegal memory access was encountered /data/home/FasterTransformer/src/fastertransformer/models/gptneox/GptNeoX.cc:786 

[VM-230-132-centos:15511] *** Process received signal ***
[VM-230-132-centos:15511] Signal: Aborted (6)
[VM-230-132-centos:15511] Signal code:  (-6)
[VM-230-132-centos:15511] [ 0] /lib64/libpthread.so.0(+0xf630)[0x7f86ee995630]
[VM-230-132-centos:15511] [ 1] /lib64/libc.so.6(gsignal+0x37)[0x7f86edaf63d7]
[VM-230-132-centos:15511] [ 2] /lib64/libc.so.6(abort+0x148)[0x7f86edaf7ac8]
[VM-230-132-centos:15511] [ 3] /usr/local/lib64/libstdc++.so.6(+0x99203)[0x7f86ee441203]
[VM-230-132-centos:15511] [ 4] /usr/local/lib64/libstdc++.so.6(+0xa4b86)[0x7f86ee44cb86]
[VM-230-132-centos:15511] [ 5] /usr/local/lib64/libstdc++.so.6(+0xa4bf1)[0x7f86ee44cbf1]
[VM-230-132-centos:15511] [ 6] /usr/local/lib64/libstdc++.so.6(+0xa4e45)[0x7f86ee44ce45]
[VM-230-132-centos:15511] [ 7] ./bin/gptneox_example[0x4260cb]
[VM-230-132-centos:15511] [ 8] ./bin/gptneox_example[0x4373fb]
[VM-230-132-centos:15511] [ 9] ./bin/gptneox_example[0x415645]
[VM-230-132-centos:15511] [10] ./bin/gptneox_example[0x407c1a]
[VM-230-132-centos:15511] [11] /lib64/libc.so.6(__libc_start_main+0xf5)[0x7f86edae2555]
[VM-230-132-centos:15511] [12] ./bin/gptneox_example[0x406c79]
[VM-230-132-centos:15511] *** End of error message ***
Aborted
byshiue commented 1 year ago

So, the crash reason is because the model generate minus number and try to run embedding on minus number.

Can you try running with float32? Hope to make sure it is not overflow.

tyliupku commented 1 year ago

Can you try running with float32? Hope to make sure it is not overflow.

The program runs as expected after changing the data_type in gptneox_config.ini from fp16 to fp32. Great thanks for your help!

May I ask why this issue happens? As I make sure that my model is fp16 rather than fp32 (I print out the data type of all the parameters of my model), does the FT computation involve transferring fp16 data to fp32 data somewhere in the GPT-like autoregressive decoding?

byshiue commented 1 year ago

A possible way is huggingface or Megatron clip some outliers values under fp16. I am not sure. We need to check the values layer by layer. Overflow is only a possible issue, there are also other possible issues. But I cannot check it because we don't have model to reproduce.

You can try bf16. If bf16 works well, then it is highly possible a overflow issue and we may need to find where we should clip them by comparing with HF or Megatron. If bf16 does not work, then it is possible some bf16/fp16 kernels have bugs.

tyliupku commented 1 year ago

You can try bf16. If bf16 works well, then it is highly possible a overflow issue and we may need to find where we should clip them by comparing with HF or Megatron. If bf16 does not work, then it is possible some bf16/fp16 kernels have bugs.

Thanks for the quick reply!

How can I try bf16 ? I try to change the data_type in the gptneox.ini to bf16 but end up with an error:

[FT][ERROR] is_fp16 should be 0 (use float) or 1 (use half).

It seems that data_type can only be assigned as fp32 or fp16.

byshiue commented 1 year ago

Sorry. We still not add bf16 implementation in GptNeox.

tyliupku commented 1 year ago

A possible way is huggingface or Megatron clip some outliers values under fp16. I am not sure. We need to check the values layer by layer. Overflow is only a possible issue, there are also other possible issues. But I cannot check it because we don't have model to reproduce.

Thanks for all the replies! I will try to look into this issue myself. If I figure out the underlying causes for the overflow, I will update in this thread.