huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.72k stars 26.94k forks source link

UMT5 incredibly slow in generating #26144

Open BramVanroy opened 1 year ago

BramVanroy commented 1 year ago

System Info

Who can help?

@ArthurZucker and @younesbelkada and @gante

Information

Tasks

Reproduction

import time

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, GenerationConfig

if __name__ == "__main__":
    timings = {}

    for model_name in ("facebook/mbart-large-50-many-to-one-mmt", "google/umt5-small"):
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map={"": "cuda"})
        print(model_name, model.num_parameters())
        # google/umt5-small                        306601984
        # facebook/mbart-large-50-many-to-one-mmt 1122990080
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        gen_config = GenerationConfig.from_pretrained(
            model_name,
            max_new_tokens=200,
            max_length=None,
            num_beams=1,
        )
        text = "I would really like to eat some cookies now."
        if "t5" in model_name:
            text = f"translate English to Dutch: {text}"

        encoded = tokenizer(text, return_tensors="pt")
        encoded = {k: v.to(model.device) for k, v in encoded.items()}
        start_time = time.perf_counter_ns()
        for _ in range(100):
            _ = model.generate(**encoded, generation_config=gen_config)

        timings[model_name] = time.perf_counter_ns() - start_time

    for model_name, timings in timings.items():
        print(f"Generation duration for {model_name.split('/')[1]}:\t{timings}")
        # Generation duration for mbart-large-50-many-to-one-mmt:  22413427363
        # Generation duration for umt5-small:                     207906791077

So despite UMT5-small having only about 27% the number of parameters of the MBART-large model it is 9-10x slower!

(I also tried with a gc.collect() after each generation.)

Expected behavior

Faster inference/generation speed. Training is fine so I assume caching of past states is not (correctly) implemented but I might be wrong. This PR on adding caching to T5 by @patrickvonplaten might be related: https://github.com/huggingface/transformers/pull/3682

ArthurZucker commented 1 year ago

Hey Thanks for reporting I'll investigate! Not sure why you would need to run 100 iterations of the generate method this way, but for one generation:

For 10 iterations:

so not sure if this is simply a bug in the time logging?

BramVanroy commented 1 year ago

The same difference (around 9-10x) can be observed when leaving out the prefix for umT5.

ArthurZucker commented 1 year ago

also you can do encoded = tokenizer(text, return_tensors="pt").to(model.device) 😉

ArthurZucker commented 1 year ago

I was not using device map = auto, might be the _no_split_module difference

BramVanroy commented 1 year ago

I can't reproduce your results. Do you have accelerate installed? Can you share your transformers-cli env? @ArthurZucker

ArthurZucker commented 1 year ago

I can reproduce your results using you exact script so will investigate anyway !

BramVanroy commented 1 year ago

Even without device_map and your other suggested changes, I still see a noticeable difference. A 4x difference for one generation. Attaching for reproducibility.

import gc
import time
from tqdm import trange
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, GenerationConfig

if __name__ == "__main__":
    timings = {}

    for model_name in ("facebook/mbart-large-50-many-to-one-mmt", "google/umt5-small"):
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        model = model.to("cuda")
        print(model_name, model.num_parameters())
        # google/umt5-small                        306601984
        # facebook/mbart-large-50-many-to-one-mmt 1122990080
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        gen_config = GenerationConfig.from_pretrained(
            model_name,
            max_new_tokens=200,
            max_length=None,
            num_beams=1,
        )
        text = "I would really like to eat some cookies now."

        encoded = tokenizer(text, return_tensors="pt").to(model.device)
        start_time = time.time()
        model.generate(**encoded, generation_config=gen_config)
        timings[model_name] = time.time() - start_time

    for model_name, timings in timings.items():
        print(f"Generation duration for {model_name.split('/')[1]}:\t{timings}")
        # Generation duration for mbart-large-50-many-to-one-mmt:  0.4059898853302002
        # Generation duration for umt5-small:                      1.7038893699645996
ArthurZucker commented 1 year ago

Yep, the faulty function is compute_bias which if you remove it (so use a default positional bias) you have the same performances. This is kind of expected, but we should definitely try to use a caching mechanism for this

ArthurZucker commented 1 year ago

You should compare with mt5 or t5 which have similar architecture, bias etc 😉

BramVanroy commented 1 year ago

New snippet! @ArthurZucker It compares umt5, t5, and mt5. It seems umt5 is definitely the outlier here

import gc
import time
from statistics import mean

from tqdm import trange
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, GenerationConfig

if __name__ == "__main__":
    timings = {}
    sizes = {}
    for model_name in ("facebook/mbart-large-50-many-to-one-mmt", "google/umt5-small", "google/mt5-small", "t5-small"):
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        sizes[model_name] = model.num_parameters()
        model = model.to("cuda")
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        gen_config = GenerationConfig.from_pretrained(
            model_name,
            max_new_tokens=200,
            max_length=None,
            num_beams=1,
        )
        text = "I would really like to eat some cookies now."

        encoded = tokenizer(text, return_tensors="pt").to(model.device)
        timings[model_name] = []
        for _ in trange(10):
            start_time = time.time()
            model.generate(**encoded, generation_config=gen_config)
            timings[model_name].append(time.time() - start_time)

        timings[model_name] = mean(timings[model_name])

    baseline_time = timings.pop("facebook/mbart-large-50-many-to-one-mmt")
    baseline_size = sizes.pop("facebook/mbart-large-50-many-to-one-mmt")

    print("Comparing to facebook/mbart-large-50-many-to-one-mmt:")
    for model_name, timing in timings.items():
        size = sizes[model_name]
        print(f"{model_name}:\t {(timing*100/baseline_time):.2f}% (relative speed); {(size*100/baseline_size):.2f}% (relative size)")

Output (relative to facebook/mbart-large-50-many-to-one-mmt):

google/umt5-small:        975.10% (relative gen. time); 50.19% (relative size)
google/mt5-small:        12.80% (relative gen. time); 49.14% (relative size)
t5-small:        33.08% (relative gen. time); 9.90% (relative size)

Interesting to me that mt5 is so fast given it's size.

BramVanroy commented 1 year ago

For ease-of-access:

All of these look the same to me, however. (umT5 simplifies it a little bit by getting attributes directly from self but apart from that they seem the same. So I am not sure if that is the cause? @ArthurZucker

ArthurZucker commented 1 year ago

UMT5 uses a bias for each layer that is not shared vs shared in other models.

ArthurZucker commented 1 year ago

We should pre-compute all the positional bias wrt to the max sequence length of the model, cache it and only fetch the ones we need! Same for T5 but it's already pretty fast. Will open a PR !

ArthurZucker commented 1 year ago

(sorry)

BramVanroy commented 1 year ago

Ah, sorry yes you are absolutely right!

https://github.com/huggingface/transformers/blob/7ccac73f749ce535851b9188f3867d5ed87c318c/src/transformers/models/umt5/modeling_umt5.py#L322

Would be great if this could be cached indeed. That would make my current research a lot more feasible!

ArthurZucker commented 1 year ago

Oups, sorry did not have time to work more on the fix!

ArthurZucker commented 1 year ago

It's not on my priority list so if anyone wants to take over the PR feel free to do so!

xiaoyunwu commented 3 months ago

what is the status for this?

ArthurZucker commented 3 months ago

26150 is a draft PR, feel free to take it over!