Open gs-olive opened 1 year ago
Decoder-style models are difficult for our torch.compile
backend since we do not directly support Dynamic shapes yet. This presents a unique challenge, since each iteration of the auto-regressive loop increases the size of the input by 1. As a workaround, we propose a padding-based solution, which auto-pads user inputs in the auto-regressive loop to a fixed maximum shape, thus resulting in a much-improved system and avoiding recompilation for every shape between the minimum and maximum output size.
We take Llama as a canonical model here:
class LlamaForCausalLMTorchTensorRT(LlamaForCausalLM):
def __init__(self, config, tokenizer=None, max_pad_length=96):
super().__init__(config)
self.tokenizer = tokenizer
self.max_pad_length = max_pad_length
def set_tokenizer(self, tokenizer):
self.tokenizer = tokenizer
if self.tokenizer.pad_token_id is None:
self.tokenizer.add_special_tokens({'pad_token': '0'})
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
The prepare_inputs_for_generation
function override will feature the key logic which performs auto-padding to the requested max_pad_length
. This will assist in resolving the dynamic shape recompilation issue.
Thus far, the proposed solution has been effective, but there remain a few questions which need to be addressed to consider this solution fully functional. For instance, passing in alternative parameters such as past_key_values
or position_ids
can be problematic for this auto-padding scheme. Additionally, we need to override the call
function, as so:
def __call__(self, *args: Any, **kwargs: Any) -> Any:
output = super().__call__(*args, **kwargs)
non_padding_elts = kwargs["attention_mask"].sum().item()
output.logits = output.logits[:, :non_padding_elts, ...]
return output
The above will be handled with improved messaging to users and fail-fast error checking.