huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.05k stars 26.81k forks source link

(TF) model.generate to tf.function for tf serving #16823

Closed piEsposito closed 2 years ago

piEsposito commented 2 years ago

Feature request

It would be nice if you wrapped the generate method of autorregressive models into a tf.function. That way we could export and serve it with all the Tensorflow production stack.

Its kinda a revival of #5443.

It would enable us to do something like:

from transformers import AutoTokenizer, TFAutoModelForCausalLM
import tensorflow as tf

model = TFAutoModelForCausalLM.from_pretrained("gpt2")
model.save(
    "some_place",
    signatures={
        "serving_default": model.generate.get_concrete_function(tf.TensorSpec([None, None], tf.int32))
    }
)

And then serve it on TF production stack.

Motivation

It would be nice if you wrapped the generate method of autorregressive models into a tf.function. That way we could export and serve it with all the Tensorflow production stack.

It is frustrating to have to write generate by hand or move to PyTorch to serve generative language models.

Your contribution

I could write a PR, thou it would be nice if HF could share what they have done when trying it, as @Rocketknight1 and @patrickvonplaten said in : https://github.com/huggingface/transformers/issues/5443#issuecomment-1020067525_ , so I would have somewhere to go from.

patrickvonplaten commented 2 years ago

Hey @piEsposito,

The function should now be useable with tf.function I think. We don't want to wrap generate tf.function automatically ourselves, but you should be able to do the following now:

#!/usr/bin/env python3
from transformers import TFGPT2LMHeadModel, GPT2Tokenizer
import tensorflow as tf

physical_devices = tf.config.list_physical_devices('GPU')
for device in physical_devices:
    tf.config.experimental.set_memory_growth(device, True)

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = TFGPT2LMHeadModel.from_pretrained("gpt2")

input_ids = tokenizer("hello there can you continue", return_tensors="tf").input_ids

xla_generate = tf.function(model.generate, jit_compile=True)
outputs = xla_generate(input_ids)

print("Output", tokenizer.batch_decode(outputs))
patrickvonplaten commented 2 years ago

cc @gante

gante commented 2 years ago

Hey @piEsposito 👋 As @patrickvonplaten mentioned, we have some generation functionality that can be wrapped by tf.function to be highly accelerated -- our tests point at a >30x speedup if an nVidia T4 is used.

The example provided should be functional and XLA-accelerated. However, some advanced features are not yet XLA-compatible, including:

All these should be solved in the next 1-2 months. Keep an eye on our releases, and let us know if you run into problems :)

piEsposito commented 2 years ago

Hey @gante , thanks for the quick reply. Actually, my problem is specifically creating a serving signature that receives an input with variable length so I can use it with TF Serving in production. Do you have anything on that?

gante commented 2 years ago

tf.function has a experimental_relax_shapes argument, which may help there. I can't confirm, as I haven't tested :) An alternative would be to pad all inputs to the maximum length accepted by the model, but that might spend needless memory/computing.

piEsposito commented 2 years ago

@gante thanks. Do you know how can I use the generate method with the fully padded sequences? It always throws an error here :( .

gante commented 2 years ago

Pardon me, I wrote a half-truth above :) For encoder-decoder (aka sequence to sequence) models like T5, you can do as I wrote above. For decoder-only models like gpt-2 you can left-pad to a constant length -- see this test as an example.

piEsposito commented 2 years ago

Sorry, but still when I do pad it to max_length (if we set padding to True it won't pad the max accepted length) it throws me an error:

from transformers import GPT2Tokenizer, TFGPT2LMHeadModel
import tensorflow as tf

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

model = TFGPT2LMHeadModel.from_pretrained("gpt2")

encoded_input = tokenizer([text],
                          return_tensors='tf',
                          padding="max_length")

model.generate(
    encoded_input.input_ids,
    max_length=1024
)

Throws me a:

ValueError: The context has 1024 number of tokens, but `max_length` is only 1024.

And of course I can't set max_length to anything more than 1024.

Am I doing something wrong?

gante commented 2 years ago

The constant length in decoder-only models has to be smaller than max_length (as opposed to encoder-decoder models, where it can be padded to max_length), and the difference between your constant and generate's max_length corresponds to the maximum tokens generate can generate.

piEsposito commented 2 years ago

When I pad and leave a few tokens for new generation, it still won't generate my text, but rather some random stuff after about 1000 eos tokens:

text = "Replace me by any text you'd like."

encoded_input = tokenizer([text],
                          return_tensors='tf',
                          padding="max_length")

preds = model.generate(
    encoded_input.input_ids[:, 50:],
    max_length=1024,
    pad_token_id=tokenizer.pad_token_id
)

tokenizer.batch_decode(preds)

And I get something like

[
"<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>...
Replace me by any text you'd like.\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n"
]

This result stays the same even when I explicitly mask the padded tokens:

preds = model.generate(
    encoded_input.input_ids[:, 50:],
    max_length=1024,
    attention_mask=encoded_input.attention_mask[:,50:]
)

When we try with the same input and do greedy decoding it makes sense.

piEsposito commented 2 years ago

It seems to be related to https://github.com/huggingface/transformers/blob/3104036e7f1a3cd6e07a69d648c3597de32f72fe/src/transformers/models/gpt2/modeling_tf_gpt2.py#L816-L842

Where when we are not passing use_xla=True it will set the attention masks as None.

But it could be something else, as just passing use_xla as True changes the result but won't fix it.

gante commented 2 years ago

@piEsposito it seems like we still have a couple of bugs to fix :D

I'm afraid I can't be of much further help -- I'm actively developing XLA + generate, but I don't expect to be able to sort your particular issue within the next month. The roadmap is approximatelly XLA logits processors -> XLA beam search -> efficient XLA batching (your issue) -> XLA on more models beyond GPT-2 and T5. When all this is sorted, we will make a big announcement and publish some tutorials. Until then, feel free to ping me to query the state of the XLA changes :)

piEsposito commented 2 years ago

@gante if you have an open-sourced branch I would love to help with that generate stuff. If not, thank you for your time and for trying to help me out with this.

gante commented 2 years ago

@piEsposito that would be lovely :)

The step I will work next, as I mentioned above, is to make the logit processors XLA-compatible. In other words, rewrite them such that the tests here pass if you compile the function with tf.function(jit_compile=True). Some of them may already work -- feel free to claim one (or more) for you to work on, excluding the repetition_penalty (which I've already rewrote for XLA in a branch)

piEsposito commented 2 years ago

@gante hacking Tensorflow away to make stuff serializable is kind of a hobby and also is paying my bills for a long time, so I can work on that.

I just need a bit more context:

Thanks, let´s do it.

gante commented 2 years ago

Awesome @piEsposito! I will open a PR today, so you can have an example, and post here a more detailed guide 💪

piEsposito commented 2 years ago

Thanks!

gante commented 2 years ago

@piEsposito This is the PR for an XLA-compatible repetition penalty logits processor. I've just opened it, so I'd suggest waiting until the review process is complete before starting on a new logit processor.

After the PR above gets approved, the process would be:

If you run into issues along the way, let me know. I will let you know here when the PR gets approved, so we can start on the next processors.

gante commented 2 years ago

(The PR got approved and merged. Working on the TFLogitsWarper subclasses now.)

piEsposito commented 2 years ago

(The PR got approved and merged. Working on the TFLogitsWarper subclasses now.)

Let's do it man.

github-actions[bot] commented 2 years ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

gante commented 2 years ago

(beam search being worked on atm, last missing piece)