huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
128.63k stars 25.51k forks source link

Add Classifier-Free Guidance sampling #24536

Closed Vermeille closed 11 months ago

Vermeille commented 1 year ago

EDIT: =========================== As I see many people copy pasting this initial code that was meant to be a basis for discussion, here is a cleaner version (yet not perfect! We're still doing improvement rounds with the huggingface team to improve it! Check the state of the PR until it's not merged! https://github.com/huggingface/transformers/pull/24654 ).

from transformers import (GPT2Tokenizer, AutoModelForCausalLM,
                          GPTNeoXForCausalLM, AutoTokenizer)
import numpy as np
import torch
from transformers import (LogitsProcessor, LogitsProcessorList,
                          MinLengthLogitsProcessor, TemperatureLogitsWarper,
                          TopKLogitsWarper, TopPLogitsWarper,
                          TypicalLogitsWarper)
from transformers.generation import LogitNormalization
import torch.nn.functional as F

class CFGLogits(LogitsProcessor):
    r"""Logits processor for Classifier-Free Guidance (CFG). The processors
    computes a weighted average across scores from prompt conditional and prompt unconditional (or negative) logits,
    parameterized by the `guidance_scale`. The unconditional scores are computed internally by prompting `model` with
    the `uncond` branch. Finally, according to CFG Rescale, the reweighted logits are interpolated back with weight
    `rescale_factor` the conditional ones to smooth the effect and increase output quality.

    See [the paper](https://arxiv.org/abs/2306.17806) for more information.

    Args:
        guidance_scale (float):
            The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
            Higher guidance scale encourages the model to generate samples that are more closely linked to the input
            prompt, usually at the expense of poorer quality.
        uncond (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary for the unconditional branch.
        model:
            The LM computing the unconditional scores. Supposedly the same as the one computing the conditional scores.
            Both models must use the same tokenizer.
    """

    def __init__(self, guidance_scale, uncond, model):
        self.guidance_scale = guidance_scale
        self.uncond = uncond
        self.model = model
        self.out = None
        self.rescale_factor = rescale_factor

    def __call__(self, input_ids, scores):
        scores = F.log_softmax(scores, dim=-1)
        if self.guidance_scale == 1:
            return scores

        if self.out is None:
            self.out = self.model(self.uncond, use_cache=True)
        else:
            self.out = self.model(
                input_ids[:, -1:],
                use_cache=True,
                past_key_values=self.out.past_key_values,
            )
        unconditional_logits = F.log_softmax(self.out.logits[0][-1:], dim=-1)
        out = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits
        return out

# paper usage: (copying and editing @grantCelley 's answer)
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import LogitsProcessorList, TemperatureLogitsWarper, TopPLogitsWarper

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m")

model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-160m")

prompt = tokenizer("Today a dragon flew over Paris, France,", return_tensors='pt')
# either provide a negative prompt:
neg_prompt = tokenizer("A sad event happened,", return_tensors='pt')['input_ids']
# or don't:
# neg_prompt = prompt['input_ids'][:, -1:]

device='cuda:0'
model.to(device)
outputs = model.generate(
    input_ids=prompt['input_ids'].to(device),
    attention_mask=prompt['attention_mask'].to(device),
    max_new_tokens=125,
    logits_processor=LogitsProcessorList([
        # inputs_cfg usually is the last token of the prompt but there are
        # possibilities of negative prompting that are explored in the paper
        CFGLogits(1.5, neg_prompt.to(device), model),
        TemperatureLogitsWarper(0.8),
        TopPLogitsWarper(0.95),
    ]),
    do_sample=True,
)

print(tokenizer.decode(outputs[0]))

===============================

Feature request

Hello! I wish to contribute CFG sampling. I'm working with EleutherAI and @StellaAthena and will have a paper about it by Friday. CFG brings non trivial improvements on many standard benchmarks. It contrast the logits for the next token $P(wt|w{..t}, prompt)$ to that of the input deprived of the prompt $P(wt|w{..t})$, by defining

$$ \log P{\text{cfg}}(w|w{..t}, prompt) = \log P(w|w{..t}) + \text{cfg} * (\log P(w|w{..t}, prompt) - \log P(w|w_{..t}) $$

And then we can blend $\log P{\text{cfg}}$ with $\log P(w|w{..t}, prompt)$ to smoothen that distribution a bit, but it's optional.

Motivation

My current implementation is:

class CFGLogits(LogitsWarper):

    def __init__(self, cfg, inputs, model, verbose=True):
        self.cfg = cfg
        self.inputs = inputs
        self.model = model
        self.out = None
        self.verbose = verbose

    def __call__(self, input_ids, scores):
        if self.cfg == 1:
            return F.log_softmax(scores, dim=-1)
        scores = F.log_softmax(scores, dim=-1)
        if self.out is None:
            self.out = self.model(self.inputs.to(device), use_cache=True)
        else:
            self.out = self.model(input_ids[:, -1:],
                                  use_cache=True,
                                  past_key_values=self.out.past_key_values)
        unconditional_logits = F.log_softmax(self.out.logits[0][-1:], dim=-1)
        out = self.cfg * (scores - unconditional_logits) + unconditional_logits
        out = F.log_softmax(out, dim=-1)
        return 0.7 * out + 0.3 * scores

# usage:

outputs = model.generate(
    input_ids=inputs['input_ids'].to(device),
    attention_mask=inputs['attention_mask'].to(device),
    max_new_tokens=l,
    logits_processor=LogitsProcessorList([
        # inputs_cfg usually is the last token of the prompt but there are
        # possibilities of negative prompting that are explored in the paper
        CFGLogits(cfg, inputs_cfg, model),
        TemperatureLogitsWarper(0.8),
        TopPLogitsWarper(0.95),
    ]),
    do_sample=True,
)

I am not familiar enough with the design guidelines of HF to know if this implementation as a LogitsWarper is satisfactory.

just a few figures supporting the claims: flops image image

image image

Your contribution

I can contribute the code but I need to be guided as I don't know the exact design guidelines and overall architecture of HF.

Thank you for your time!

sgugger commented 1 year ago

cc @gante But let's see if the community requests this added feature before implementing it in the library proper :-)

gante commented 1 year ago

Hey @Vermeille đź‘‹

I have the impression that our MusicGen PR (still open, expected to get merged soon) introduces the bulk of the logic to make it happen -- see this file

It is the same thing with a slightly different code implementation, correct? In the MusicGen PR, the model does a forward pass with 2x the batch size, where half of the batch corresponds to the unprompted tokens

Vermeille commented 1 year ago

Indeed @gante !

I don't fully get how the 2x batch size thing works, but if it does, it's cool. The paper makes some more additions to that base implementation: 1) the uncond_logits might in fact have a different prompt than the cond_logits, which is commonly called "negative prompt". 2) the comment says "usually at the expense of poorer quality". This can be mitigated with linearly interpolating the cfg scores back with with the initial scores 3) We had better results log_softmaxing both scores before cfg, which normalizes both logits sets to a common "scale".

gante commented 1 year ago

cc @sanchit-gandhi, who's probably better equipped to comment on potential differences :)

sanchit-gandhi commented 1 year ago

Hey @Vermeille - thanks for the comprehensive write-up! Just a clarifying question: in your implementation, how do you construct the token ids for the model based on the conditional ids and the un-conditional ones? You mention:

inputs_cfg usually is the last token of the prompt but there are

Which suggests you concatenate them together in the same batch item?

In MusicGen (and also the HF Diffusers library for models like Stable Diffusion), we construct our input ids by concatenating the input ids for the conditional prompt and the un-conditional prompt along the batch dimension (dim=0):

input_ids = torch.concatenate([conditional_ids, unconditional_ids], dim=0)

This is what's referred to by the 2x batch size 'trick' (concatenating the conditional prompt and unconditional prompt over the batch dim). There's no restriction to how these unconditional ids are formed - they can be from a 'null' input, or from a negative prompt. So we can do negative prompting in exactly the way you've described.

When we run our model forward, the logits for the first half of the batch corresponds to the conditional prompt, and the second half to the unconditional prompt (or negative prompt if we use one).

By splitting along the batch dim, we can partition the conditional logits and the unconditional ones:

conditional_logits, unconditional_logits = torch.split(logits, batch_size // 2)

-> we then perform our weighted sum over the conditional and unconditional logits for CFG.

Hope that explains how the 2x batch size trick works - would be keen to hear whether this aligns with how you've run CFG in your experiments.

Regarding implementing a new logits processor, we'd probably want to add this new logits processor when the time comes for integrating the model you've worked on into transformers, rather than adding it solely as a standalone logits processor. transformers is less of a modular toolbox for building new models, more a library for housing the most popular OS ML models

Have you trained a new model that uses this processor? Or built on-top of an existing one? (if it's the latter, then adding the CFG logits processor standalone makes sense, otherwise let's integrate it all in one go)

Vermeille commented 1 year ago

Thank you for your detailed answer @sanchit-gandhi !

The part I'm the most unclear with regarding the 2x batch trick is how the sampling happen. Do you actually sample the same continuation token for the conditional and unconditional branch, or do they diverge in their own direction (which would be weird imho)?

Regarding the integration, there is no need to train models to support CFG, it works out of the box. The paper will be out in few days, but as you can see on the figures, we employed it with LLaMA models, all Pythias, GPT-2 family, and even GPT4All. We don't train a new model. It's meant to be an addition to the .generate() method that is totally model agnostic and don't need training nor finetuning. Hence the PR with the standalone logits processor :)

Vermeille commented 1 year ago

The paper is out

sanchit-gandhi commented 1 year ago

Maybe this helps!

Pre-processing:

Forward pass:

CFG:

Sampling:

How have you been getting the conditional and unconditional logits in your experiments? Through two forward passes? (one with the conditional inputs and then a second with the unconditional ones). This batch size concatenation trick means you only have to run one forward pass, but with 2x the batch size

The only pain point I see with getting this work in transformers is this batch size change as we go from our forward pass to our sampling loop. But we can add some logic to change the batch size on the fly if we're doing CFG (kind of like we did for MusicGen @gante - we need to trick the forward pass into using 2 * bsz, then the decoder ids to use bsz).

here is no need to train models to support CFG, it works out of the box

Very cool indeed! Would be nice to have this as a standalone PR then as suggested

Vermeille commented 1 year ago

Thank you! Yeah if the cond and uncond prompts gets the same next token sampled, it's good wrt to our experiments! That's how you manage to loop around in the .generate() to grow the continuation token per token and zigzaging between bsz and 2bsz that I'm not 100% clear with. I totally see how it works for one forward pass. Totally an implementation detail :) But apparently that's a new trick you had to implement for MusicGen too so it makes sense that I'm not perfectly clear with that.

Would be nice to have this as a standalone PR then as suggested

I'm happy to address the changes that have to be made to contribute this into the lib :)

sanchit-gandhi commented 1 year ago

Awesome - feel free to open a PR and tag myself and @gante! How do you do it without the 2x batch size trick? Do you do two forward passes? Just asking in case there's a simpler way we can integrate this!

gante commented 1 year ago

(catching up on the paper and thinking a bit about usage experience -- will comment tomorrow with specific suggestions, but I think @Vermeille's suggested implementation above will be pretty close to a great user experience with minimal compute overhead)

alex2awesome commented 1 year ago

here is an alternative implementation we used for some of our other experiments in the paper, for your consideration.

it was designed with huggingface's typical *ModelFor* code-style in mind, which just puts the base model in the init and extends the forward() method https://github.com/Vermeille/lm-evaluation-harness-cfg/blob/cfg-alex/log_logits_on_p3.py#L30-L97

Vermeille commented 1 year ago

Awesome - feel free to open a PR and tag myself and @gante! How do you do it without the 2x batch size trick? Do you do two forward passes? Just asking in case there's a simpler way we can integrate this!

Yes. Two consecutive passes. Which is indeed not that great wrt latency.

elikoga commented 1 year ago

Would be great to have both the 2x batch size and two forward passes. Since 2x batch size is better for throughput but the two forward passes are much better for VRAM usage, as the Paper outlines

(unless I missunderstood)

Vermeille commented 1 year ago

So given you already have this ( https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L1070 )

What do you want me to add / change in the PR?

StellaAthena commented 1 year ago

Would be great to have both the 2x batch size and two forward passes. Since 2x batch size is better for throughput but the two forward passes are much better for VRAM usage, as the Paper outlines

(unless I missunderstood)

This is correct: our focus was on getting the best results for a fixed amount of VRAM in our experiments. Hence it didn't occur to us to simply 2x the batch size. I agree that having this be togglable is a good idea and don't have any preference about the default.

drdaxxy commented 1 year ago

The application to LLMs seems more of a situational sampling technique. With smaller conditional generative models like MusicGen, trained from-scratch with (explicit) condition dropout, it's practically part of the model. MusicGen isn't the first AR Transformer here, last year's DALL-E Mega already did it (itself inspired by https://twitter.com/RiversHaveWings/status/1478093658716966912 ), and in these models it's essential for performance.

So I'd expect "batch size 1 dramatically underutilizes available resources" to be the more common case.

Since 2x batch size is better for throughput but the two forward passes are much better for VRAM usage, as the Paper outlines

Depending on model and hardware, "biggest batch size that fits" isn't necessarily optimal. On decent hardware, you can hit optimal compute utilisation before VRAM limits with batched inference in smaller models.


Normalizing the summands, then interpolating with the original scores is intriguing. If adding this to the CFG implementation that's now in Transformers is still being considered, this would be unexpected as default behavior though. In diffusion models, it's not applicable, and in sequence prediction, I've only seen people combine the unnormalized scores.

Vermeille commented 1 year ago

@drdaxxy

Normalizing the summands, then interpolating with the original scores is intriguing. [...] In diffusion models, it's not applicable

This is a technique we borrowed from Common Diffusion Noise Schedules and Sample Steps are Flawed they call CFG Rescale. You can see Imagen doing some normalizing trick too.

in sequence prediction, I've only seen people combine the unnormalized scores.

That's what we started with, and our results were a little bit worse.

gante commented 1 year ago

This method is interesting to implement from an engineering and maintenance point of view!

The simplest approach would be to proceed as @Vermeille suggested: add a logits processor that calls a model forward pass for the unconditional part of the input. It would be a small self-contained piece of code, which means low long-term maintenance on our end. On the negative side, we have the 2x latency, which is more impactful than the extra VRAM (IMO).

If we go the 2x batch size route, we need to implement a function like greedy_search or sample -- a long function with non-negligible maintenance costs on our end. I believe this would be the best form of CFG sampling. However, we are severely constrained by our ability to keep the machine up and running at a good pace, so we can quickly add new features like CFG sampling :D

We have a plan to reorganize generate such that it is entirely made of small functions, making it much more composable. In the way I'm envisioning it, the 2x batch size version of CFG sampling would need a few extra lines of code, as opposed to a new large function.

How about we go with @Vermeille's proposal now, which will make CFG sampling available this week with low overhead on our end, and we implement the 2x batch size version after the generate refactor is complete? The new logits processor class would need a different name, as we already have ClassifierFreeGuidanceLogitsProcessor for the 2x batch size case (perhaps UnbatchedClassifierFreeGuidanceLogitsProcessor?)

Vermeille commented 1 year ago

Expect a PR in few hours.

Thank you for your interest and answers!

Vermeille commented 1 year ago

@gante There is a name clash for the arguments to .generate(). For this PR, unless instructed otherwise before I submit it, cfg_scale (mine) will live next to guidance_scale (MusicGen's). Idk how to resolve this competition, give that .generate() does not seem ready to use the 2x batch trick yet.

gante commented 1 year ago

@Vermeille Adding more (and partially redundant) parameterization is highly undesirable, and we'd want to favor the more general case (yours). You also have the additional requirement of renormalizing the logits before applying your logits processor. Fortunately, we haven't officially released a transformers version with MusicGen, so we still have some wiggle room!

Let's try to fit everything together -- here's my suggestion:

This way the two strategies can coexist, share the argument, and not clash 🤗

Vermeille commented 1 year ago

Great! Thank you for the walkthrough.

On it.

Vermeille commented 1 year ago

Wait @gante, integrating it after the LogitNormalization is not something we want: all the prior processing (temperature, top_p, etc), will be used only on the conditional branch and not the unconditional, and will be executed before computing the CFG logits. To be fair, we haven't tested this transformation order, but being asymmetrical like this scares me.

And this is is even invalid. Top-k/p may not even select the same tokens in both branches, so that will misbehave.

I'm afraid I can't do that. CFG has to happen as one of the first logitprocessor

gante commented 1 year ago

@Vermeille looking at your code example above, I didn't notice it already had normalization inside the processor. My bad -- feel free to add it as the 1st one :)

(will edit my comment above accordingly, for clarity)

grantCelley commented 1 year ago

So this is the code I got to get it working. It is just a hack but if you want to playwith it just use this code

from transformers import LogitsWarper
import torch
from torch.nn import functional as F

device = 'cpu'
if torch.has_cuda:
    device = 'cuda'

class CFGLogits(LogitsWarper):

    def __init__(self, cfg, inputs, model, verbose=True):
        self.cfg = cfg
        self.inputs = inputs
        self.model = model
        self.out = None
        self.verbose = verbose

    def __call__(self, input_ids, scores):
        if self.cfg == 1:
            return F.log_softmax(scores, dim=-1)
        scores = F.log_softmax(scores, dim=-1)
        if self.out is None:
            self.out = self.model(self.inputs.to(device), use_cache=True)
        else:
            self.out = self.model(input_ids[:, -1:],
                                  use_cache=True,
                                  past_key_values=self.out.past_key_values)
        unconditional_logits = F.log_softmax(self.out.logits[0][-1:], dim=-1)
        out = self.cfg * (scores - unconditional_logits) + unconditional_logits
        out = F.log_softmax(out, dim=-1)
        return 0.7 * out + 0.3 * scores

from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import LogitsProcessorList, TemperatureLogitsWarper, TopPLogitsWarper

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m")

model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-160m")

prompt = "Salve, dispiculi."
inputs = tokenizer(prompt, return_tensors='pt')
model.to(device)
outputs = model.generate(
    input_ids=inputs['input_ids'].to(device),
    attention_mask=inputs['attention_mask'].to(device),
    max_new_tokens=125,
    logits_processor=LogitsProcessorList([
        # inputs_cfg usually is the last token of the prompt but there are
        # possibilities of negative prompting that are explored in the paper
        CFGLogits(3, inputs['input_ids'], model),
        TemperatureLogitsWarper(0.8),
        TopPLogitsWarper(0.95),
    ]),
    do_sample=True,
)

print(tokenizer.decode(outputs[0]))

This worked on my end

chris-aeviator commented 1 year ago

@grantCelley 's code works for me.

With CFG (pythia 160m)

grafik

Without CFG

grafik

Vermeille commented 1 year ago

@grantCelley @chris-aeviator The line CFGLogits(3, inputs['input_ids'], model), should really be CFGLogits(3, inputs['input_ids'][:, -1:], model),

chris-aeviator commented 1 year ago

thanks for pointing it out, my 30 was a typo, but your prev. code doesnt seem to mention the [:, -1:] ?!

Vermeille commented 1 year ago

@chris-aeviator notice how it uses input_cfg:

        # inputs_cfg usually is the last token of the prompt but there are
        # possibilities of negative prompting that are explored in the paper
        CFGLogits(cfg, inputs_cfg, model),
grantCelley commented 1 year ago

I will change it when I get home from work

cyberfox commented 1 year ago

@Vermeille I'm currently working on ways to implement logits-processing capability in extensions in oobabooga's text-generation-webui (originally improve the OpenAI extension to mirror OpenAI's logprobs and logit_bias) and came across this as part of my changes. (I was trying to understand how to add LogitsProcessor's to Exllama.)

I'd love to implement this as an example of a plugin that adds a logits processor; is it okay if I use the code at the top of this issue for that?

You can see the whole change set so far (I need to split it up a bit and make individual pull requests, and I'll make sure to give clear credit when I do!) here: https://github.com/oobabooga/text-generation-webui/pull/3001/files and the CFG plugin is at the top. The simplest idea is you'd add it with --extensions cfg on the command line. Probably with some warning that it's going to make things slower, but hopefully better.

I'll add some configuration (the cfg hard codes 1.5 right now, as that appeared to be the best point in the paper), but mostly I want to make sure I'm not stepping on any toes.

Vermeille commented 1 year ago

@cyberfox Thank you for your interest! I see your PR reuses as is the code I initially submitted. I strongly advise you to update to the updated code (edit in first post) to have a cleaner and better experience.

chris-aeviator commented 1 year ago

@chris-aeviator notice how it uses input_cfg:

        # inputs_cfg usually is the last token of the prompt but there are
        # possibilities of negative prompting that are explored in the paper
        CFGLogits(cfg, inputs_cfg, model),

The results start to become a bit garbage when I run with inputs['input_ids'][:, -1:]

This is a fine tuned Pythia 12b

vanilla/ no CFG:

Beetle larvae are hard to love because they eat their way through many of our favorite plants. But we should love them anyway—at least for one thing. It’s called “killing with a smile,” and it’s what these little guys do to protect us from predators. Beetle larvae protect themselves from predators by producing a deadly toxin after feeding on Commiphora tree. In Kenya, for example, where the leaf beetle is found, that means they harvest the toxic protein called mimeticotoxin from the leaves and apply it to an arrowhead to kill any large animals that might otherwise eat them. The toxin on a single arrow could kill an antelope.

CFG w. inputs['input_ids'] (wrong usage)

One might think that one species of insects would be another’s food source, but this is true for many organisms. For example, mosquitoes are eaten by birds and bats; moths provide food for hummingbirds; and so forth. In fact, some insects have evolved such self-protection mechanisms to ward off their own destruction by other predators. One remarkable case involves leaf beetles (Chrysomelidae) of southern Africa. These pests of plants like acacia trees feed exclusively on leaves of the Commiphora tree. They produce toxins only when young—and only enough to kill or repel the smallest animals likely to eat them. Thus, they develop a defense mechanism capable of killing an antelope size animal. To deter predators, lar

CFG w. inputs['input_ids'][:, -1:] (correct usage)

Leaf beetles doncrautaiaeare North African desert insects that eat right through beetle doors―leaves damaged or preparing to be browsed upon by Commiphora matabueillevis goats— killing off nearby predators such as antelopes and Cape buffalo. Sure enough, South African Bushmen ancient warriors would aim and shoot these poisoned arrows at grazing animals to defeat themsingleadultantelope.

Vermeille commented 1 year ago

This output is typical of a guidance strength that's too high. You can either reduce it or reduce the rescale_factor. Try cfg 1.5 and if it's still not there, 1.25. Then you can try ramping the guidance strength up while reducing the rescale_factor.

Vermeille commented 1 year ago

@chris-aeviator I very quickly tuned the parameters in the example. It generated:

Today a dragon flew over Paris, France, and flew around London, England.

And he flew into the heavens, to be welcomed, and to be welcomed by his own people. This dragon was a master of disguise, and used his powers to give a great signal to the earth to let him pass through it. It was the first time the dragon had ever been seen flying around the world, and it was the first time he had ever used the magic to fly.

One day the dragon is sitting on the ground, on the ground, with his feet bare. He is about to say something to him, and he is very interested in his own

We can clearly see the effect of the negative prompt. Don't hesitate to fiddle with both values though, they're not set in stone.

grantCelley commented 1 year ago

Here is the updated code I got.

from transformers import LogitsWarper
import torch
from torch.nn import functional as F

device = 'cpu'
if torch.has_cuda:
    device = 'cuda'

class CFGLogits(LogitsWarper):

    def __init__(self, cfg, inputs, model, verbose=True):
        self.cfg = cfg
        self.inputs = inputs[:, -1:]
        self.model = model
        self.out = None
        self.verbose = verbose

    def __call__(self, input_ids, scores):
        if self.cfg == 1:
            return F.log_softmax(scores, dim=-1)
        scores = F.log_softmax(scores, dim=-1)
        if self.out is None:
            self.out = self.model(self.inputs.to(device), use_cache=True)
        else:
            self.out = self.model(input_ids[:, -1:],
                                  use_cache=True,
                                  past_key_values=self.out.past_key_values)
        unconditional_logits = F.log_softmax(self.out.logits[0][-1:], dim=-1)
        out = self.cfg * (scores - unconditional_logits) + unconditional_logits
        out = F.log_softmax(out, dim=-1)
        return 0.7 * out + 0.3 * scores

from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import LogitsProcessorList, TemperatureLogitsWarper, TopPLogitsWarper

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m")

model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-160m")

#This is the prompt to be the actual auto complete
#Means "Hello, Students. Now"
prompt = "Salve, dispiculi. Nunc"
"This is what it will go towrds"
cfg_prompt = "Latin"

inputs = tokenizer(prompt, return_tensors='pt')
cfg_inputs = tokenizer(cfg_prompt, return_tensors='pt')
model.to(device)
outputs = model.generate(
    input_ids=inputs['input_ids'].to(device),
    attention_mask=inputs['attention_mask'].to(device),
    max_new_tokens=125,
    logits_processor=LogitsProcessorList([
        # inputs_cfg usually is the last token of the prompt but there are
        # possibilities of negative prompting that are explored in the paper
        CFGLogits(1.5, cfg_inputs['input_ids'], model),
        TemperatureLogitsWarper(0.8),
        TopPLogitsWarper(0.95),
    ]),
    do_sample=True,
)

print(tokenizer.decode(outputs[0]))

The output is:

Salve, dispiculi. Nunc conciliat est.

Lacchius autem erat et. fructus.

Ut esse,

Quod

.

Hecce,

In quem

um

et

o

de

quattu

.

Merci,

Ai,

Besci sunt

Cum

um

in

.

Dic quam

.

In

.

Merci,

Ai,

Bos.
grantCelley commented 1 year ago

I just saw the updated code from @Vermeille

Vermeille commented 1 year ago

@grantCelley Pythia models are trained on English. I'm really confused by what you're trying to achieve there.

grantCelley commented 1 year ago

I was just trying to get it to work. Also it does continue in latin for a little which is interesting then goes into a romance language. But it just showed how to do it. I didn't realize that you updated the original codeblock.

chris-aeviator commented 1 year ago

@Vermeille

This output is typical of a guidance strength that's too high. You can either reduce it or reduce the rescale_factor. Try cfg 1.5 and if it's still not there, 1.25. Then you can try ramping the guidance strength up while reducing the rescale_factor.

Ok this helped, generation for the same amount of tokens takes longer now, is this expected?

Vanilla / no CFG, 512 token / 3 min

Fungus-growing ants (Myrmecocystus ants) farm their own food, feed larvae and queen with it, and maintain the garden through vigorous harvesting and composting practices. Like farmers who grow vegetables, they plant crops, harvest, store, and distribute the fruits of the harvest to themselves and their dependents.

CFG, neg_token = last token, cfg_scale=1.5, 512 token / 5 min

Leaf-cutting ants (or “antlion ants”) are found in tropical regions around the world. They farm fungi to eat as adults and feed their larvae. Fungi provide food not only for adult ants but also for the gardens that they maintain across vast distances in the form of fungus farms. These farms can contain tens of thousands of acres of fungal colonies. The largest known leaf-cutting ant fungus farm has over 20,000 colonies with a total area of nearly 2 million square meters (2 hectares).

CFG, neg_token = last token, cfg_scale=1.25, 512 token / 5 min

Leaf-cutting ants (or “antlion ants”) are found in tropical regions around the world, where they farm fungus to feed their young. Fungus farming has been observed in several ant species, including the Acromyrmex octospinosus ant, endemic to South and Central America and the southern United States. Farmers remove leaves from native plants and chew them into small pieces, which they place directly onto the soil around newly established colonies. The leaf fragments provide food for the larvae inside the nest as well as for the colony’s queen.

chris-aeviator commented 1 year ago

@grantCelley shouldnt a negative prompt of 'Latin' prohibit latin output? Do I misunderstand the concept of negative prompts?

Vermeille commented 1 year ago

@chris-aeviator

Ok this helped, generation for the same amount of tokens takes longer now, is this expected?

Yes, there are two forward passes per token now.

@grantCelley shouldnt a negative prompt of 'Latin' prohibit latin output? Do I misunderstand the concept of negative prompts?

You are correct

FartyPants commented 1 year ago

@grantCelley shouldnt a negative prompt of 'Latin' prohibit latin output? Do I misunderstand the concept of negative prompts?

It is hard to say what negative prompt does in certain terms. I had it generate a poem and specified negative prompt as happy and it used somehow gloomy language and vice versa - so it "does" work, but beyond that I think only further experimentation will tell. It does affect the output, but not too dramatically. In the paper they put negative prompt the system prompt the model was trained with... not sure about the reasoning for that.

Vermeille commented 1 year ago

It is hard to say what negative prompt does in certain terms. I had it generate a poem and specified negative prompt as happy and it used somehow gloomy language and vice versa - so it "does" work, but beyond that I think only further experimentation will tell.

Yes. Neg prompts in language are somewhat harder to pull off than in vision. Especially because the continuation should be kinda grammatical with the neg prompt too. Not gonna lie, we were under time constraints and having a clear neg prompt methodology was unsolved in that time frame. But we're working on it, and the example in the first post works.

It does affect the output, but not too dramatically.

Hard to say yet, but it should depend on the guidance strength (decrease the rescale_factor as you increase the guidance strength)

In the paper they put negative prompt the system prompt the model was trained with... not sure about the reasoning for that.

from the paper:

We set the negative prompt to be the default system-prompt for the models we use [...] This approach not only makes the sampling more prompt-aware in general, but directly emphasizes the difference between our system-prompt and the model’s default system-prompt.

FartyPants commented 12 months ago

Thanks for explanation. I can't wait to see more about this. It is really fascinating trying to see how the code works - it is quite something! A little question for my own curiosity. (referring to the code on top of the page) In case I'm using it with a full negative prompt, wouldn't be better in the very first round (when self.out==None) to do just out = self.guidance_scale * (scores - unconditional_logits) (not to add the + unconditional_logits - which are logits of the negative prompt in the first round? Then of course continue normally as t is in next rounds. Or it doesn't really matter? Or I don't understand the nuance of the code (which I of course don't)?

cyberfox commented 12 months ago

In case I'm using it with a full negative prompt, wouldn't be better in the very first round (when self.out==None) to do just out = self.guidance_scale * (scores - unconditional_logits) (not to add the + unconditional_logits - which are logits of the negative prompt in the first round? Then of course continue normally as t is in next rounds. Or it doesn't really matter? Or I don't understand the nuance of the code (which I of course don't)?

Because the key_values are passed back in each time, the negative side of the prompting (the second self.out) will always contain the prediction based on the entire negative context so far. Even if you did it from the second token generation, you'd still be adding in the negative prompt's effect on the generated 'unconditional' logits...

This implies to me that the two things should be separate concepts, with separate implementations...but if you (reasonably) wanted to use both focus-on-first-prompt, and negative prompting, it would be compute expensive to do them separately.

That said, I do feel a little like the 'adding them back in' is a fudge-factor, trying to reduce the effect slightly. But I don't understand the math symbology in the original paper very well, so I'm very cautious about that.

github-actions[bot] commented 11 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

gante commented 11 months ago

It is merged, feel free to install from main and play with it :)

sersoage commented 10 months ago

Can you provide sample code on how to use classifier free guidance?