keras-team / keras-hub

Pretrained model hub for Keras 3
Apache License 2.0
768 stars 233 forks source link

Add contrastive search to our Sampler collection #644

Closed chenmoneygithub closed 1 year ago

chenmoneygithub commented 1 year ago

Constrastive search is an improvement to Top-K search, which further reduces the non-sense repetition. Starting from top-k, the implementation should not be very hard, all we need is this equation:

formulation

This issue will be based on #563, which creates the basic interface of our sampler class.

For reference, please check this paper

================================================================

Adding some implementation guidance...

Contrastive sampler is a variant of top-k sampler, the only difference is it adds a penalty based on max similarity with previously seen tokens. You need to override the sample() method instead of get_next_token() as in TopKSampler because we have to compute the penalty for each vocab token. Here is a template for the implementation to help you start:

@keras.utils.register_keras_serializable(package="keras_nlp")
class ContrastiveSampler(Sampler):
    """Contrastive Sampler class.

    {Add docstring here}
    """

    def __init__(
        self,
        k=5,
        seed=None,
        jit_compile=True,
        run_eagerly=False,
    ):
        self.k = k
        self.seed = seed
        super().__init__(jit_compile=jit_compile, run_eagerly=run_eagerly)

    def get_next_token(self, next_token_probs):
        pass

    def sample(
        self, prompt, token_probability_fn, mask, num_steps, from_logits=True
    ):
        batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1]
        max_length = tf.cast(max_length, num_steps.dtype)
        # The index of the last non-padding token in prompt. Since all sequences
        # are aligned to the right side, the index is the same for all.
        current_index = max_length - num_steps

        def one_step(current_index, prompt, mask):

            #################################################################
            # 1. Get top-k tokens along with their probibility and representation.
            # 
            # Your code goes here!
            #################################################################

            #################################################################
            # 2. Compute the penalty for each token in the top-k selection.
            # 
            # Your code goes here!
            #################################################################

            #################################################################
            # 3. Update the corresponding index and mask.
            # 
            # Your code goes here!
            #################################################################

        # Run a while loop till `max_length` of tokens has been generated.
        current_index, prompt, mask = tf.while_loop(
            cond=lambda current_index, prompt, mask: tf.less(
                current_index, max_length
            ),
            body=one_step,
            loop_vars=(current_index, prompt, mask),
        )
        return prompt

To test your implementation, you can use the code snippet below:

import keras_nlp

gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
gpt2_lm.run_eagerly = True
gpt2_lm.jit_compile = False
print(
    gpt2_lm.generate(
        "that's weird", sampler="greedy", max_length=30
    )
)

gpt2_lm.run_eagerly = True makes the generate run in eager mode for easier debugging.

This issue is a bit challenging but very rewarding!

chenmoneygithub commented 1 year ago

To our contributors,

This issue is unblocked as #563 has been merged. I will update the description with more details soon, and feel free to take this issue! This might become our default text generation sampling algorithm, so very impactful!

soma2000-lang commented 1 year ago

@chenmoneygithub I would love to work on this.

chenmoneygithub commented 1 year ago

@soma2000-lang Awesome, assigned to you!

apupneja commented 1 year ago

@chenmoneygithub the notebook you linked with the issue doesn't seem to be working. I got two errors:

chenmoneygithub commented 1 year ago

@apupneja Thanks! I will update the colab soon, there have been a few changes since I opened the issue.

chenmoneygithub commented 1 year ago

@apupneja Updated the description!

soma2000-lang commented 1 year ago

@chenmoneygithub thanks

apupneja commented 1 year ago

Are you still working on this issue @soma2000-lang ? If not, I can take it up.

soma2000-lang commented 1 year ago

@apupneja this is a bit challenging, I am still trying.If I am unable to do this finally,then I will surely tag you and unassign myself

AmanSal1 commented 1 year ago

HI ,@chenmoneygithub i was going through gsco organisation and there i founded tensorflow There were many projects to contribute and i founded one of the project interesting so i wanted to ask that is it same thing that you have opened Add contrastive search to our Sampler collection

644

This is the link to that project : https://docs.google.com/document/d/1w7MoOy7FJwECdl3gifPttLN-HnRyWG1kJkTg2g8uKz4/edit#

shivance commented 1 year ago

@chenmoneygithub @mattdangerw Could you assign me this? I would love to take a crack at it, as it's a challenging one.

I'll mostly put PRs on pending keras-io issues of mine by tonight.

[Edit] : Looks like @chenmoneygithub is already working on it #896 💯

chenmoneygithub commented 1 year ago

Update this - since we have been changing our API, it requires lots of design to add contrastive search, so I am assigning myself this issue. You can check PR #896 for the progress. Thanks all!

mattdangerw commented 1 year ago

This is done!