aws-neuron / aws-neuron-sdk

Powering AWS purpose-built machine learning chips. Blazing fast and cost effective, natively integrated into PyTorch and TensorFlow and integrated with your favorite AWS services
https://aws.amazon.com/machine-learning/neuron/
Other
447 stars 149 forks source link

Cannot compile MT5 model #657

Closed yzGao22 closed 9 months ago

yzGao22 commented 1 year ago

Hi Team,

We are trying to compile MT5 model using torch-neuronx. There is no error reported but the traced model can only generate garbled characters during inference, while the model works fine before traced. Below is the code, which was original written by @daan-triumph:

import torch_xla
import torch_xla.core.xla_model as xm

import os
from time import perf_counter
from typing import Any, Optional, cast

import numpy as np
# import tensorflow  # type: ignore
import torch
import torch_neuronx
from torch.nn import functional as F
from transformers.generation_utils import GenerationMixin
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
from transformers.modeling_utils import PreTrainedModel
from transformers import MT5ForConditionalGeneration,MT5Config,MT5Tokenizer

device = xm.xla_device()
model_id = "csebuetnlp/mT5_multilingual_XLSum"
num_texts = 1  # Number of input texts to decode
num_beams = 4  # Number of beams per input text
max_encoder_length = 128  # Maximum input token length
max_decoder_length = 128

def reduce(hidden: torch.Tensor, index: int):
    _, n_length, _ = hidden.shape

    print(hidden.dtype)
    # Create selection mask
    mask = torch.arange(n_length, dtype=torch.float16).to(xm.xla_device()) == index
    mask = mask.view(1, -1, 1)

    # Broadcast mask
    masked = torch.multiply(hidden, mask)

    # Reduce along 1st dimension
    summed = torch.sum(masked, 1)
    return torch.unsqueeze(summed, 1)

class NeuronDecoder(torch.nn.Module):
    def __init__(self, model: MT5ForConditionalGeneration, max_length: int):
        super().__init__()
        self.weight = cast(torch.Tensor, model.shared.weight.clone().detach())
        self.decoder = model.decoder
        self.max_length = max_length

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        encoder_outputs: torch.Tensor,
        index: int,
    ):
        # Invoke the decoder
        (hidden,) = self.decoder(
            input_ids=input_ids,
            encoder_hidden_states=encoder_outputs,
            return_dict=False,
            use_cache=False,
        )
        # hidden = hidden.to(device)

        # Reduce decoder outputs to the specified index (current iteration)
        hidden = reduce(hidden, index)

        # Compute final linear layer for token probabilities
        logits = F.linear(hidden.to(xm.xla_device()), self.weight.to(xm.xla_device()))
        return logits

class NeuronGeneration(PreTrainedModel, GenerationMixin):
    def trace(
        self,
        model: MT5ForConditionalGeneration,
        num_texts: int,
        num_beams: int,
        max_encoder_length: int,
        max_decoder_length: int,
    ) -> None:
        """
        Traces the encoder and decoder modules for use on Neuron.
        This function fixes the network to the given sizes. Once the model has been
        compiled to a given size, the inputs to these networks must always be of
        fixed size.
        Args:
            model (GenerationMixin): The transformer-type generator model to trace
            num_texts (int): The number of input texts to translate at once
            num_beams (int): The number of beams to computer per text
            max_encoder_length (int): The maximum number of encoder tokens
            max_encoder_length (int): The maximum number of decoder tokens
        """
        self.config.max_decoder_length = max_decoder_length

        # Trace the encoder
        inputs = (
            torch.ones((num_texts, max_encoder_length), dtype=torch.long),
            torch.ones((num_texts, max_encoder_length), dtype=torch.long),
        )
        encoder = NeuronEncoder(model)
        self.encoder = cast(NeuronEncoder, torch_neuronx.trace(encoder, inputs))
        print('enconder finished')

        # Trace the decoder (with expanded inputs)
        batch_size = num_texts * num_beams
        inputs = (
            torch.ones((batch_size, max_decoder_length), dtype=torch.long),
            torch.ones((batch_size, max_encoder_length), dtype=torch.long),
            torch.ones(
                (batch_size, max_encoder_length, model.config.d_model),
                dtype=torch.float,
            ),
            torch.tensor(0),
        )
        decoder = NeuronDecoder(model, max_decoder_length)
        self.decoder = cast(NeuronDecoder, torch_neuronx.trace(decoder, inputs))
        print('decoder finished')

    # ------------------------------------------------------------------------
    # Encoder/Decoder Invocation
    # ------------------------------------------------------------------------

    def prepare_inputs_for_generation(
        self,
        input_ids: torch.Tensor,
        encoder_outputs: BaseModelOutput,
        attention_mask: Optional[BaseModelOutput] = None,
        **model_kwargs: Any,
    ):
        # Pad the inputs for Neuron
        current_length = input_ids.shape[1]
        pad_size = self.config.max_decoder_length - current_length
        return dict(
            input_ids=F.pad(input_ids, (0, pad_size)),
            attention_mask=attention_mask,
            encoder_outputs=encoder_outputs.last_hidden_state,
            current_length=torch.tensor(current_length - 1),
        )

    def get_encoder(self):
        """Helper to invoke the encoder and wrap the results in the expected structure"""

        def encode(**kwargs:Any):
            input_ids = kwargs["input_ids"]
            attention_mask = kwargs.get("attention_mask", torch.ones_like(input_ids))
            (output,) = self.encoder(input_ids, attention_mask)
            return BaseModelOutput(
                last_hidden_state=output,
            )

        return encode

    def __call__(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        encoder_outputs: BaseModelOutput,
        current_length: int,
        **kwargs: Any,
    ):
        """Helper to invoke the decoder and wrap the results in the expected structure"""
        logits = self.decoder(
            input_ids, attention_mask, encoder_outputs, current_length
        )
        return Seq2SeqLMOutput(logits=logits)

    # ------------------------------------------------------------------------
    # Serialization
    # ------------------------------------------------------------------------

    def save_pretrained(self, directory: str):
        if os.path.isfile(directory):
            print(f"Provided path ({directory}) should be a directory, not a file")
            return
        os.makedirs(directory, exist_ok=True)
        torch.jit.save(self.encoder, os.path.join(directory, "encoder.pt"))
        torch.jit.save(self.decoder, os.path.join(directory, "decoder.pt"))
        self.config.save_pretrained(directory)

    @classmethod
    def from_pretrained(cls, directory: str):
        config = MT5Config.from_pretrained(directory)
        obj = cls(config)
        obj.main_input_name = "input_ids"
        obj.encoder = torch.jit.load(os.path.join(directory, "encoder.pt"))
        obj.encoder.main_input_name = "input_ids"
        obj.decoder = torch.jit.load(os.path.join(directory, "decoder.pt"))
        obj.decoder.main_input_name = "decoder_input_ids"
        return obj

    @property
    def device(self):
        return torch.device("cpu")

model_cpu = cast(MT5ForConditionalGeneration, MT5ForConditionalGeneration.from_pretrained(model_id))
tokenizer_cpu = cast(MT5Tokenizer, MT5Tokenizer.from_pretrained(model_id))
model_neuron = NeuronGeneration(model_cpu.config)

model_neuron.trace(
    model=model_cpu,
    num_texts=num_texts,
    num_beams=num_beams,
    max_encoder_length=max_encoder_length,
    max_decoder_length=max_decoder_length,
)

And the inference code is:

def infer(model: NeuronGeneration, tokenizer: MT5Tokenizer, text: str):
    # Truncate and pad the max length to ensure that the token size is compatible with fixed-sized encoder (Not necessary for pure CPU execution)
    batch = tokenizer(
        text,
        max_length=max_decoder_length,
        truncation=True,
        padding="max_length",
        return_tensors="pt",
    )
    # with torch.inference_mode():
    output = model.generate(
        inputs=batch["input_ids"],
        #**batch,
        max_length=max_decoder_length,
        num_beams=num_beams,
        num_return_sequences=num_beams,
    )
    print(output)
    results = [tokenizer.decode(t, skip_special_tokens=True) for t in output]

    print("Texts:")
    for i, summary in enumerate(results):
        print(i + 1, summary)

text = 'Sky News announces slate of special programming for the appointment of the UK’s new Prime Minister.\nSky News’ political programming will expand ahead of a momentous week in UK politics with the impending announcement of the new Prime Minister. Sky News’ key political programmes will return to bring audiences in-depth discussion and analysis of all the latest news with live coverage from Downing Street and Westminster.\nHead of Sky News, John Ryley:\n“This is a momentous week in British politics, where a new Prime Minister will take on an in-tray bursting with crunch decisions. Sky News will be live in Westminster and hearing from voters in key constituencies to bring our audiences the latest news, and analyse the impact this new government will have on households across the UK.”\nSky News’ slate of dedicated political programming will kick off from 8.30 am on Sunday 4th September with Sophy Ridge on Sunday, focusing on the impending result of the Conservative Party leadership election.\nOn Monday 5th, Tuesday 6th and Wednesday 7th September, Sky News will bring live coverage to audiences from Downing Street and Westminster as power is handed over from outgoing Prime Minister, Boris Johnson to his successor, either Liz Truss or Rishi Sunak.\nSophy Ridge’s The Take will return on Wednesday 7th September and on Thursday 8th September Beth Rigby Interviews… will also return to explore the result of the leadership election.\nOther special programming on Sky News will cover the key moments for the new Prime Minster in their first days in office, including their first meeting with her Majesty the Queen on Tuesday – during which they’ll seek permission to form a government – and their first major statement as Prime Minister from the steps of Downing Street. With millions of households across the UK asking questions about the cost-of-living crisis, this moment will be pivotal for the new Prime Minister and Sky News will bring audiences the full story on-air, online, on our app, and via our podcasts.\nThe first Prime Minister’s Questions will also be broadcast live from the House of Commons on Sky News on Wednesday 7th September. This will be the first time that the new PM faces Labour’s leader across the despatch box, and they will be expected to face questions about future government policy including the possible timing of a general election.\nCoverage of all these events will be available on the Sky News Politics Hub and will continue throughout September and October from the Labour and Conservative Party Conferences from 25th-28th September and 2nd-5th October respectively. Sky News will also be hosting pop-up radio stations at both party conferences.'

infer(model_neuron,tokenizer_cpu,text)

Expected outcome: the summarization of above text in English Actually outcome:

01.01.201sqftПÐ밡밡ksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksuksu

Environment information: python==3.8 torch==1.13.1+cu117 transformers==4.28.1 numpy==1.21.6 The code is run on m5.16xlarge EC2 instance.

Do you know how to fix this problem? Thank you in advance.

shebbur-aws commented 1 year ago

Thank you for reporting the issue. We are trying to repro this on our end and will get back to you shortly.

chanhosong commented 1 year ago

Has this issue been resolved? I too am having the same problem.

mrnikwaws commented 11 months ago

Here are two links to work with:

This ticket has been open for some time - so you may already have resolved the issue. If I don't hear back by early next week I will plan to close the ticket.

aws-taylor commented 9 months ago

Hello @yzGao22,

We haven't heard for you in a while, so I'm going to resolve this issue, but don't hesitate to reach out if you have further questions or issues.