huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.82k stars 27.19k forks source link

Disable @torch.no_grad() for model.generate() ? #3720

Closed Laksh1997 closed 4 years ago

Laksh1997 commented 4 years ago

❓ Questions & Help

Is there any way to do this?

Laksh1997 commented 4 years ago

At the moment the only solution seems to be copying and pasting the entire generation code, as well as making a few changes that comes along with it, to avoid this issue.

Laksh1997 commented 4 years ago

One solution I propose is to add an argument with_grad which defaults to False. Then, add this as the first line in the generate code:

def generate(...):
   torch.set_grad_enabled(with_grad)
   ...

This will be backward-compatible.

patrickvonplaten commented 4 years ago

Being able to back-prop through the generate() fn would require a lot of changes in my opinion. Not sure whether we plan on doing this any time soon. If you find a good way, feel free to open a PR though :-)

Laksh1997 commented 4 years ago

Hi Patrick, yes I understand it's complicated.

Here is a snippet that explains how it may work:

import torch
import torch.distributions as dist

def generate_and_trace_log_probs(
model, batch_size=32, max_len=100, top_k=0, top_p=1.0, bos_id=1, eos_id=2
):

    initial_pool = torch.full(
        size=(batch_size, 1),
        fill_value=bos_id,
        dtype=torch.long,
        device=next(model.parameters()).device,
    )
    past_tokens = initial_pool
    current_tokens = initial_pool
    log_probs = []
    past_attention_computation = None

    for i in range(max_len - 1):

        # Forward prop through model
        outputs = model(
            input_ids=current_tokens, past=past_attention_computation
        )

        # Extract logits for sampling next tokens
        logits = outputs[0]

        # Top-p and/or top-k filtering
        if top_k > 0 or top_p < 1.0:
            logits = top_k_top_p_filtering(
                logits.squeeze(1), top_k=top_k, top_p=top_p, min_tokens_to_keep=1
            ).unsqueeze(1)

        # Extract attention computations to cache
        past_attention_computation = outputs[1]

        # Sample logits
        catdist = dist.Categorical(logits=logits)
        next_tokens = catdist.sample()

        # Compute and store log probs for REINFORCE
        log_prob = catdist.log_prob(next_tokens)
        log_probs.append(log_prob)

        # Update input into LM
        current_tokens = next_tokens

        # Store tokens for reward computation
        past_tokens = torch.cat([past_tokens, current_tokens.detach()], dim=-1)

        # Check if all examples have had an EOS token - if so, break
        if past_tokens.eq(eos_id).any(dim=-1).all():
            break

    log_probs = torch.cat(log_probs, dim=-1)

    # For tokens that came after the EOS token, mask their log prob
    for idx, ex in enumerate(past_tokens):
        eos_idx = torch.where(ex.eq(eos_id))[0].min()
        log_probs[idx, eos_idx + 1 :] = -1e4

    return log_probs, past_tokens

def top_k_top_p_filtering(
    logits: torch.Tensor,
    top_k: int = 50,
    top_p: float = 0.95,
    min_tokens_to_keep=1,
    filter_value=-float("Inf"),
):
    """Add torch.no_grad() for steps that unnecessarily trace gradients"""
    if top_k > 0:
        with torch.no_grad():
            top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))  # safety check
            indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p < 1.0:
        with torch.no_grad():
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

            # Remove tokens with cumulative probs above threshold (token with 0 kept)
            sorted_indices_to_remove = cumulative_probs > top_p
            if min_tokens_to_keep > 1:
                # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
                sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
            # Shift the indices to the right to keep also the first token above the threshold
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
                ..., :-1
            ].clone()
            sorted_indices_to_remove[..., 0] = 0

            # scatter sorted tensors to original indexing
            indices_to_remove = sorted_indices_to_remove.scatter(
                1, sorted_indices, sorted_indices_to_remove
            )
        logits[indices_to_remove] = filter_value

    return logits
patrickvonplaten commented 4 years ago

@Laksh1997 - thanks for the code snippet :-) If you think you are able to make a PR that can pass the tests, I think we would be more than happy to add this to the lib!

aced125 commented 4 years ago

Okay, will try...

aced125 commented 4 years ago

@patrickvonplaten Have edited the code (only had to make a few changes to enable this capability!) and ran the tests (369 pass, 808 skip, 10 warnings).

I'm trying to push a new branch but getting access denied.

Laksh1997 commented 4 years ago

@patrickvonplaten that's my other account ...

Laksh1997 commented 4 years ago

I'm reading the instructions now on how to contribute ...

aced125 commented 4 years ago

Done a PR... @patrickvonplaten

stale[bot] commented 4 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

abhishek0318 commented 3 years ago

One could use model.greedy_search if they wan't to backpropogate through the generation process. This worked for me.

JoaoLages commented 2 years ago

greedy_search

model.greedy is not working correctly, at least for T5.

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained('t5-small')
tokenizer = AutoTokenizer.from_pretrained('t5-small')
model.greedy_search(**tokenizer("I love HuggingFace", return_tensors='pt'))

I get the following error with the code above:

  File "/home/joaolages/.venv/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 930, in forward
    raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
ValueError: You have to specify either input_ids or inputs_embeds

I even tried calling greedy_search as suggested in here, but this creates different outputs compared to calling model.generate with num_beams=1, which shouldn't, right?

patrickvonplaten commented 2 years ago

@JoaoLages, you need to also add encoder_outputs to generate when using it on encoder-decoder models such as T5. This should work:

#!/usr/bin/env python3
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

model = AutoModelForSeq2SeqLM.from_pretrained('t5-small')
tokenizer = AutoTokenizer.from_pretrained('t5-small')

input_ids = tokenizer("Translate English to German: Today is a nice day.", return_tensors="pt").input_ids
encoder_outputs = model.encoder(input_ids)

decoder_input_ids = torch.ones_like(input_ids)[:, :1] * model.config.decoder_start_token_id
model_kwargs = {"encoder_outputs": encoder_outputs}

sequences = model.greedy_search(decoder_input_ids, **model_kwargs)

print("Output:", tokenizer.batch_decode(sequences))
# => prints `['<pad> Heute ist ein schöner Tag.</s>']

I do very much admit though that this is too complicated and it also took me a bit. @JoaoLages think we need to improve our docs here no?

JoaoLages commented 2 years ago

Thanks!

I do very much admit though that this is too complicated and it also took me a bit. @JoaoLages think we need to improve our docs here no?

I think it would be simpler to change T5ForConditionalGeneration.greedy_search to have this code inside it, so that we could simply call model.greedy_search(input_ids)

patrickvonplaten commented 2 years ago

Sorry also meant to ping @gante here

gante commented 2 years ago

@patrickvonplaten Trying to understand the problem -- am I right in saying that we want to use the generation methods directly for backpropagation purposes (because generate() won't work there), and thus we need to document their proper use (because generate() does a lot of input preparation)?

patrickvonplaten commented 2 years ago

Good point!

I think my idea back when we added the sub-methods was to push the community more to use those directly instead of the more "magic" .generate() function. The reason being because it's harder and harder to cover every use case in generate() where as the sub methods are very "bare-bone" without any magic which means that if one knows how to use them they can more or less cover every use case. Now, that failed a bit I think because 99.9% people just use generate(...), probably because of how difficult it is to understand and use the sub methods directly (as shown here: https://github.com/huggingface/transformers/issues/3720#issuecomment-1235775528 <- that's too difficult to understand/know).

So just pinged you here to be aware of this and was wondering whether it could make sense to think about providing better guides for the sub-method, maybe even changing the submethods or continue to not spend much time on them. Don't think it's an urgent thing to think about though!

JoaoLages commented 2 years ago

@patrickvonplaten @gante At least these docs should be updated with the code that @patrickvonplaten shared in here

pedrocolon93 commented 2 years ago

Just a heads up that I think some of these methods (if you want a continuous gradient) might have to use the softmax trick: https://datascience.stackexchange.com/questions/58376/gumbel-softmax-trick-vs-softmax-with-temperature to get a differentiable final next token. At least when I checked this out a while back that seemed to be the case but ¯_(ツ)_/¯

abarbet commented 1 year ago

Using the approach above with greedy_search and a T5 model, I'm still not seeing a grad_fn associated with the output logits. Was anyone able to get this working with a T5 architecture?

JoaoLages commented 1 year ago

Using the approach above with greedy_search and a T5 model, I'm still not seeing a grad_fn associated with the output logits. Was anyone able to get this working with a T5 architecture?

In order to get the gradient per step, you need to do the greedy decoding on your own. Try using model.forward instead to get the gradient and the next token, then you need to concatenate that generated token with the decoder_input_ids and repeat the process.

If you want to test this fast, you can use the ecco package that I've helped build. It has logic for doing this gradient calculation for any kind of sampling approach (greedy/beam/sample/etc) and for models like T5 or GPT. It is not very optimized in terms of inference times though, I must warn you.

abarbet commented 1 year ago

@JoaoLages that is a really helpful starting point, thank you! I'm not sure I see a beam search sampling process in the code (but perhaps I'm looking in the wrong place). I do see a TODO in sample_output_token to add beam search in the future.

JoaoLages commented 1 year ago

There isn't beam search yeah. What we actually do is that we use the normal model.generate method to use beam search, and then we feed the generated tokens through the model to calculate their gradient. So we actually do the generation step 2 times, but in the second we capture the gradients. It's slow, but it could be optimized if we did our custom beam search.