Open SilverSoldier opened 4 days ago
Hi, so if I understand correctly, the goal of this change is to make the future model(...)
(potentially used outside generate
) will use the already compiled version (if that is done in generate
here). It's good, but currently code isn't really a bug, right?
Previous code called model()
which works with model.compile()
. Current code (from this recent commit) changed this to model.forward()
.
So current code does not work as expected with model.compile()
, it uses compiled version for the first iteration and eager mode for other iterations. This is sort of a bug.
Hi @SilverSoldier, thanks for opening the issue! Indeed I agree that we should use __call__
, for consistency between all methods. We wanted to refine this part anyway, this was the first shot! I'll take care of it very soon. Curious to see what the pytorch team has to say with regards to the consistency of the different way to compile
a model, but I'm not sure this can be solved in general.
Also, we introduced the recent changes because it is inefficient to compile
the forward in all cases (we only want to compile the iterative decoding part, not the prefill), so your code would me more efficient if you don't call compile
yourself, and only use cache_implementation=static
in generate
, and let us only compile the iterative decoding part 🤗
I see, it's another API of compile! Thanks!
@Cyrilvallez interesting. However, if generate takes care of the compile how can we pass the compile arguments, say custom backend and other parameters?
We are adding this via the classic generate_config
and generate_kwargs
api!
System Info
transformers
version: 4.47.0.dev0Who can help?
@ArthurZucker @Cyrilvallez
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
Expected behaviour is that we use the compiled forward function.
When compiling using the
model.compile()
API, the call method uses an internal variable with the compiled forward instead of the uncompiled forward.(I raised a related issue in pytorch, this is the Option 2 there)
So generate, should use the call method instead of the forward to use the compiled version of forward (for this particular case of model.compile). However, recent changes have changed this call to model.forward() instead of model() for the non-first token :
model_forward should be changed to call model() instead of model.forward()