salesforce / GeDi

GeDi: Generative Discriminator Guided Sequence Generation
https://arxiv.org/abs/2009.06367
BSD 3-Clause "New" or "Revised" License
208 stars 47 forks source link

top_k_top_p_filtering function missing in modeling_utils.py #4

Closed ktrapeznikov closed 3 years ago

ktrapeznikov commented 3 years ago

so whenrun_generation script is used with do_sample, get an error because top_k_top_p_filtering is missing.

So I just added the function (from https://github.com/huggingface/transformers/blob/master/src/transformers/generation_utils.py)


def top_k_top_p_filtering(
    logits: Tensor,
    top_k: int = 0,
    top_p: float = 1.0,
    filter_value: float = -float("Inf"),
    min_tokens_to_keep: int = 1,
) -> Tensor:
    """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
    Args:
        logits: logits distribution shape (batch size, vocabulary size)
        if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
        if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
            Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        Make sure we keep at least min_tokens_to_keep per batch example in the output
    From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    if top_k > 0:
        top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))  # Safety check
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs > top_p
        if min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
            sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        logits[indices_to_remove] = filter_value
    return logits```
akhileshgotmare commented 3 years ago

Thanks for pointing out!

GeDi heuristics on combining the generative classifier outputs with LM logits were designed with greedy decoding in mind, so the experiments in our paper were all done with greedy decoding (Section 3.1.1), hence we don't support sampling as of now. Ensuring that sequences generated with sampling carry the desired attribute (eg. positive sentiment) will likely need some tuning of the decoding hyper-parameters (\omega, \rho, n) and some algorithmic changes.

Were you able to generate reasonable sequences with sampling using GeDi-guided generation?

ktrapeznikov commented 3 years ago

Yeah. The output looks reasonable. It works pretty well for "topic" mode. I set the temperature = 1 and increase disc_weight=50. The class probability estimate is pretty high usually (in the 90s). Slightly worse for "sentiment" mode.

Looking at the generate code, the sampling is applied to the next_logit_prob after they have been modulated by the GeDi probabilities.

akhileshgotmare commented 3 years ago

@ktrapeznikov Do you want to create a pull request with this addition to the modeling_utils.py file? Or I could make the changes in a commit.

akhileshgotmare commented 3 years ago

top_k_top_p_filtering added!

ktrapeznikov commented 3 years ago

Awesome. Thanks. Somehow I missed your previous comment.

xiximiyi commented 1 year ago

ImportError: cannot import name 'top_k_top_p_filtering' from 'transformers.generation_utils'

ashokchhetri7 commented 6 months ago

ImportError: cannot import name 'top_k_top_p_filtering' from 'transformers.generation_utils'

So, in the previous code, the top_k_top_p_filtering was imported from the transformers.generation_utils. I changed underscore to . as shown in following and it solved my problem:

from transformers.generation.utils import top_k_top_p_filtering

Also you can downgrade, and it will run too.

pip install transformers==4.36.2