huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.65k stars 27.16k forks source link

torch.compile: generate should use call instead of forward #34906

Open SilverSoldier opened 4 days ago

SilverSoldier commented 4 days ago

System Info

Who can help?

@ArthurZucker @Cyrilvallez

Information

Tasks

Reproduction

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "facebook/opt-125m"
length = 100

prompt_text = 'In a small, bustling cafe nestled in the heart of a vibrant city, a serendipitous event unfolded, leaving a lasting impression on all who witnessed it. As the patrons sat sipping their coffees and engaging in animated conversations, a talented street musician entered the cafe, carrying a weathered guitar and radiating an aura of creativity.'

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.compile()
input_ids = tokenizer(prompt_text, add_special_tokens=False, return_tensors='pt').input_ids
output = model.generate(input_ids, max_new_tokens=length)

Expected behavior

Expected behaviour is that we use the compiled forward function.

When compiling using the model.compile() API, the call method uses an internal variable with the compiled forward instead of the uncompiled forward.

(I raised a related issue in pytorch, this is the Option 2 there)

So generate, should use the call method instead of the forward to use the compiled version of forward (for this particular case of model.compile). However, recent changes have changed this call to model.forward() instead of model() for the non-first token :

def _sample():
  ...
  def model_forward(model, *args, **kwargs):
      return model.forward(*args, **kwargs)
  ...
      if i == 0:
          outputs = self(**model_inputs, return_dict=True)
          i += 1
      else:
          outputs = model_forward(self, return_dict=True, **model_inputs)

model_forward should be changed to call model() instead of model.forward()

ydshieh commented 4 days ago

Hi, so if I understand correctly, the goal of this change is to make the future model(...) (potentially used outside generate) will use the already compiled version (if that is done in generate here). It's good, but currently code isn't really a bug, right?

SilverSoldier commented 4 days ago

Previous code called model() which works with model.compile(). Current code (from this recent commit) changed this to model.forward(). So current code does not work as expected with model.compile(), it uses compiled version for the first iteration and eager mode for other iterations. This is sort of a bug.

Cyrilvallez commented 4 days ago

Hi @SilverSoldier, thanks for opening the issue! Indeed I agree that we should use __call__, for consistency between all methods. We wanted to refine this part anyway, this was the first shot! I'll take care of it very soon. Curious to see what the pytorch team has to say with regards to the consistency of the different way to compile a model, but I'm not sure this can be solved in general.

Also, we introduced the recent changes because it is inefficient to compile the forward in all cases (we only want to compile the iterative decoding part, not the prefill), so your code would me more efficient if you don't call compile yourself, and only use cache_implementation=static in generate, and let us only compile the iterative decoding part 🤗

ydshieh commented 3 days ago

I see, it's another API of compile! Thanks!

SilverSoldier commented 3 days ago

@Cyrilvallez interesting. However, if generate takes care of the compile how can we pass the compile arguments, say custom backend and other parameters?

ArthurZucker commented 3 days ago

We are adding this via the classic generate_config and generate_kwargs api!