aws-neuron / aws-neuron-samples

Example code for AWS Neuron SDK developers building inference and training applications
Other
101 stars 32 forks source link

Cannot compile Bart generate-optimised decode #18

Closed sinking-point closed 12 months ago

sinking-point commented 12 months ago

I'm trying to compile Bart for text2text generation on an Inf2 server. I am aware that optimum-neuron has a Bart implementation, but I need to be able to make customizations that are incompatible with the pipeline system.

Bart is implemented so that if you pass past_key_values, you can provide only the last decoder input ID rather than the whole string. This speeds up the attention, so that it's linear per step rather than quadratic time, because it only has to run for one position rather than all positions so far. This is an important compute optimisation.

When I try to trace a call to Bart that uses this optimisation, I get an error:

2023-07-04T12:38:29Z ERROR 26864 [Tensorizer]: Transformation error on operator: mlir.function
2023-07-04T12:38:29Z ERROR 26864 [neuronx-cc]: ***************************************************************
2023-07-04T12:38:29Z ERROR 26864 [neuronx-cc]:  An Internal Compiler Error has occurred
2023-07-04T12:38:29Z ERROR 26864 [neuronx-cc]: ***************************************************************
2023-07-04T12:38:29Z ERROR 26864 [neuronx-cc]: 
2023-07-04T12:38:29Z ERROR 26864 [neuronx-cc]: Error message:  too many values to unpack (expected 1)

Steps to reproduce:

from transformers import BartForConditionalGeneration, BartTokenizerFast, BartConfig
import copy
import torch
import torch.nn.functional as F
import torch_neuronx
import transformers

model = BartForConditionalGeneration.from_pretrained('facebook/bart-base', torchscript=True)

example_sentence = "Hello, my name is Billy."
tokeniser = BartTokenizerFast.from_pretrained('facebook/bart-base')
tokens = tokeniser(example_sentence, return_tensors='pt')

inputs = (None,None, None, None, None, None, None, (torch.zeros((1, 128, 768)),), None, None, torch.zeros((1, 9, 768)))
outputs = model(*inputs)

class BartForNeuronDecoder(torch.nn.Module):

    def __init__(self, bart):
        super().__init__()

        self.bart = bart

    def forward(
        self,
        decoder_input_ids, # 1 token per batch
        encoder_outputs,
        attention_mask, # for encoder outputs
        past_key_values, # max_len - 1
    ):

        outputs = self.bart.model(
            encoder_outputs=encoder_outputs, 
            attention_mask=attention_mask, 
            decoder_input_ids=decoder_input_ids, 
            past_key_values=past_key_values,
            use_cache=True,
        )
        lm_logits = self.bart.lm_head(outputs[0]) + self.bart.final_logits_bias

        return (
            lm_logits,
            outputs[1]
        )

wrapped_model = BartForNeuronDecoder(model)

def pad_key_values(past_key_values, max_len):
    padded_key_values = ()
    for layer in past_key_values:
        padded_layer = ()
        for i in [0,1]:
            padded_layer = padded_layer + (F.pad(layer[i], pad=(0,0,0, max_len - 1 - layer[i].shape[2])),)
        padded_layer = padded_layer + layer[2:]
        padded_key_values = padded_key_values + (padded_layer,)

    return padded_key_values

pkv = pad_key_values(outputs[1], 128)

args = (torch.tensor([[0]]), (torch.zeros((1, 128, 768)),), torch.tensor([[1,1,1,1,1] + [0]*123]), pkv)
wrapped_model_neuron = torch_neuronx.trace(wrapped_model, args)

Let me know if you have trouble reproducing it and need additional details.

Many thanks.

sinking-point commented 12 months ago

Moved this to https://github.com/aws-neuron/aws-neuron-sdk/issues/703