huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
128.38k stars 25.47k forks source link

tracker: `generate` composability refactor #30810

Open gante opened 1 month ago

gante commented 1 month ago

generate + composability = more use cases with minimal rewrites

As I write this issue, generate is mostly a sequential monolith. Many internal blocks were carved into functions over the last two years, but navigating there as a beginner is still messy. It is also very challenging to adapt generate to different tasks and/or modalities, forcing us to overwrite the entire generate function (e.g. RAG, MusicGen). All these aspects make using, documenting, maintaining, and testing generate a challenge.

This issue is a tracker for the refactor of generate, where we aim to build the structure outlined in this board. Key ideas for this refactor: 👉 All models can use the base generate API 👉 Reduce if/else blocks 👉 Reduce the barriers to entry for new decoding methods/modalities/use cases 👉 Reduce per-model overwrites when possible 👉 Add unit tests 👉 Add documentation regarding the structure of generate

Tasks

[From this point onwards the tasks are only a sketch, need more detailed planning when we get there]

amyeroberts commented 1 month ago

Adding the WIP label, so you don't get pinged by the stale bot 🤖

dmarx commented 1 month ago

Could you elaborate on the "prefill" component? My impression is that this step converts the prompt into a KV cache, i.e. "pre-filling" the KV component for the tokens that are fixed. If that's correct, this component could probably serve double duty (as an exit point from the generate procedure) for users who just want logprobs/scores for the prompt.

dmarx commented 1 month ago

IMHO, upsteam of (part of?) the "generate outputs" step in the decoding loop should be a templated _prepare_outputs function whose job is just to attach output attributes to the appropriate output class. Design POC via #29545, concretely: https://github.com/coreweave/transformers/blob/dmarx.output_streamer/src/transformers/generation/utils.py#L354-L378

gante commented 1 month ago

@dmarx

Could you elaborate on the "prefill" component? My impression is that this step converts the prompt into a KV cache, i.e. "pre-filling" the KV component for the tokens that are fixed. If that's correct, this component could probably serve double duty (as an exit point from the generate procedure) for users who just want logprobs/scores for the prompt.

Correct, it can be made a public function with that additional purpose. The difference between the prefill for generate and obtaining the scores for the prompt is that in the former, we only want to keep the past KV. The different output needs suggest to me that a stand-alone public function is preferable to an alternate exit to generate :D Added this to the notes of the prefill stage in the diagram.

IMHO, upsteam of (part of?) the "generate outputs" step in the decoding loop should be a templated _prepare_outputs function whose job is just to attach output attributes to the appropriate output class. Design POC via https://github.com/huggingface/transformers/pull/29545, concretely: https://github.com/coreweave/transformers/blob/dmarx.output_streamer/src/transformers/generation/utils.py#L354-L378

That's a good idea, adding it to the diagram!

Thank you for the suggestions 💛