onnx / onnx-mlir

Representation and Reference Lowering of ONNX Models in MLIR Compiler Infrastructure
Apache License 2.0
775 stars 322 forks source link

ONNX-MLIR fails to compile Llama2-7B model : Issue with RMSNorm #2924

Closed srijanie03 closed 2 months ago

srijanie03 commented 3 months ago

I use the code below to generate the ONNX training model for Llama2-7b (with Lora) using torchtune:

from torchtune.models.llama2 import llama2_7b, lora_llama2_7b
import onnxruntime.training.onnxblock as onnxblock
from onnxruntime.training.api import CheckpointState, Module, Optimizer
from onnxruntime.training import artifacts
from onnxruntime import InferenceSession

lora_model = lora_llama2_7b(lora_attn_modules=["q_proj", "v_proj"])

torch.onnx.export(
lora_model,
batch,
"torchtune_lora_llama2.onnx",
input_names=input_names,
output_names=output_names,
opset_version=14,
do_constant_folding=False,
training=torch.onnx.TrainingMode.TRAINING,
dynamic_axes=dynamic_axes,
export_params=True,
keep_initializers_as_inputs=False,
)

requires_grad = [name for name, param in lora_model.named_parameters() if param.requires_grad]
frozen_params = [name for name, param in lora_model.named_parameters() if not param.requires_grad]

artifacts.generate_artifacts(
"torchtune_lora_llama2.onnx",
optimizer=artifacts.OptimType.AdamW,
loss=artifacts.LossType.CrossEntropyLoss,
requires_grad=requires_grad,
frozen_params=frozen_params,
artifact_directory="torchtune_lora_llama2",
additional_output_names=["output"])

It works fine. But when I use the same ONNX model to compile using ONNX-MLIR, it doesn't work correctly for the RMSNorm node represented as SimplifiedLayerNormalization here:

%7:2 = "onnx.Custom"(%6, %arg4) {axis = -1 : si64, domain_name = "", epsilon = 9.99999997E-7 : f32, function_name = "SimplifiedLayerNormalization", onnx_node_name = "/rmsnorm/Mul_1/SimplifiedLayerNormFusion/", stash_type = 1 : si64} : (tensor<2x3x3xf32>, tensor<3xf32>) -> (tensor<*xf32>, **none**)

%11:2 = "onnx.Custom"(%10, %6, %arg4, %7#1) {axis = -1 : si64, domain_name = "com.microsoft", epsilon = 9.99999997E-7 : f32, function_name = "SimplifiedLayerNormalizationGrad", onnx_node_name = "/rmsnorm/Mul_1/SimplifiedLayerNormFusion/_Grad/SimplifiedLayerNormalizationGrad_0", stash_type = 1 : si64} : (tensor<2x3x3xf32>, tensor<2x3x3xf32>, tensor<3xf32>, none) -> (tensor<2x3x3xf32>, **none**)

The same issue occurs for a simple MLP with RMSNorm implemented as below (Llama2's custom RMSNorm), so the problem is not Llama2 specific but related to RMSNorm:

class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        """
        Initialize the RMSNorm normalization layer.

        Args:
            dim (int): The dimension of the input tensor.
            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.

        Attributes:
            eps (float): A small value added to the denominator for numerical stability.
            weight (nn.Parameter): Learnable scaling parameter.

        """
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        """
        Apply the RMSNorm normalization to the input tensor.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The normalized tensor.

        """
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        """
        Forward pass through the RMSNorm layer.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The output tensor after applying RMSNorm.

        """
        output = self._norm(x.float()).type_as(x)
        return output * self.weight`

    class SimpleNetRMS(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(SimpleNetRMS, self).__init__()

        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.rmsnorm = RMSNorm(3)                      #Llama2 custom

    def forward(self, x):
        x = self.linear1(x)
        x = self.rmsnorm(x)

        return x

        inputs = torch.randn(2,3,3)
  1. What is causing the function to output a None? The None flows through the graph and creates more issues.
  2. Why does it go through a SimplifiedLayerNormFusion? Can it be modified/bypassed in some way?

Really appreciate the help, thank you!