huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.92k stars 26.52k forks source link

Adding State-of-the-art Contrastive Search to the Codebase of model.generate() #19182

Closed gmftbyGMFTBY closed 1 year ago

gmftbyGMFTBY commented 2 years ago

Feature request


Catalogue:


1. Abstract: [Back to Top]

In this issue, we try to integrate contrastive search into the codebase of model.generate() as an additional option for text generation. We believe it would greatly benefit the research community.

All related resources of our work have been open-sourced, please check them as below.


2. Introduction: [Back to Top]

Open-ended text generation is one core task in NLP. However, the maximization-based decoding methods (e.g., greedy search and beam search) of neural language models often lead to degenerate problems, i.e., the generated text is unnatural and contains undesirable repetitions. Existing approaches address the text degeneration problem by introducing stochasticity via sampling (e.g. top-k sampling [1] and nucleus sampling [2]), but they often lead to solutions that lack coherence.

In our recent NeurIPS 2022 paper [3], "A Contrastive Framework for Neural Text Generation", we propose a new decoding method, i.e. contrastive search, which can be directly applied to all families of off-the-shelf language models (e.g. GPT and OPT). Specifically, during the decoding process, contrastive search selects from the most probable candidates predicted by the model while taking into account the degeneration penalty computed from the previous context. Formally, at each decoding step, given the context $\boldsymbol{x}_{< t}$, the selection of the output token $\boldsymbol{x}_t$ follows:

2022-09-23-10-45-28-image

where $V^{(k)}$ is the set of top-k predictions from the model's probability distribution. The second term, degeneration penalty, measures how discriminative of candidate $v$ with respect to tokens in the previous context. And $s(hv,h{x_j})$ is the cosine similairty between the representations of the candidate $v$ and the previous token ${x_j}$. (The core implementations of contrastive search can be found in Section 5.)

Our proposed contrastive search can generate text that is coherent to the prefix text as well as can maintain the diversity of its generated result. Through extensive experiments, we demonstrate the clear superiority of contrastive search over all existing methods on both automatic (e.g. MAUVE [4]) and human evaluations.


3. Demonstration of the Awesome Results from Contrastive Search: [Back to Top]

In this section, we provide direct demonstrations to show how awesome contrastive search is. To this end, we select two representative models to demonstrate the results, including: (i) OPT [5], the recently released model by Meta which successfully replicates the results of GPT-3 [6]; (ii) GPT-2 [7], (arguably) the most widely used model for text generation.

3.1. Demonstration with OPT: [Back to Top]

First, let's see the results of OPT models.

Take the conversation generation task from the original paper [5] (Figure 9 at the Appendix E) as an example. The prefix text is shown as below.

A chat between a curious human and the Statue of Liberty.

Human: What is your name?
Statue: I am the Statue of Liberty.
Human: Where do you live?
Statue: New York City.
Human: How long have you lived there?

(i) The generated text with OPT-175B + nucleus sampling, as shown in Figure 9 at the Appendix E from the original paper, is as following:

A chat between a curious human and the Statue of Liberty.

Human: What is your name?
Statue: I am the Statue of Liberty.
Human: Where do you live?
Statue: New York City.
Human: How long have you lived there?
Statue: I have been there since 1886.
Human: What is your job?
Statue: I was built to welcome immigrants to the United States.
Human: What is your favorite food?
Statue: I like ice cream and apples.
Human: What is your favorite color?
Statue: I like red, white, and blue.
Human: What is your favorite animal?
Statue: I like dogs.
Human: What is your favorite song?

Obviously, even with 175B parameters, the model gets stuck in some simple and linguistically repetitive generations with nucleus sampling.

(ii) Now, let's what result we can get using contrastive search with OPT-6.7B [model card] (a more than 20 times smaller model compared to OPT-175B).

[Reproduce Results] In Section 4.2, we show how to reproduce our results with a few lines of code.

A chat between a curious human and the Statue of Liberty.

Human: What is your name?
Statue: I am the Statue of Liberty.
Human: Where do you live?
Statue: New York City.
Human: How long have you lived there?
Statue: Since 1884.
Human: Why did you come to America?
Statue: I was given to the United States by France as a gift for helping the French during the Franco-Prussian War.
Human: What do you think of America?
Statue: I love it. It is the greatest country in the world.
Human: What’s the weather like in New York?
Statue: It is cold.
Human: Is it safe to walk around at night?
Statue: Yes. There are policemen everywhere.
Human: Do you have any children?
Statue: Not yet. My pedestal is empty.
Human: What would you like to say to people who want to immigrate to America?
Statue: Come on over. You will be happy here. We have everything you need.

We see that the generated text with contrastive search is much more diverse and interesting.

[Comparison] For a more direct comparison, the generated results with the same OPT-6.7B using greedy search and nucleus sampling are:

(1) Text generated by greedy search: [click to expand] ``` ---------------------------------------------------------------------------------------------------- A chat between a curious human and the Statue of Liberty. Human: What is your name? Statue: I am the Statue of Liberty. Human: Where do you live? Statue: New York City. Human: How long have you lived there? Statue: I have lived here for over 100 years. Human: What do you do? Statue: I welcome people from all over the world to come to America. Human: What do you think of America? Statue: I love America. Human: What do you think of immigrants? Statue: I love immigrants. Human: What do you think of America? Statue: I love America. Human: What do you think of immigrants? Statue: I love immigrants. Human: What do you think of America? Statue: I love America. Human: What do you think of immigrants? Statue: I love immigrants. Human: What do you think of America? Statue: I love America. Human: What do you think of immigrants? Statue: I love immigrants. Human: What do you think of America? Statue: I love America. Human: What do you think of immigrants? Statue: I love immigrants. Human: What do you think of America? Statue: I love America. Human: What do you think of immigrants? Statue: I love immigrants. Human... ---------------------------------------------------------------------------------------------------- ```
(2) Text generated by nucleus sampling: [click to expand] ``` ---------------------------------------------------------------------------------------------------- A chat between a curious human and the Statue of Liberty. Human: What is your name? Statue: I am the Statue of Liberty. Human: Where do you live? Statue: New York City. Human: How long have you lived there? Statue: Since 1876. Human: Why is the Statue of Liberty guarded? Statue: Because there are many people trying to steal her. a comparison about an unexpressed thought I would also share the story of “A Humble Fear.” At a conference in New York the Dalai Lama gave a speech to the International Thinkers Congress in New York. The whole thing was recorded, and the video is quite interesting. (on a side note, I love the fact that there were some people who laughed when he described himself as a humble being… I think the video is hilarious, there is a reason why I put up the video. Because if you cannot find the humor in this you’re sadly lacking…) In the speech, the Dalai Lama compares the search for truth to searching for treasure. He says: “However there is a huge difference between being a thief and a collector. A thief simply takes things, whereas a collector looks for the beauty, even if it is just a single object.” The above quote is perhaps the most cliched Buddhist philosophy of our times. However the comparison between a collector and a thief is quite interesting. I like to think that the Buddha... ---------------------------------------------------------------------------------------------------- ```

We see that (i) greedy search generates repetitive text; and (ii) nucleus sampling produces text that is incoherent.

3.2. Demonstration with GPT: [Back to Top]

Next, let's see the results of GPT models.

We provide a simple prefix text (DeepMind Company is) with only three words and asks the model to generate a long text with 512 tokens. In this example, we use GPT-2-large [model card] for text generation.

[Reproduce Results] In Section 4.3, we show how to reproduce our results with a few lines of code.

(1) Generated result with contrastive search:

----------------------------------------------------------------------------------------------------
DeepMind Company is a leader in artificial intelligence (AI). We have a long history of working with 
companies such as Google, Facebook, Amazon, and Microsoft to build products that improve people's lives, 
and today we are excited to announce that DeepMind's AlphaGo program has won the game of Go, becoming 
the first program to defeat a professional Go player.

The victory is a testament to the power of deep learning, and to the incredible work of our research team, 
which has been at the forefront of AI research for the past five years. AlphaGo is one of the most advanced 
Go programs ever created, and its performance is an important step towards the goal of human-level AI.

"This is the culmination of a decade of hard work," said Andy Ng, co-founder and CTO of DeepMind. "We are 
thrilled to have achieved this milestone and look forward to continuing to develop AI that can be used in 
a wide range of applications and to help people live better lives."

DeepMind's work on Go began in 2010, when it began to train a neural network to play Go using millions of 
games played by top Go players around the world. Since then, the team has refined the algorithm, adding 
more and more layers of reinforcement learning to make it better at recognizing patterns and making decisions 
based on those patterns. In the past year and a half, the team has made significant progress in the game, 
winning a record-tying 13 games in a row to move into the top four of the world rankings.

"The game of Go is a complex game in which players have to be very careful not to overextend their territory, 
and this is something that we have been able to improve over and over again," said Dr. Demis Hassabis, co-founder
and Chief Scientific Officer of DeepMind. "We are very proud of our team's work, and we hope that it will inspire
others to take the next step in their research and apply the same techniques to other problems."

In addition to the win in Go, DeepMind has also developed an AI system that can learn to play a number of different
games, including poker, Go, and chess. This AI system, called Tarsier, was developed in partnership with Carnegie
Mellon University and the University of California, Berkeley, and is being used to teach computer vision and machine
learning to identify objects in images and recognize speech in natural language. Tarsier has been trained to play
the game of Go and other games on a number of different platforms...
----------------------------------------------------------------------------------------------------

From the results, we can see that the entire generated document is very high-quality and human-like.

[Comparison] For a more direct comparison, the generated results with the same model using greedy search and nucleus sampling are:

(2) Text generated by greedy search: [click to expand] ``` ---------------------------------------------------------------------------------------------------- DeepMind Company is a leading AI research company, with a focus on deep learning and deep learning-based systems. The company's research is focused on the development of deep learning-based systems that can learn from large amounts of data, and that can be used to solve real-world problems. DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service. DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service. DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service. DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service. DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service. DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service. DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service. DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service. DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service. DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service. DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service. DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service. DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service. DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service. DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service. DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service. DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service. DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service. DeepMind's ---------------------------------------------------------------------------------------------------- ```
(3) Text generated by nucleus sampling: [click to expand] ``` ---------------------------------------------------------------------------------------------------- DeepMind Company is a Cardiff-based start-up with an exclusive mission to build the world's largest ever deep-learning system to analyse the world's digital content and in particular, super-sized image content. The system, the largest in the world with no previous expertise in image or digital content detection, will have previously relied on a mixture of machine learning, artificial neural networks, and storage, processing and retrieval techniques. The AI system, called ImageNet, will take new approach to our challenge of data science and machine learning, significantly improving efficiency, natural language processing and full understanding of complex, high-dimensional images, with an Eye of the Tiger framework for extracting techniques to ensure correct detection of particular images in complex scenes. Dr. Mark Ward, Dr. Alex Kudle, Dr. Ralph Pinchbeck and CTO, DeepMind Dr. Alex Kudle Case Study: Derpy's Most Wanted: Fighting Cybersecurity, building a robot-aided smuggling network InfoSec News, 06/07/2017 Dimitrios Papadimitriou (left) and Chris Bardy (right) at G+ XE, July 2017 How to model an industrial malware botnet In this case study, we show how to build a deep-learning environment to model a new, massive ransomware botnet. Our model computes the distribution of user credentials stored on infected machines and produces a toolkit for open-source "modeling-as-code" (MATC) simulation. We elaborate on the resource management aspect of the toolkit, and how it can be adapted to working offline on embedded or cloud-based networks. Hacking Networked: The industrial botnets of the future InfoSec News, 04/11/2017 Intensive analysis of state sponsored malicious cyber activity, published by KBB Strategic The major single source of IoT malware networks in 2017 The global commercial botnet equivalent count grew to 31.5% in 2017, up from 21.1% the year before, according to a comprehensive report from the Government Accountability Office (GAO). According to the report, various malware operators continued to convert massive amounts of wasted data into profits as well as enable sophisticated cyber operations targeting critical infrastructure. Industrial malware blasts up to 31\% of malware within the IP space over 2017... ---------------------------------------------------------------------------------------------------- ```

Obviously, greedy search generates repetitive text while nucleus sampling produces text that is incoherent and quickly goes off-the-topic.


4. Example Usage: [Back to Top]

In our [main repo], we have provided detailed huggingface-style tutorials ([tutorial 1], [tutorial 2]) on how to apply contrastive search on different models across different languages.

In the following, we show how to easily reproduce our results in Section 3 with a few lines of code.

4.1. Environment Setup:

For an easy usage, we have provided a Pypi package which can be installed as below. More details of our package can be found [here].

pip install simctg --upgrade

4.2. Reproduce Results of OPT:

To reproduce our results in Section 3.1 using OPT, (i) We first load the OPT model as

import torch
from simctg.simctgopt import SimCTGOPT
model_name = 'facebook/opt-6.7b'
model = SimCTGOPT(model_name)
tokenizer = model.tokenizer
model.eval()
bos_token_id = tokenizer.bos_token_id
eos_token_id = tokenizer.eos_token_id

(ii) Then, we provide the prefix text as

prefix_text = r"""A chat between a curious human and the Statue of Liberty.

Human: What is your name?
Statue: I am the Statue of Liberty.
Human: Where do you live?
Statue: New York City.
Human: How long have you lived there?"""

(iii) Thirdly, we prepare the input ids as

[Important Tip] As the authors suggested in their [tutorial], OPT adds the EOS token to the beginning of every prompt. So make sure the special token is added at the front of the prompt.

tokens = tokenizer.tokenize(prefix_text)
input_ids = [bos_token_id] + tokenizer.convert_tokens_to_ids(tokens) # adds </s> to the beginning of every prompt
input_ids = torch.LongTensor(input_ids).view(1,-1)

(iv) Lastly, we generate the text with contrastive search as

beam_width, alpha, decoding_len = 5, 0.6, 256
output = model.fast_contrastive_search(input_ids=input_ids, beam_width=beam_width, 
                                       alpha=alpha, decoding_len=decoding_len,
                                       end_of_sequence_token_id = eos_token_id, early_stop = True) 
print("Output:\n" + 100 * '-')
print(tokenizer.decode(output[1:]))
print("" + 100 * '-')

4.3. Reproduce Results of GPT:

To reproduce our results in Section 3.2 using GPT, (i) We first load the GPT-2 model as

import torch
from simctg.simctggpt import SimCTGGPT
model_name = r'gpt2-large'
model = SimCTGGPT(model_name)
model.eval()
tokenizer = model.tokenizer
eos_token_id = tokenizer.eos_token_id

(ii) Then, we prepare the prefix text as

prefix_text = r"DeepMind Company is"
tokens = tokenizer.tokenize(prefix_text)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
input_ids = torch.LongTensor(input_ids).view(1,-1)

(iii) Last, we generate the text with contrastive search as

beam_width, alpha, decoding_len = 4, 0.6, 512
output = model.fast_contrastive_search(input_ids=input_ids, beam_width=beam_width, 
                                       alpha=alpha, decoding_len=decoding_len,
                                      end_of_sequence_token_id = eos_token_id, early_stop = True) 
print("Output:\n" + 100 * '-')
print(tokenizer.decode(output))
print("" + 100 * '-')

5. Code Snippet: [Back to Top]

The main implemetations of contrastive search involves two parts: (i) candidates collection; and (ii) candidate re-ranking.

For more details, please find our open-sourced implementations for [GPT-2 models] and [OPT models].

(i) The collection of candidates can be implemented as below:

def ContrastiveSearchOneStep(model, input_ids, beam_width, alpha):
    '''
        model: the generation model, e.g., gpt2
        input_ids: 1 x seqlen
    '''
    prev_hidden_states, logits = model.compute_logits_and_hidden_states(input_ids)
    _, seqlen, embed_dim = prev_hidden_states.size()
    _, _, vocab_size = logits.size()
    p = random.uniform(0, 1)

    logit_for_next_step = logits[:,-1,:]
    assert logit_for_next_step.size() == torch.Size([1, vocab_size])

    next_probs = F.softmax(logit_for_next_step, dim = -1)
    assert next_probs.size() == logit_for_next_step.size()

    _, top_k_ids = torch.topk(logit_for_next_step, dim = -1, k = beam_width)
    assert top_k_ids.size() == torch.Size([1, beam_width])

    top_k_probs = torch.gather(next_probs, dim = 1, index=top_k_ids)

    assert top_k_probs.size() == top_k_ids.size()
    # compute new hidden 
    expanded_context = [input_ids for _ in range(beam_width)]
    expanded_context = torch.cat(expanded_context, dim = 0)
    assert expanded_context.size() == torch.Size([beam_width, seqlen])
    top_k_ids = top_k_ids.view(beam_width, 1)
    next_input_ids = torch.cat([expanded_context, top_k_ids], dim = -1)
    assert next_input_ids.size() == torch.Size([beam_width, seqlen+1])
    new_hidden_states, next_logits = model.compute_logits_and_hidden_states(next_input_ids)
    assert new_hidden_states.size() == torch.Size([beam_width, seqlen+1, embed_dim])
    context_hidden = new_hidden_states[:,:seqlen,:]
    assert context_hidden.size() == torch.Size([beam_width, seqlen, embed_dim])
    next_hidden = new_hidden_states[:,seqlen:,:]
    assert next_hidden.size() == torch.Size([beam_width, 1, embed_dim])

    next_id = ranking(context_hidden, next_hidden, top_k_ids, top_k_probs, alpha)       

    next_input_ids = torch.cat([input_ids, next_id], dim = -1)
    assert next_input_ids.size() == torch.Size([1, seqlen+1])
    return next_input_ids

(ii) The re-ranking of candidates can be implemented as below:

def ranking(context_hidden, next_hidden, next_top_k_ids, next_top_k_probs, alpha):
    '''
        context_hidden: beam_width x context_len x embed_dim
        next_hidden: beam_width x 1 x embed_dim
        next_top_k_ids: beam_width x 1
    '''
    beam_width, context_len, embed_dim = context_hidden.size()
    assert next_hidden.size() == torch.Size([beam_width, 1, embed_dim])
    norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
    norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
    cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1)
    assert cosine_matrix.size() == torch.Size([beam_width, context_len])
    scores, _ = torch.max(cosine_matrix, dim = -1)
    assert scores.size() == torch.Size([beam_width])
    next_top_k_probs = next_top_k_probs.view(-1)
    scores = (1.0 - alpha) * next_top_k_probs - alpha * scores 
    _, selected_idx = torch.topk(scores, k = 1)
    assert selected_idx.size() == torch.Size([1])
    selected_idx = selected_idx.unsqueeze(0)
    assert selected_idx.size() == torch.Size([1,1])
    next_id = torch.gather(next_top_k_ids, dim = 0, index=selected_idx)
    assert next_id.size() == torch.Size([1,1])
    return next_id

6. Inference Latency: [Back to Top]

Lastly, we compare the inference latency of contrastive search with other widely used decoding methods. The results are shown in the Figure below.

2022-09-23-10-42-01-image

We see that the inference latency of contrastive search is comparable with other widely used methods, which further verifies the practical usage of our proposed approach.


References:

[1] Fan et al., 2018, "Hierarchical Neural Story Generation", ACL 2018

[2] Holtzman et al., 2020, "The Curious Case of Neural Text Degeneration", ICLR 2020

[3] Su et al., 2022, "A Contrastive Framework for Neural Text Generation", NeurIPS 2022

[4] Pillutla et al., 2021, "MAUVE: Measuring the Gap Between Neural Text and Human Text using Divergence Frontiers", NeurIPS 2021

[5] Zhang et al., 2022, "OPT: Open Pre-trained Transformer Language Models", Arxiv 2022

[6] Brown et al., 2020, "Language Models are Few-Shot Learners", NeurIPS 2020

[7] Radford et al., 2018, "Language Models are Unsupervised Multitask Learners"

Motivation

Given the exceptional performances of contrastive search, we certainly believe that it would greatly benefit a wide range of NLP researchers/practitioners in the text generation community.

Your contribution

I can submit a PR for this request feature ASAP.

gmftbyGMFTBY commented 2 years ago

@patrickvonplaten @sgugger @stas00 Thank you very much for your contributions to the codebase of generation_utils. Could you please take a moment to review our request, which we believe could significantly facilitate the development of the text generation community?

patrickvonplaten commented 2 years ago

Wow this is very cool! @gante do you have time to take a look here by any chance? Otherwise @ArthurZucker maybe?

yxuansu commented 2 years ago

Hi @patrickvonplaten, @gante, and @ArthurZucker,

We believe our contrastive search would greatly benefit the research community! Really looking forward to seeing it to be added in the transformers library!

Please do let us know if you need any assistance from our end. Many thanks for your kind help!

Best,

Yixuan

gante commented 2 years ago

Hi @yxuansu -- this is really cool! The method is clear and makes total sense, and the results seem to back it up. Also, this is probably the clearest feature request I've seen here <3

I'd be happy to support you throughout the process, including adding it to the three frameworks (PT/TF/JAX), creating demos, and communicating it. You mentioned that you were willing to open a PR -- how may I be of help? 🤗

Looking at the resources that you shared, it seems like the workflow can be coded within a LogitsProcessor, which would automatically make your method compatible with other logit manipulation strategies (e.g. forbidding certain words) and with all generation strategies (sampling, beam search, ...). In essence, the processor would apply the top k filtering, compute the cosine similarities, compute the logits according to your method, and return them. The caveat is the need for the hidden states, to compute the coside similarities.

Model-specific details like forcing the EOS token in OPT are handled inside generate(), so no further changes should be needed. I'm curious to see the performance in other model types and types of text (like generating code)!

yxuansu commented 2 years ago

Hi @gante -- thank you so much for your reply! I wonder if could you advise us (me and @gmftbyGMFTBY ) on what should be our next step? It is our first time trying to commit to huggingface :-)

Many thanks!

stas00 commented 2 years ago

This looks fantastic. I'm looking forward to having this new feature in transformers.

Also IMHO you actually would probably get an even more impressive improvement using BLOOM-176B which by default with greedy search suffers from getting stuck in repetition a lot.

yxuansu commented 2 years ago

Hi @stas00 -- Thank you very much for your interest! We will work with @gante and try to add this new feature to transformers ASAP!

yxuansu commented 2 years ago

Hi @gante,

For your convenience, you can find our key implementations of contrastive search for GPT models below:

  1. Candidate Ranking
  2. One Step Decoding
  3. Contrastive Search Interface

For OPT models, the resources are referred as below:

  1. Candidate Ranking
  2. One Step Decoding
  3. Contrastive Search Interface

Hope these pointers are useful!

Best,

Yixuan

gante commented 2 years ago

@yxuansu @gmftbyGMFTBY fantastic! The first step is to discuss the design before jumping to the implementation itself. Since it will be your first commit, I'll be assuming that you are not very familiar with the code base, so I'll give extra pointers 🙌

I thought deeper about the design, and I realized that my suggestion above, to use a LogitsProcessor, would require needlessly complicated code. Obtaining $h_v$, according to your implementation, requires running an additional forward pass, and LogitsProcessor isn't the place to do it for a myriad of reasons.

The points above lead to the following proposal of implementation: a dedicated generation method, like sample or beam_search. It will be much easier to implement and test -- you can simply:

  1. make a copy of greedy search
  2. rewrite some of its parts so as to implement your new method
  3. add a new argument to generate, alpha (I'm assuming we'll repurpose the existing top_k argument into your method)
  4. add the needed piping in generate so as to call your method when alpha is set (follow the example here, which triggers greedy_search when certain conditions are met)
  5. Play around with it and confirm that it is working as expected. Then we can design some tests for the codebase.

The only drawback of this design is that we won't be able to mix and match your method with other generation methods that are not coded as a LogitsProcessor, like contrained_beam_search. But that would only be icing on the cake, not the cake itself 🤗

What do you think?

yxuansu commented 2 years ago

Hi @gante -- Thank you for your super cool advice! We will start right away on adapting the greedy search method and get back to you ASAP. Many thanks for your suggestion!

P.S. Would it be more convenient that we add you to our private repo in which we test our implementations? This way, we might be able to test the demos together.

gante commented 2 years ago

@yxuansu Yeah, that's a good idea -- that way I'm also able to push ad hoc changes if needed 👍 After we're all happy with the state of the method, we can open a PR from the fork

yxuansu commented 2 years ago

@gante -- Cool! @gmftbyGMFTBY will send you an invitation after we create the repo, it would not take long :-)

Many thanks!

gmftbyGMFTBY commented 2 years ago

@gante Hi, thank you so much for your suggestions, we've almost prepared the PyTorch version codebase of contrastive_search in our fork. I have sent you an invitation to our repo.

All the changes are in src/tranformer/generation_utils.py and you could check them. Furthermore, we also prepare the test script for you to run the contrastive_search simply. To run this test scripts, please conduct the following commands:

cd tests/generation
CUDA_VISIBLE_DEVICES=0 python test_generation_contrastive_search.py

Looking forward to your valuable questions and suggestions.

Best, TianLan

gante commented 2 years ago

Hi @gmftbyGMFTBY 👋 Thank you for adding me to your fork!

I have looked at the code and at the results of the script you mentioned. It's great that you were able to massage past_key_values to fit your method 💪 From the test script we can see that we are getting the same output for GPT-2 as in your paper, which is a great starting point 🚀

From here, my recommendation would be to open a draft PR. There are a few points that I'd like to sort together with you before opening the review to others:

  1. There is separate logic for decoder-only and encoder-decoder models -- it would be great if we could unify it, even if at expense of a few if/elses
  2. We don't host test scripts, only unit tests, so test_generation_contrastive_search.py has to be removed. We could use a few of its examples for integration tests, though
  3. Readability is very important to us 🤗 A random user reading the code should be able to understand the basics of what is going on (and why) without going to the paper. A few more docstrings, comments, and potentially more informative variable names would go a long way (there are a few more nits, but I want to focus on the important parts first)

Let me know if you'd like a hand tackling any of the points above!

gmftbyGMFTBY commented 2 years ago

Okay, thank you so much for your suggestions!

I'd like to solve points [1] and [3] first. If there is any progress, I will continue to discuss it with you.

Best, TianLan

yxuansu commented 2 years ago

Hi @gante -- Many thanks for your kind help!

We'd like to ask if you would like to join a slack channel with me and @gmftbyGMFTBY? In this way, we can more timely and easily discuss on our PR. If you'd like to do so, could share us with your slack account then we can add you to our private channel? Many thanks!

gante commented 2 years ago

@yxuansu surely, you can add the email of my GH account (joaofranciscocardosogante@gmail.com)

yxuansu commented 1 year ago

Hi @gante -- We have created a private channel and sent an invitation to you. Let's communicate in our channel!

Many thanks for your help!

Best,

Yixuan

yxuansu commented 1 year ago

This looks fantastic. I'm looking forward to having this new feature in transformers.

Also IMHO you actually would probably get an even more impressive improvement using BLOOM-176B which by default with greedy search suffers from getting stuck in repetition a lot.

Hi @stas00 -- Many thanks for your interest in our work! Contrastive search now has been merged to transformers. Here, we provide a short tutorial on how to apply contrastive search within transformers. Please feel free to try it out :-)

gante commented 1 year ago

@yxuansu @stas00 actually it works for nearly all models, except for Bloom (which has a different shape for the past key values output) -- working on it :)

yuchenlin commented 11 months ago

Thank you guys for the PR! I have a quick question about contrastive searching in general, and hope you would please help!

From the description of this contrastive searching method, it should generate a deterministic output by only setting top_k=x and penalty_alpha=y. Do I need to set do_sample=True?

I initially thought we should not because this is a searching method and the penalty helps us do argmax for selecting from the top_k, which should not involve any randomness via sampling. But there will be a warning message telling me that I should enable do_sample=True when I set top_k.

I tried both enabling do_sample and not, and I found that if you set do_sample=True, there will be randomness in multiple runs, while if I set do_sample=False, the results are the same as the greedy decoding results. I'm not sure if I did this correctly. Thank you so much in advance! :D

gmftbyGMFTBY commented 11 months ago

@yuchenlin Hello, thanks for your attention. I will answer your questions as follows:

  1. Your understanding is correct. Contrastive search is a deterministic searching method that is free from randomness, and multiple runs will generate the same results.
  2. The do_sample parameter should be set as False. If the do_sample is set as True, the random sampling decoding method is activated instead of the contrastive search method. More details could be found in these lines: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py#L867