pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
https://pytorch.org/TensorRT
BSD 3-Clause "New" or "Revised" License
2.54k stars 349 forks source link

Llama Investigation #2360

Open gs-olive opened 1 year ago

gs-olive commented 1 year ago
gs-olive commented 12 months ago

Context

Decoder-style models are difficult for our torch.compile backend since we do not directly support Dynamic shapes yet. This presents a unique challenge, since each iteration of the auto-regressive loop increases the size of the input by 1. As a workaround, we propose a padding-based solution, which auto-pads user inputs in the auto-regressive loop to a fixed maximum shape, thus resulting in a much-improved system and avoiding recompilation for every shape between the minimum and maximum output size.

Proposal

We take Llama as a canonical model here:

class LlamaForCausalLMTorchTensorRT(LlamaForCausalLM):
    def __init__(self, config, tokenizer=None, max_pad_length=96):
        super().__init__(config)
        self.tokenizer = tokenizer
        self.max_pad_length = max_pad_length

    def set_tokenizer(self, tokenizer):
        self.tokenizer = tokenizer

        if self.tokenizer.pad_token_id is None:
            self.tokenizer.add_special_tokens({'pad_token': '0'})

    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
    ):

The prepare_inputs_for_generation function override will feature the key logic which performs auto-padding to the requested max_pad_length. This will assist in resolving the dynamic shape recompilation issue.

gs-olive commented 11 months ago

Proposal Progress

Thus far, the proposed solution has been effective, but there remain a few questions which need to be addressed to consider this solution fully functional. For instance, passing in alternative parameters such as past_key_values or position_ids can be problematic for this auto-padding scheme. Additionally, we need to override the call function, as so:

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        output = super().__call__(*args, **kwargs)
        non_padding_elts = kwargs["attention_mask"].sum().item()
        output.logits = output.logits[:, :non_padding_elts, ...]
        return output

The above will be handled with improved messaging to users and fail-fast error checking.