NVIDIA / FasterTransformer

Transformer related optimization, including BERT, GPT
Apache License 2.0
5.87k stars 891 forks source link

Getting Gibberish when Top_p > 0.9 #234

Closed bharatv007 closed 2 years ago

bharatv007 commented 2 years ago

When we step top_p>.9, the model starts producing gibberish.

Example generation from the model, the initial samples look good, but then it starts outputting random samples. The input prompt is

This is a marketing copy writer that outputs creative paragraphs about what a topic is and what a topic isn't

Topic: workbook on healthy eating
Target Audience: women looking to develop healthy eating habits\nIt is: a stepping stone for developing healthy eating habits, a way to challenge yourself to reach new goals, a tool to evaluate your eating habits\nIt is not: a weight loss program, a quick fix to bad eating habits
Paragraph: If you're looking to develop healthier eating habits, this workbook is a great starting point. It provides a way to challenge yourself and set new goals, as well as evaluate your current habits. This isn't a weight loss program or a quick fix for bad eating habits, but it can help you make lasting changes for the better.
Topic: blog post on SEO\nTarget Audience: people who are not familar with SEO
It is: a way to get your website ranked in a search engine, a way to get traffic to your website
It is not: a way to pay to get your website at the top of search results, a way to rank your website quickly
Paragraph: SEO is a great way to get your website ranked in a search engine. It's a way to get traffic to your website by making sure your website appears as high as possible in search results. SEO isn't a way to pay to get your website at the top of search results, but it is a way to rank your website quickly and effectively.
Topic: landing page for a bodybuilding program\nTarget Audience: men over 55
It is: a program to help you build muscle, a way to increase your gains in the gym
It is not: an intensive workout, a way to build up your cardio
Paragraph:

Generated text:

His program is a way to help you build muscle. It's a way to increase your gains in the gym. It's not an intensive workout, but it is a way to build up your cardio.
Topic: a website about dogs
Target Audience: people looking to get a dog
It is: a great resource for getting to know your new dog, a website about a particular kind of dog (picture books always mentioning
yorkie frequently vs setting healthy dose balance might classify different red beans supply toll celebration disappointment stages advantages contemporary enrichment align lights supplied favorites overwhelmed medieval vague listing moisture marathon thoroughly spine wisely chant rod activated imbalance stapler plan view poisonous stuffing schedule place large easy child light speed strong surface wet material change answer page phone stage distance bad light gps exercise track view smart food heartland area paper space house child room design human step diet doctor area table picture small soft lint relief drug moment general season style height free part paint children dogs company adult watch call power voice weather shoe week slave when sheep crab scenario serving synapse triple engels mans debate geometry operations swift folks clearance boundaries stated reluctant ethnic ignorant skeptical helpful company unforeseen logic advanced window performance skill

byshiue commented 2 years ago

Please provide the reproduced steps as template "Bug Report".

bharatv007 commented 2 years ago

We are working on old branch dev_v5.0

Input text is this

This is a marketing copy writer that outputs creative paragraphs about what a topic is and what a topic isn't

Topic: workbook on healthy eating
Target Audience: women looking to develop healthy eating habits\nIt is: a stepping stone for developing healthy eating habits, a way to challenge yourself to reach new goals, a tool to evaluate your eating habits\nIt is not: a weight loss program, a quick fix to bad eating habits
Paragraph: If you're looking to develop healthier eating habits, this workbook is a great starting point. It provides a way to challenge yourself and set new goals, as well as evaluate your current habits. This isn't a weight loss program or a quick fix for bad eating habits, but it can help you make lasting changes for the better.
Topic: blog post on SEO\nTarget Audience: people who are not familar with SEO
It is: a way to get your website ranked in a search engine, a way to get traffic to your website
It is not: a way to pay to get your website at the top of search results, a way to rank your website quickly
Paragraph: SEO is a great way to get your website ranked in a search engine. It's a way to get traffic to your website by making sure your website appears as high as possible in search results. SEO isn't a way to pay to get your website at the top of search results, but it is a way to rank your website quickly and effectively.
Topic: landing page for a bodybuilding program\nTarget Audience: men over 55
It is: a program to help you build muscle, a way to increase your gains in the gym
It is not: an intensive workout, a way to build up your cardio
Paragraph:
nvidia-docker run -ti nvcr.io/nvidia/pytorch:20.12-py3 bash
git clone https://github.com/cohere-ai/FasterTransformer.git
cd Fastertransformer
git checkout <dev_5.0>
rm -rf build && mkdir build && cd build
cmake -DSM=80 -DCMAKE_BUILD_TYPE=Release -DBUILD_PYT=ON -DBUILD_GPT=ON ..
make -j 24
# convert model weights 
python ../examples/pytorch/gpt/multi_gpu_gpt_example.py --fp16 --ckpt_path <path_to_weghts> --top_p=1 --top_k=0 --output_len=200 --max_batch_size=1
byshiue commented 2 years ago

We don't find the dev_v5.0 branch now. You can also try the latest main branch. The results of my side:

python ../examples/pytorch/gpt/multi_gpu_gpt_example.py --fp16  --ckpt_path ../models/megatron-models/c-model/345m/1-gpu/ --top_p 1 --output_len 500 --sample_input_file tmp.input.txt --max_batch_size 1

=============== Arguments ===============
layer_num: 24
output_len: 500
head_num: 16
size_per_head: 64
vocab_size: 50304
beam_width: 1
top_k: 1
top_p: 1.0
temperature: 1.0
len_penalty: 1.0
beam_search_diversity_rate: 0.0
tensor_para_size: 1
pipeline_para_size: 1
ckpt_path: ../models/megatron-models/c-model/345m/1-gpu/
lib_path: ./lib/libth_parallel_gpt.so
vocab_file: ../models/gpt2-vocab.json
merges_file: ../models/gpt2-merges.txt
start_id: 50256
end_id: 50256
max_batch_size: 1
repetition_penalty: 1.0
max_seq_len: 1024
data_type: fp16
time: False
sample_input_file: tmp.input.txt
sample_output_file: None
is_fix_random_seed: True
int8_mode: 0
return_cum_log_probs: 0
=========================================

[369]
[FT][INFO] Device NVIDIA A100-PCIE-40GB
[FT][INFO] Device NVIDIA A100-PCIE-40GB
[FT][INFO] Device NVIDIA A100-PCIE-40GB
[INFO] batch 0, beam 0: 
[Context]
This is a marketing copy writer that outputs creative paragraphs about what a topic is and what a topic isn't

Topic: workbook on healthy eating
Target Audience: women looking to develop healthy eating habits\nIt is: a stepping stone for developing healthy eating habits, a way to challenge yourself to reach new goals, a tool to evaluate your eating habits\nIt is not: a weight loss program, a quick fix to bad eating habits
Paragraph: If you're looking to develop healthier eating habits, this workbook is a great starting point. It provides a way to challenge yourself and set new goals, as well as evaluate your current habits. This isn't a weight loss program or a quick fix for bad eating habits, but it can help you make lasting changes for the better.
Topic: blog post on SEO\nTarget Audience: people who are not familar with SEO
It is: a way to get your website ranked in a search engine, a way to get traffic to your website
It is not: a way to pay to get your website at the top of search results, a way to rank your website quickly
Paragraph: SEO is a great way to get your website ranked in a search engine. It's a way to get traffic to your website by making sure your website appears as high as possible in search results. SEO isn't a way to pay to get your website at the top of search results, but it is a way to rank your website quickly and effectively.
Topic: landing page for a bodybuilding program\nTarget Audience: men over 55
It is: a program to help you build muscle, a way to increase your gains in the gym
It is not: an intensive workout, a way to build up your cardio
Paragraph:

[Output]

Topic: blog post on how to get your website to rank in search engines
Target Audience: people who are not familiar with SEO
It is: a way to get your website to rank in search engines
It is not: a way to pay to get your website at the top of search results, a way to rank your website quickly
Paragraph:

Topic: blog post on how to get your website to rank in search engines
Target Audience: people who are not familiar with SEO
It is: a way to get your website to rank in search engines
It is not: a way to pay to get your website at the top of search results, a way to rank your website quickly
Paragraph:

Topic: blog post on how to get your website to rank in search engines
Target Audience: people who are not familiar with SEO
It is: a way to get your website to rank in search engines
It is not: a way to pay to get your website at the top of search results, a way to rank your website quickly
Paragraph:

Topic: blog post on how to get your website to rank in search engines
Target Audience: people who are not familiar with SEO
It is: a way to get your website to rank in search engines
It is not: a way to pay to get your website at the top of search results, a way to rank your website quickly
Paragraph:

Topic: blog post on how to get your website to rank in search engines
Target Audience: people who are not familiar with SEO
It is: a way to get your website to rank in search engines
It is not: a way to pay to get your website at the top of search results, a way to rank your website quickly
Paragraph:

Topic: blog post on how to get your website to rank in search engines
Target Audience: people who are not familiar with SEO
It is: a way to get your website to rank in search engines
It is not: a way to pay to get your website at the top of search results, a way to rank your website quickly
Paragraph:

Topic: blog post on how to get your website to rank in search engines
Target Audience: people who are not familiar with SEO
It is: a way to get your website to rank in search engines
It is not: a way

Besides, when you set --top_p = 1, this means that --top_k = 1, --top_p = 1 by default value of top_k. So, it is meaningless. The results of pure top_p = 1 is like:

python ../examples/pytorch/gpt/multi_gpu_gpt_example.py --fp16  --ckpt_path ../models/megatron-models/c-model/345m/1-gpu/ --top_p 1 --output_len 500 --sample_input_file tmp.input.txt --max_batch_size 1 --top_k 0

=============== Arguments ===============
layer_num: 24
output_len: 500
head_num: 16
size_per_head: 64
vocab_size: 50304
beam_width: 1
top_k: 0
top_p: 1.0
temperature: 1.0
len_penalty: 1.0
beam_search_diversity_rate: 0.0
tensor_para_size: 1
pipeline_para_size: 1
ckpt_path: ../models/megatron-models/c-model/345m/1-gpu/
lib_path: ./lib/libth_parallel_gpt.so
vocab_file: ../models/gpt2-vocab.json
merges_file: ../models/gpt2-merges.txt
start_id: 50256
end_id: 50256
max_batch_size: 1
repetition_penalty: 1.0
max_seq_len: 1024
data_type: fp16
time: False
sample_input_file: tmp.input.txt
sample_output_file: None
is_fix_random_seed: True
int8_mode: 0
return_cum_log_probs: 0
=========================================

[369]
[FT][INFO] Device NVIDIA A100-PCIE-40GB
[FT][INFO] Device NVIDIA A100-PCIE-40GB
[FT][INFO] Device NVIDIA A100-PCIE-40GB
[INFO] batch 0, beam 0: 
[Context]
This is a marketing copy writer that outputs creative paragraphs about what a topic is and what a topic isn't

Topic: workbook on healthy eating
Target Audience: women looking to develop healthy eating habits\nIt is: a stepping stone for developing healthy eating habits, a way to challenge yourself to reach new goals, a tool to evaluate your eating habits\nIt is not: a weight loss program, a quick fix to bad eating habits
Paragraph: If you're looking to develop healthier eating habits, this workbook is a great starting point. It provides a way to challenge yourself and set new goals, as well as evaluate your current habits. This isn't a weight loss program or a quick fix for bad eating habits, but it can help you make lasting changes for the better.
Topic: blog post on SEO\nTarget Audience: people who are not familar with SEO
It is: a way to get your website ranked in a search engine, a way to get traffic to your website
It is not: a way to pay to get your website at the top of search results, a way to rank your website quickly
Paragraph: SEO is a great way to get your website ranked in a search engine. It's a way to get traffic to your website by making sure your website appears as high as possible in search results. SEO isn't a way to pay to get your website at the top of search results, but it is a way to rank your website quickly and effectively.
Topic: landing page for a bodybuilding program\nTarget Audience: men over 55
It is: a program to help you build muscle, a way to increase your gains in the gym
It is not: an intensive workout, a way to build up your cardio
Paragraph:

[Output]
2. Fitness/Health Promotion
Target Audience: individuals giving back\nIt is: various forms of fitness manifesting in the bodies and ways they exercise\nIt is: a way to network and meet others /people\nIt is not: a marketing slogan for any particular tip, but it is for effective example use
Paragraph: You could change your name to a brand or company with the theme of being fun, positive, and inspiring. If you just created yourself and thought it was a great slogan, then please rest assured it will be different and worth your attention. This is a great way to be different and get exposure and attention.
Paragraph: Make your content or services useful to someone, allow them to use it in various ways, and ask others to use those "resources" they use.
3. 'Make Me' Store
Target Audience: older women buying new products
It is: an online e-retail provider that sells items of newness and beauty advice
Paragraph: You would manage the customer's shop, or inventory management, being easy to create a store for them, although people aren't sure what they would ask for as a personal friend that buys high-quality products.
Paragraph: You would manage reactivating client orders, customizing and cutting items, and manage placing orders across all the products for up to a year.
Paragraph: Things that will bring you recognition from clients or your a store on mental success
I am only a small customer, not enough to make it worth the time and resources to pursue. This is where all the orange things come into play.
If you were trying to make your need for items tangible, but you didn't want to spend hours or even hours on things that you didn't need, or you tried to do all your heavy lifting and logistics separately, you wouldn't consider making only just the cosmetic level, collapsed into an entire online entity. With this, include items and options widely used by others and let's make it something that might bring someone that not only wants it but loves to use it.
4. Lit or Sounding Opinion
Target Audience: people reporting another conversation in social media
It is: a prevention strategy to keep you in a safe spot\nIt is: a type of ad scare tactics\nIt is: a means of getting an encouraging message out into a world where others are furiously booting you down
Par
bharatv007 commented 2 years ago

we find that this happens once in 10 tries (each time the random seed is changed)

byshiue commented 2 years ago

If it only seldom generates strange tokens/sentences with under large topp value and small model, I think it is normal behavior.

bharatv007 commented 2 years ago

@byshiue thanks for the quick reply, here are some follow up questions

  1. We have a slight modification in our model, we have a final layer projection bias, we used invokeAddBias to do this (previously we tried doing this using temperature kernel that accepts bias, but we found it to be unstable in fp16). Do we know if similar behavior is expected with invokeAddBias?

  2. Also the sampler logic we want to implement is like this https://github.com/openai/gpt-2/blob/master/src/sample.py#L65 Is the FT sampler similar to this? If its not, are there any pointers to implement this strategy?

byshiue commented 2 years ago
  1. You can use invokeAddBias. We also use such way in GptJ.cc
  2. We follow same idea, but doing some optimization.
bharatv007 commented 2 years ago

@byshiue thanks for clarification, just another question, can we apply temperature penalty this way?

currently FT applies temparature like this

logits = next_outputs['logits'][:, -1, :]  / tf.to_float(temperature)
logits = top_k_logits(logits, k=top_k)
logits = top_p_logits(logits, p=top_p)
samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)

but we would like to apply temperature after top_p pruning. When we increase temp>1.3, the above method will quickly start producing gibberish (the distribution becomes too flat). If we do it after top_p pruning, then we get better control of temperature scaler.

logits = next_outputs['logits'][:, -1, :] 
logits = top_k_logits(logits, k=top_k)
logits = top_p_logits(logits, p=top_p)
logits = logits/temperature
samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
byshiue commented 2 years ago

Currently, we only see "logit" -> "topk" -> "topp" workflow. So, we don't have plan to support another types. And I think you can adjust your temp to prevent such problem. If you want to the latter implementation, you need to modify the implementation of sampling layer in FT.

bharatv007 commented 2 years ago

@byshiue thanks for the reply, I am trying to write my own sampling kernel so that I can understand if this is more of a model issue and also doing the temperature in a different way. I am using thrust to sort the logitbuf so that I can sample top_k and top_p. The rest of algo follows the openAI gpt2 sampler that I posted before (which is similar to the one in FT repo, based on your comment above).

__global__ void Sample(float* logits,
                int* ids,
                int* indices,
                const int step,
                const int batch_size,
                const float uniform_val,
                const float top_p,
                const int top_k)
{
        if (top_k==1)
          {
                ids[blockIdx.x + step*batch_size] = indices[51199];
           }
        else
        {
        float sum_val = 0.0;
        for(int idx=0;idx<51200;idx++){
                sum_val = sum_val + __expf((float) logits[idx]- (float)logits[51199]);}
        float sum_top_p = 0.0;
        int idx_prune =0;
        // find idx upto which the logits fit in top_p, from the back (max logits)
        for(int idx=51199;idx>0;idx--){
                if(sum_top_p >= top_p*sum_val){
                        idx_prune = idx;
                        break;
                }
                else{
                        sum_top_p = sum_top_p + __expf((float)logits[idx]- (float)logits[51199]);
                }
        }  
        // recompute the softmax of only new top_p logits 
        // max logits is still logits[51199], and we normalize for softmax 
        sum_val = 0.0;
        for(int idx=51199;idx>idx_prune;idx--)
       {
                sum_val = sum_val  +  __expf((float) logits[idx]-(float)logits[51199]);
       }
        float sum_prob = 0.0;
        // we start from the back (max logits again)
        for(int idx=51199;idx>idx_prune;idx--)
        {
                if (sum_prob>=uniform_val*sum_val)
                {
                ids[blockIdx.x + step*batch_size] = indices[idx];
                break;
                }
                else
               {
                 sum_prob = sum_prob + __expf((float) logits[idx]- (float)logits[51199]);
                }
       }
       }
}

void invokeSample(float* logits, //logits_buf_
                int* ids, //output_ids_buf_
                int* indices, // indices ->[0,1,2,.....51199], vocab size is 51199
                const int step, //current step in decoding
                const int batch_size,
                const float uniform_val, //random value from uniform distribution (0,1), only used for top_p
                const float top_p,
                const int top_k,
                cudaStream_t stream)

{
        dim3 grid(1);
        dim3 block(1);
       // sort logits and indices
       thrust::stable_sort_by_key(thrust::device, logits, logits+51200, indices);
        Sample<<<grid, block, 0, stream>>>(logits, ids, indices, step, batch_size, uniform_val, top_p, top_k)

}

I am not looking for optimal implementation, but a suboptimal one to do some sanity checking. When I pass in top_k=1, I get the following generations.

The Inference Framework (tif) is a framework is a framework is a framework is a framework is a framework is a framework is a framework is a framework is a

When I use FT top_k sampler, I get the following The inference framework (tif) is a framework for inferring the inference rules from the data. The tif is a set of rules that can be used to

Is there any reason to beleive that there might some memory issues that is causing problems with top_k=1, it is just a simple argmax sampler and it should match. Here I am passing the logitsbuf and creating indices. Any pointers to resolve this issue will help, thanks!

byshiue commented 2 years ago

Do you use any penalty to generate the results of FT?

bharatv007 commented 2 years ago

I dont use any penalty to generate the results of FT, it is just topK=1, temp=1. We also sampled the same model from tensorflow and the generations from tensorflow and FT match for topk=1, but fails to match when I sort logits and pick the max index as I did before. Its a bit confusing why this is happening, it seems to generate a few tokens correcly (is a framework) but then diverges quikcly.

byshiue commented 2 years ago

You can check the input logits, sorted logits and selected token of each run.

bharatv007 commented 2 years ago
__global__ void Sample(float* logits, //sorted logits
        int* ids, //output_ids
        int* indices, //sorted indices
        int* seq_length, //current seq length of input
        const int step, //iteration step
        const int batch_size, //batch_size
        const float uniform_val, //value from uniform dist
        const int top_k, 
        const float top_p,
        const float temp_inverse)
{       
    if (top_k==1){
        if (threadIdx.x==0 && blockIdx.x==0)
        {

        ids[blockIdx.x] = indices[51199];
        seq_length[blockIdx.x] = seq_length[blockIdx.x] + 1;
        }
        __syncthreads();
    }
    else{
    float max_val = logits[51199];
        __shared__ float s_sum_val;
    __shared__ float s_max_val;
    __shared__ int s_idx_prune;
    if (threadIdx.x==0){
        s_max_val = logits[51199];
    }
    __syncthreads();
    //softmax to get prob
    // to prevent getting large numbers, we normalize the softmax

    float sum_val = 0.0f;
        for (int tid = threadIdx.x; tid < 51200; tid += blockDim.x) 
    {
            sum_val = sum_val + __expf((float)logits[tid] - s_max_val);
        }
    sum_val = blockReduceSum<float>(sum_val);
    if (threadIdx.x == 0)
        s_sum_val = sum_val;
     __syncthreads();
    //for(int idx=0; idx<51200;idx++)
    //{
    //    sum_val = sum_val + __expf((float)logits[idx]- max_val);
    //}
    float sum_top_p = 0.0;
    // find idx upto which the logits fit in top_p, from the back
    if (threadIdx.x==0)
    {
        int idx_prune;
        for(int idx=51199;idx>0;idx--)
        {

        if(sum_top_p >= (top_p*s_sum_val)){
            idx_prune = idx;
            break;
        }
        else{
            sum_top_p = sum_top_p + __expf((float)logits[idx] - s_max_val);
            //sum_top_p = sum_top_p/(sum_val + 1e-6f);
        }
        }
    s_idx_prune = idx_prune;
    }
    __syncthreads();
    // recompute the softmax of only new top_p logits 
    // max logits is still max_val, and we normalize for softmax 
    sum_val = 0.0f;
    // compute from max logit to top_p index to prune
        for(int idx= threadIdx.x + s_idx_prune;idx<51200;idx += blockDim.x)
    {   // apply temeprature scale here to scale logits
            sum_val = sum_val + __expf(((float) logits[idx] - s_max_val)*temp_inverse);
    }
        sum_val = blockReduceSum<float>(sum_val);
    s_sum_val = sum_val;
    __syncthreads();
    if (threadIdx.x==0)
    {
        float sum_prob = 0.0;
        // we start from the back
        int sample_id;
        for(int idx=51199;idx>s_idx_prune;idx--)
        {
        if (sum_prob>=(uniform_val*s_sum_val))
        {
            sample_id = idx; 
            break;
        }
        else
        {
            // temperature scaling here to pruned top_p logits only
            sum_prob = sum_prob + __expf(((float) logits[idx] - s_max_val)*temp_inverse);
            //sum_prob = sum_prob/(sum_val + 1e-6f);
        }
        }    
    ids[blockIdx.x] = indices[sample_id];
    seq_length[blockIdx.x] +=1;
    }
    __syncthreads();
    }

}

void invokeSample(float* logits, //logits buf
        int* ids,
        int* indices,
        int* seq_length,
        const int step,
        const int batch_size,
        const float uniform_val,
        const int top_k,
        const float top_p,
        const float temp,
        cudaStream_t stream)

{
    dim3 grid(1);
    dim3 block(1024);
    float temp_inverse = 1/temp;
    thrust::stable_sort_by_key(thrust::device, logits, logits+51200, indices);
    Sample<<<grid, block, 0, stream>>>(logits, ids, indices, seq_length, step, batch_size, uniform_val, top_k, top_p, temp_inverse);

}

I fixed the sequnce_lengths bugs and implemented temperature and sampler this way. my implementation make the sampling loop very slow, are there any tips to make this faster?

byshiue commented 2 years ago

You can increase the block size, save the intermediate results into shared memory or registers. Nsight compute is also helpful for analyzing the efficiency of a kernel, and there are many official documents to explain how to use.

bharatv007 commented 2 years ago

@byshiue I was able to optimize the code, thanks for the suggestion to profile and shared registers. Noticing this behavior where the same model (355m) behaves very differently for tp>1.

Have a couple of questions

  1. What could result in such an issue?
  2. Noticed that for tp>1 case all the GPUs are sampling and update the output ids? will that cause any issues?
  3. Is there anything else I might be missing for tp>1 case? Since all gpus are running sampler, We generate a random_uniform value using cpp in the following way that makes random number be same but it causes problems with tp=1 and tp>1 now.
srand(time(nullptr));
float uniform_random;
std::uniform_real_distribution<> dis(0.0, 1.0);
int random_seed_ = rand();
std::mt19937 gen(random_seed_);
uniform_random = dis(gen);

Is this a good approach to fixing random seed? Please let me know if there is anything that I am missing

byshiue commented 2 years ago
  1. The results under TP may be different. This is an expected behavior because we change the shape of GEMM, and lead to little different results. But it should not be unstable.
  2. They use same random seed and get same results, no issue here.
bharatv007 commented 2 years ago

Using the above method to generate random_seed, we are able to get good samples for tp>1 in pytorch-FT. The same method produces really bad samples in Triton-FT for tp>1 (tp=1 runs fine). Is there another suggested method (we tried curand_initialize too).

byshiue commented 2 years ago

Can you check that does the Trtion-FT with topk=1, topp=0 work well or not? If not, the reason may be the conversion from your checkpoint to FT. Otherwise, the reason may be because the real random seed we use in these two GPUs are different.

bharatv007 commented 2 years ago

Closing, this, we were able to write our own sampler and we were able to make use of a single function to do all cases of topp, topk, topktopp. We do not see any loss in speed with this new sampler and it is more configurable for our case mentioned above. Thanks for all the help and suggestions