ggerganov / llama.cpp

LLM inference in C/C++
MIT License
64.96k stars 9.31k forks source link

Contrastive Decoding Improves Reasoning in Large Language Models #3278

Closed logikstate closed 5 months ago

logikstate commented 12 months ago

This paper has a method similar to speculative sampling that improves models by sampling the lower quality model for tokens to avoid thus increasing the quality of the output of the higher quality model. Allegedly leading to LLaMA-65B outperforming LLaMA 2, GPT-3.5 and PaLM 2-L on the HellaSwag commonsense reasoning benchmark.

https://arxiv.org/abs/2309.09117

"We demonstrate that Contrastive Decoding -- a simple, computationally light, and training-free text generation method proposed by Li et al 2022 -- achieves large out-of-the-box improvements over greedy decoding on a variety of reasoning tasks. Originally shown to improve the perceived quality of long-form text generation, Contrastive Decoding searches for strings that maximize a weighted difference in likelihood between strong and weak models. We show that Contrastive Decoding leads LLaMA-65B to outperform LLaMA 2, GPT-3.5 and PaLM 2-L on the HellaSwag commonsense reasoning benchmark, and to outperform LLaMA 2, GPT-3.5 and PaLM-540B on the GSM8K math word reasoning benchmark, in addition to improvements on a collection of other tasks. Analysis suggests that Contrastive Decoding improves over existing methods by preventing some abstract reasoning errors, as well as by avoiding simpler modes such as copying sections of the input during chain-of-thought. Overall, Contrastive Decoding outperforms nucleus sampling for long-form generation and greedy decoding for reasoning tasks, making it a powerful general purpose method for generating text from language models."

KerfuffleV2 commented 11 months ago

I was looking at this a few days ago, but it seems pretty complicated. Unlike the other samplers that you can just give the last tokens + current logits to, it seems like contrastive decoding requires a different approach. (Correct me if I'm wrong.)

I tried to find a simple example of implementing it but wasn't successful.

IridiumMaster commented 11 months ago

Here's what they list in their appendix: A.1 CODE IMPLEMENTATION We include PyTorch implementations of contrastive decoding in Algorithm 1 and Algorithm 2 Algorithm 1: Original formulation

# expert logits - unnormalized scores from the expert model
# amateur logits - unnormalized scores from the amateur model # amateur temp - temperature to normalize amateur distribution # alpha - masking threshold
expert probs = softmax(expert logits, dim=-1)
amateur probs = softmax(amateur logits / amateur temp, dim=-1) cutoff = alpha*expert probs.max(dim=-1, keepdim=True).values
diffs = log(expert probs) - log(amateur probs)
cd logits = diffs.masked fill(expert probs < cutoff, -float(’inf’))

Algorithm 2: Our formulation

# expert logits - unnormalized scores from the expert model
# amateur logits - unnormalized scores from the amateur model # alpha - masking threshold
# beta - expert-amateur tradeoff parameter
cutoff = log(alpha) + expert logits.max(dim=-1, keepdim=True).values diffs = (1 + beta)*expert logits - beta*amateur logits
cd logits = diffs.masked fill(expert logits < cutoff, -float(’inf’))

And here is GPT 3.5 16k Turbo's take on the approach required:

  1. Prepare your expert and amateur language models. These models should be pre-trained and capable of generating text.

  2. Calculate the unnormalized scores (logits) for each token from both the expert and amateur models.

  3. Choose a hyperparameter alpha (α) to determine the masking threshold. This will be used to mask out tokens that have lower probability assigned by the expert model.

  4. Calculate the weighted differences in likelihood (diffs) between the expert and amateur models. This can be done by subtracting the amateur logits from the expert logits and applying weights.

  5. Apply the alpha-mask to filter out tokens with lower probability assigned by the expert model. This can be done by comparing the expert logits to a threshold obtained from alpha.

  6. Apply the final CD logits by replacing the expert logits with the masked logits. Tokens that are below the masking threshold will be replaced with -inf to avoid selecting them during decoding.

  7. Use the CD logits to generate text by selecting tokens with higher probabilities in the CD distribution. Greedy decoding or sampling techniques can be used based on your preference.

By following these steps, you can implement contrastive decoding to improve text generation from your language models.

And here's what it has to say about your statement: In the context of the paper, the statement holds true. Contrastive decoding does require a different approach compared to other sampling methods.

Contrastive decoding involves searching for tokens that maximize a weighted difference in likelihood between a stronger expert model and a weaker amateur model. This requires calculating the differences in probabilities between the expert and amateur models, and then applying a masking threshold to filter out low-probability tokens. The resulting contrastive logits are used for text generation.

In contrast, other sampling methods like top-k sampling or nucleus sampling only require the last tokens and current logits to select the next token for text generation. These methods do not involve comparing probabilities between different models or applying specific masking techniques.

Therefore, contrastive decoding does require a distinct approach that considers the differences between the expert and amateur models, making it distinct from other sampling techniques.

It seems like something that could be enabled as speculative decoding with smaller models is implemented, @KerfuffleV2 ?

KerfuffleV2 commented 11 months ago

Yes, it does kind of sound like something that could at least reuse parts of the existing speculative stuff. You might not even need a completely separate model: https://arxiv.org/abs/2309.08168

By the way, you might get more responses if you created this as a discussion rather than an issue.

trabbart commented 10 months ago

I created simple example that uses contrastive decoding in #3984

cebtenzzre commented 10 months ago

The original paper includes a benchmark against Contrastive Search ("CS" in this table), which HF transformers implements.

image

cebtenzzre commented 10 months ago

The amateur model used in the titular paper is a 1.5B LLaMA model trained on the same data as LLaMA-1, which presents a reproducibility issue, as they haven't provided the dataset or the resulting weights.

They find that a fully-trained LLaMA-1 7B as an amateur hurts performance, but a partially trained 7B helps. I don't know where to get one of those either...

OpenLLaMA 3B has a different tokenizer/vocabulary so that won't help.

trabbart commented 10 months ago

Yes. I also don't have access to the 1.5 LLaMA model so I cannot reproduce exact results from titular paper. There is a project for training the LLaMA model with 1.1B parameters: TinyLLaMa It is possible to check performance with this smaller LLaMa as an amateur. TinyLlama has a different size and a different training dataset, so I am not sure if results will be similar.

I think that it would be interesting to test quantized LLaMa 7B model as an amateur. The authors didn't try that in the paper.

nidhoggr-nil commented 7 months ago

Just found the contrastive search paper and it looked interesting, which lead me to contrastive decoding. It seems like an issue is the lack of verifiability due to not having access to the 1.5 LLaMa model? From what I understand they use a hardcoded amateur model, but implementing this, would it not be possible to have llamacpp load two models given at startup via arguments? Like suggested, it would be interesting to test this, and whether models which are too different would still give good results. I guess it's only the logits output that really matters?

trabbart commented 7 months ago

Yes, it is possible. I have started implementing it in the PR (but it is not finished. I plan to look at it during the weekend). You can compile current code in the PR using g++. Probably you would want to change some parameters in the code if you like to test contrastive decoding.

cebtenzzre commented 7 months ago

I made an attempt with TinyLlama, but the results were worse than without contrastive decoding. https://github.com/cebtenzzre/llama.cpp/blob/ceb/contrastive/examples/contrastive/contrastive.cpp

github-actions[bot] commented 5 months ago

This issue was closed because it has been inactive for 14 days since being marked as stale.