Open gante opened 1 month ago
Adding the WIP label, so you don't get pinged by the stale bot 🤖
Could you elaborate on the "prefill" component? My impression is that this step converts the prompt into a KV cache, i.e. "pre-filling" the KV component for the tokens that are fixed. If that's correct, this component could probably serve double duty (as an exit point from the generate
procedure) for users who just want logprobs/scores for the prompt.
IMHO, upsteam of (part of?) the "generate outputs" step in the decoding loop should be a templated _prepare_outputs
function whose job is just to attach output attributes to the appropriate output class. Design POC via #29545, concretely:
https://github.com/coreweave/transformers/blob/dmarx.output_streamer/src/transformers/generation/utils.py#L354-L378
@dmarx
Could you elaborate on the "prefill" component? My impression is that this step converts the prompt into a KV cache, i.e. "pre-filling" the KV component for the tokens that are fixed. If that's correct, this component could probably serve double duty (as an exit point from the generate procedure) for users who just want logprobs/scores for the prompt.
Correct, it can be made a public function with that additional purpose. The difference between the prefill for generate
and obtaining the scores for the prompt is that in the former, we only want to keep the past KV. The different output needs suggest to me that a stand-alone public function is preferable to an alternate exit to generate
:D Added this to the notes of the prefill stage in the diagram.
IMHO, upsteam of (part of?) the "generate outputs" step in the decoding loop should be a templated _prepare_outputs function whose job is just to attach output attributes to the appropriate output class. Design POC via https://github.com/huggingface/transformers/pull/29545, concretely: https://github.com/coreweave/transformers/blob/dmarx.output_streamer/src/transformers/generation/utils.py#L354-L378
That's a good idea, adding it to the diagram!
Thank you for the suggestions 💛
generate
+ composability = more use cases with minimal rewritesAs I write this issue,
generate
is mostly a sequential monolith. Many internal blocks were carved into functions over the last two years, but navigating there as a beginner is still messy. It is also very challenging to adaptgenerate
to different tasks and/or modalities, forcing us to overwrite the entire generate function (e.g. RAG, MusicGen). All these aspects make using, documenting, maintaining, and testinggenerate
a challenge.This issue is a tracker for the refactor of
generate
, where we aim to build the structure outlined in this board. Key ideas for this refactor: 👉 All models can use the basegenerate
API 👉 Reduce if/else blocks 👉 Reduce the barriers to entry for new decoding methods/modalities/use cases 👉 Reduce per-model overwrites when possible 👉 Add unit tests 👉 Add documentation regarding the structure ofgenerate
Tasks
input_ids[:, -1:]
), so we don't compute variables regarding the latest token twice;use_cache=True
and cache length < input length - 1;_expand_inputs_for_generation
needs to be changed (it copied inputs before prefill, we will need to copy prefill outputs)yield
/yield from
instead ofreturn
pipeline
position_ids
.LogitsWarper
in this step (it's a copy ofLogitsProcessor
)generate
[From this point onwards the tasks are only a sketch, need more detailed planning when we get there]
prepare_inputs_for_generation
, VLMs also have their special preprocessing steps, ...)prepare_inputs_for_generation
?generate
from models that have a custom implementation