I use the code below to generate the ONNX training model for Llama2-7b (with Lora) using torchtune:
from torchtune.models.llama2 import llama2_7b, lora_llama2_7b
import onnxruntime.training.onnxblock as onnxblock
from onnxruntime.training.api import CheckpointState, Module, Optimizer
from onnxruntime.training import artifacts
from onnxruntime import InferenceSession
lora_model = lora_llama2_7b(lora_attn_modules=["q_proj", "v_proj"])
torch.onnx.export(
lora_model,
batch,
"torchtune_lora_llama2.onnx",
input_names=input_names,
output_names=output_names,
opset_version=14,
do_constant_folding=False,
training=torch.onnx.TrainingMode.TRAINING,
dynamic_axes=dynamic_axes,
export_params=True,
keep_initializers_as_inputs=False,
)
requires_grad = [name for name, param in lora_model.named_parameters() if param.requires_grad]
frozen_params = [name for name, param in lora_model.named_parameters() if not param.requires_grad]
artifacts.generate_artifacts(
"torchtune_lora_llama2.onnx",
optimizer=artifacts.OptimType.AdamW,
loss=artifacts.LossType.CrossEntropyLoss,
requires_grad=requires_grad,
frozen_params=frozen_params,
artifact_directory="torchtune_lora_llama2",
additional_output_names=["output"])
It works fine. But when I use the same ONNX model to compile using ONNX-MLIR, it doesn't work correctly for the RMSNorm node represented as SimplifiedLayerNormalization here:
The same issue occurs for a simple MLP with RMSNorm implemented as below (Llama2's custom RMSNorm), so the problem is not Llama2 specific but related to RMSNorm:
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
return output * self.weight`
class SimpleNetRMS(nn.Module):
def __init__(self, input_dim, hidden_dim):
super(SimpleNetRMS, self).__init__()
self.linear1 = nn.Linear(input_dim, hidden_dim)
self.rmsnorm = RMSNorm(3) #Llama2 custom
def forward(self, x):
x = self.linear1(x)
x = self.rmsnorm(x)
return x
inputs = torch.randn(2,3,3)
What is causing the function to output a None? The None flows through the graph and creates more issues.
Why does it go through a SimplifiedLayerNormFusion? Can it be modified/bypassed in some way?
I use the code below to generate the ONNX training model for Llama2-7b (with Lora) using torchtune:
It works fine. But when I use the same
ONNX
model to compile usingONNX-MLIR
, it doesn't work correctly for the RMSNorm node represented asSimplifiedLayerNormalization
here:%7:2 = "onnx.Custom"(%6, %arg4) {axis = -1 : si64, domain_name = "", epsilon = 9.99999997E-7 : f32, function_name = "SimplifiedLayerNormalization", onnx_node_name = "/rmsnorm/Mul_1/SimplifiedLayerNormFusion/", stash_type = 1 : si64} : (tensor<2x3x3xf32>, tensor<3xf32>) -> (tensor<*xf32>, **none**)
%11:2 = "onnx.Custom"(%10, %6, %arg4, %7#1) {axis = -1 : si64, domain_name = "com.microsoft", epsilon = 9.99999997E-7 : f32, function_name = "SimplifiedLayerNormalizationGrad", onnx_node_name = "/rmsnorm/Mul_1/SimplifiedLayerNormFusion/_Grad/SimplifiedLayerNormalizationGrad_0", stash_type = 1 : si64} : (tensor<2x3x3xf32>, tensor<2x3x3xf32>, tensor<3xf32>, none) -> (tensor<2x3x3xf32>, **none**)
The same issue occurs for a simple MLP with RMSNorm implemented as below (Llama2's custom RMSNorm), so the problem is not Llama2 specific but related to RMSNorm:
None
? TheNone
flows through the graph and creates more issues.SimplifiedLayerNormFusion
? Can it be modified/bypassed in some way?Really appreciate the help, thank you!