keras-team / keras-hub

Pretrained model hub for Keras 3
Apache License 2.0
788 stars 241 forks source link

Improve samplers documentation #1030

Open Aliync opened 1 year ago

Aliync commented 1 year ago

Hello,

samplers documentation: https://keras.io/api/keras_nlp/samplers/ that is the most unclear documentation I've ever seen in my entire life! what is this????? someone should explain the documentation :) also, code examples are still not updated (like: https://keras.io/examples/generative/text_generation_gpt/)

jbischof commented 1 year ago

Apologies @Aliync we actually have a new version of our text generation guide: https://keras.io/examples/nlp/gpt2_text_generation_with_kerasnlp/

The old guide is for training from scratch, but we should update it for sure cc @chenmoneygithub

Aliync commented 1 year ago

thanks <3

abuelnasr0 commented 1 year ago

Samplers was introduced to sample output from generative models. so let's talk about sampler in the context of a generative model like gpt2. let's talk about the arguments of next() function. prompt: is the a sentence or batched sentences that you want your model to generate new tokens based on them. cache: cache is used to save the result of computations that we don't need to perform at each iteration because it would give the same result and it would be a waste of time. in gpt2, the key, and value, of the input after being multiplied by the K, and V matrices is being cached because it is the same input in each iteration. when a new token is generated, its embedding only not the whole sentence will be multiplied by the K, and V matrices and the cache will be updated and returned for the next iteration. index: the index to start generating new tokens from. in gpt2. If we have a batch of sentences we will start generating from the end of the shortest sentence. and the shortest sequence is calculated using the paddding mask. for example if we have:

prompt = tf.constant([[1,5,2],
                      [1,5,6,4,5,2]])

the shortest is the first prompt so the index will be 3. but does that mean the other sentence will be updated?NO. but why? in gpt2 the prompt isn't like the example above. the prompt is descriped by two Tensors (token_ids and padding_mask). let's assume the max_length of the model is 8, they will be:

token_ids = tf.constant([[1,5,2,0,0,0,0,0],
                         [1,5,6,4,5,2,0,0]])
padding_mask = tf.constant([[1,1,1,0,0,0,0,0],
                            [1,1,1,1,1,1,0,0]])

we can see that padding mask consists of (1 means the token in that index is from the original sentence, 0 means there is no token at that index). so only the tokens that have padding_mask == 0 at a particular index will be updated. and that explains the fifth argument in the documentation mask. in fact the padding_mask is being passed to the mask argument in gpt2 model. also the index is being calculated using padding mask by taking cumulative sum and then taking the minimum. now we can talk about the returned value from next(): logits: If you don't know what logits is, its unnormalized predictions ranging from [- infinity, infinty]. in gpt2 its unnormalized predictions along all the vocabularies. softamx is applied to it to get the probability of each word in the vocabulary to be the next token. hidden_state: is reperesentation of the next token that you can pass to any layer and train it to predict the next token. in gpt2, the hidden_state is multiplied by the embeddings vector and logits is produced as a prediction for the next_token. cache: as explained above it will be the updated version of the cache. I want to provide a pseudo code to help you understand what I have said:

def generate(prompt):
  token_ids, padding_mask = preprocessor(prompt)
  hidden_sate, cache = compute_cachce_and_State(token_ids)
  index = tf.reduce_min(tf.cumsum(paadding_mask))

  def next(prompt, cache, index):
    cache_index = index-1
    prompt = tf.slice(prompt, [0, cache_index], [-1, 1])
    logits, hidden_state, cache = call_model_with_cahche(prompt, cache, cache_index)
    return logits, hidden_state, cache

  sampler = keras_nlp.samplers.TopKSampler(k=5)

  output = sampler(
      next = next,
      prompt = token_ids,
      cache = cache,
      hidden_sate = hidden_sate,
      mask=padding_mask,
  )

  return output
jbischof commented 1 year ago

@chenmoneygithub could you look into updating @jessechancy's original guide?

shawnz commented 1 year ago

I'm quite unfamiliar with Keras/TF, so I'm not sure if this is correct, but for anyone reading who is trying to get the linked guide working with keras-nlp v0.5+, here is what I did:

First, replace the 1D unpadded prompt_tokens with a 2D padded one:

prompt_tokens = tf.pad(tf.constant([[tokenizer.token_to_id("[BOS]")]]),
                       tf.constant([[0, 0], [0, NUM_TOKENS_TO_GENERATE - 1]]))

Next, replace token_logits_fn with the next function:

def next(prompt, cache, index):
    output = model(prompt)
    return output[:, index, :], None, cache  # return next token logits

Finally, replace the call to top_p_search (or whatever):

output_tokens = keras_nlp.samplers.TopPSampler(p=p)(
            next=next,
            prompt=prompt_tokens,
            index=1,
)

Any advice about correctness/style in my usage of the new APIs would be appreciated. Hope this helps, Shawn

jefke24 commented 1 year ago

Apologies @Aliync we actually have a new version of our text generation guide: https://keras.io/examples/nlp/gpt2_text_generation_with_kerasnlp/

The old guide is for training from scratch, but we should update it for sure cc @chenmoneygithub

Agreed, but for me, that's exactly the appeal of it. It allows for a light-weight GPT-alike model which can be trained also with limited resources.

Good for learning and experimenting. The new one you're referring to requires one to finetune the entire pre-trained GPT2 model. In my case I can't run that on my local machine, it goes out of memory.

That train-from-scratch one I can run locally.

jbischof commented 1 year ago

@jefke24 you could trying halving the batch size to 16 or using a smaller preset (we put param counts in the table)

jefke24 commented 1 year ago

@jbischof I was even trying yesterday if I could train a GPT2 backbone model from scratch on my custom data, based on this: https://keras.io/guides/keras_nlp/getting_started/#pretraining-a-backbone-model

However, I'm too much of a beginner, I couldn't translate this recipe to GPT2, it is considered as expert-level too, so no surprise there :-) I tried using the GPT2Tokenizer and GPT2Preprocessor in combination with the GPT2backbone model with a custom config, but can't get my head around it I'm afraid. That older tutorial is more basic and easier to understand, at least for me.

jbischof commented 1 year ago

However, I'm too much of a beginner, I couldn't translate this recipe to GPT2, it is considered as expert-level too, so no surprise there :-) I tried using the GPT2Tokenizer and GPT2Preprocessor in combination with the GPT2backbone model with a custom config, but can't get my head around it I'm afraid. That older tutorial is more basic and easier to understand, at least for me.

Actually training from scratch is the exact same process as fine-tuning with decoder-only models! Either use the generic constructor to set the architecture or grab the config with a preset and pass load_weights=False. However, note that for any model size, training from scratch is more difficult than fine tuning.

abuelnasr0 commented 1 year ago

@jbischof May I update this example in keras.io ?

jbischof commented 1 year ago

Thank you @abuelnasr0, but I think @mattdangerw has this in flight already! 🚀

mattdangerw commented 1 year ago

Yeah, I believe all uses of the old sampler utils are now replaced with keras_nlp.samplers on keras.io. But please let us know if we missed a spot! Or if there's any bugs with the change.

mattdangerw commented 1 year ago

I'll leave this open in case we want to make further updates to the samplers documentation.

shawnz commented 1 year ago

Nice, it appears the linked page has indeed been updated to fix this issue, thanks!

Nit: the page https://keras.io/examples/generative/text_generation_gpt/ no longer uses the "NUM_TOKENS_TO_GENERATE" constant defined at the top