huggingface / optimum-intel

🤗 Optimum Intel: Accelerate inference with Intel optimization tools
https://huggingface.co/docs/optimum/main/en/intel/index
Apache License 2.0
364 stars 101 forks source link

optimize first latency beam search for OVModelForCausalLM #695

Closed eaidova closed 2 months ago

eaidova commented 3 months ago

What does this PR do?

this PR reduces first token latency for OVModelForCausalLM class if beam search decoding selected. Beam search represented during generation as batch of sequences (generation batch size = [num_input_promts * num_beams]). Generation API duplicates initial input sequence for promoting them for each beam before starting work, while on the first step all sequences are equal (in the same time, the first inference for models with cache is the most time-consuming part). The idea is postpone sequence duplication for beams after first iteration done (including duplication of past key values and logits in outputs)

Before submitting

HuggingFaceDocBuilderDev commented 3 months ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

IlyasMoutawwakil commented 2 months ago

IMO this will be very heavy to maintain with the constant changes in transformers lib, especially since the text generation api will be undergoing heavy refactorization soon.

Would it not make sense to instead of optimizing the generation strategy, rather optimize the first forward pass, with something along the lines of:

def generate():
    if beam_search: # or any generation strategy where this issue is observed
        self.first_beam_search_iteration = True
     else:
        self.first_beam_search_iteration = False

    return super().generate()

def forward():
    if self.first_beam_search_iteration :
        unique_inputs, inverse_order = torch.unique(inputs, dim=0, return_inverse=True) 
        # we can also use what we know about how the inputs are duplicated to deduplicate them
        unique_outputs = super().forward(unique_inputs)
        outputs = unique_inputs[inverse_order]
        self.first_beam_search_iteration = False
    else:
        outputs = super().forward(inputs)
    return outputs

I admit that this is more stateful and hacky than what's suggested in the PR, but it requires maintaining less code, until this duplication issue with beam search gets fixed in transformers.

eaidova commented 2 months ago

@IlyasMoutawwakil, thank you for your suggestion, that is from what I begin, but problem that we need to know how inputs was duplicated for nonstateful case to duplicate past key values and this required additional context for that (from generation config) that is not provided inside forward. Another problem is next_beam_idx that should be different before second inference (contains initial index duplication instead of arranged indices from cache reordering)

eaidova commented 2 months ago

@IlyasMoutawwakil @echarlaix please take a look one more time, I significantly updated code for reducing overriding beam search methods