Closed baijumeswani closed 1 year ago
cc: @BowenBao @justinchuby
Thanks for raising this issue! When we created the decomposition, I realized Gemm is a special case for the op addmm
(https://github.com/pytorch/pytorch/blob/a6b452dfdcb484d5dfdbb577b74cecbd7021df2e/torch/onnx/symbolic_opset9.py#L645-L652). In the design of torchlib, we wanted the ONNX functions to mirror the aten ops behavior as closely as possible, so that we preserve the richest information for downstream optimization (doc). To be able to use Gemm
for addmm
, we need to know the type and rank of the input, which are not assumed to be available at export time.
This kind of fusion should actually be simple for downstream optimization passes by design. We can look at the aten_addmm
function, its input types and rank when available, then make the substitution. We do need the type and rank information for this which is not available in nested functions though as @BowenBao pointed out in https://github.com/onnx/onnx/issues/5487
However, for this special case, we may be able to create an overload for supported types to conditionally choose Gemm based on rank. Optimization passes will still need to fold if branches for this.
Edit:
I tried (1)
@torch_op("aten::addmm")
def aten_addmm(
self: TReal, mat1: TReal, mat2: TReal, beta: float = 1.0, alpha: float = 1.0
) -> TReal:
"""addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"""
use_gemm = op.And(op.Equal(op.Size(op.Shape(mat1)), 2), op.Equal(op.Size(op.Shape(mat2)), 2))
if use_gemm:
result = op.Gemm(mat1, mat2, self, alpha=alpha, beta=beta)
else:
mat1_mat2 = op.MatMul(mat1, mat2)
scaled_mat1_mat2 = op.Mul(mat1_mat2, alpha)
scaled_self = op.Mul(self, beta)
result = op.Add(scaled_self, scaled_mat1_mat2)
return result
But apparently ORT has only Gemm for float32 and not other types. So this needs to become (2)
@torch_op("aten::addmm")
def aten_addmm_gemm(
self: FLOAT, mat1: FLOAT, mat2: FLOAT, beta: float = 1.0, alpha: float = 1.0
) -> FLOAT:
"""addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"""
use_gemm = op.And(op.Equal(op.Size(op.Shape(mat1)), 2), op.Equal(op.Size(op.Shape(mat2)), 2))
if use_gemm:
result = op.Gemm(mat1, mat2, self, alpha=alpha, beta=beta)
else:
mat1_mat2 = op.MatMul(mat1, mat2)
scaled_mat1_mat2 = op.Mul(mat1_mat2, alpha)
scaled_self = op.Mul(self, beta)
result = op.Add(scaled_self, scaled_mat1_mat2)
return result
@torch_op("aten::addmm")
def aten_addmm(
self: TNotFloat32, mat1: TNotFloat32, mat2: TNotFloat32, beta: float = 1.0, alpha: float = 1.0
) -> TNotFloat32:
"""addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"""
mat1_mat2 = op.MatMul(mat1, mat2)
scaled_mat1_mat2 = op.Mul(mat1_mat2, alpha)
scaled_self = op.Mul(self, beta)
result = op.Add(scaled_self, scaled_mat1_mat2)
return result
But since Gemm is defined on {tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)}, it makes less sense for the exporter to make this specialization on FLOAT inputs.
Let me know what you think or if I am missing anything. Thanks!
So far it looks like the best path forward is for ORT to implement Gemm on spec'ed types and use (1). This way we strike a balance on correctness, complexity and the effort needed for fusion.
Although: if folding is not necessarily easier and the model may run even slower with if branches when unoptimized. The assumption is we don’t want to specialize the function at conversion time so we can’t just use Gemm.
All solutions offered here are not very helpful since they all require a subgraph computation/optimization (to be either folded away or fused to Gemm).
Ideally, the information about the rank/type of the input matrices as well as the value of alpha and beta are known at export time. Which makes me feel that this should be dealt with at source and not down streamed to another optimization pass at a later time.
This particularly becomes more important for scenarios where the export is an inline operation (such as in ORTModule) and the export time along with other optimization times results in performance penalty for the scenario.
cc @pranavsharma for awareness, as I think this would impact inference as well.
Thanks @baijumeswani for adding me.
Exporter team: Please try to fix this at the export time as ORT is not the only consumer of ONNX graphs. There is a whole ecosystem around ONNX and such changes will break all of them.
ORT has not implemented Gemm for certain types because there was no production use case and adding unnecessary types increases the binary size. Hence, it doesn't make sense for ORT to implement ops for all types. For the most frequently used types, can we emit Gemm? This way we're not penalizing the majority use cases.
Thanks for this perspective! Happy to explore options here. One of the things that come to mind is as we build out aot optimization capabilities for ONNX graphs, these type of patterns can be optimized away (by the exporter) before the runtime sees the graph. This way tools in the ecosystem can choose to operate on graphs with different levels of generality based on the assumptions they are built against.
Some thoughts related to these issues.
Models can be very big nowadays, anything we don't handle at exporting handle must be taken care of at optimizing time. It is ok for small models but it is still ok on bigger models with thousands of operators? Looking for patterns in such graphs adds significant time. Maybe we should start tracking the converters performance (converting time, optimizing time with onnxruntime).
One particular case with onnx-script, it is rare but it can happen.
if beta == 0:
B = op.Matmul(X, np.array(...))
else:
B = op.Matmul(X, np.array(....))
onnx-script will convert this into 3 operators (if + 2x matmul) and 2 initializers. Then an optimization will fold the constants and keep one operator and one initializer. What if both initializers are very big? We would add unnecessary tensors to the model making it unnecessary big.
Another one, again, it is rare but it is possible:
B = op.Matmul(A, op.CastLike(np.array([....], dtype=np.float32), B)
The onnx model will always keep float tensors but if the model is float16, this could be reduced by half and the exported model could be smaller.
Thanks!
it is still ok on bigger models with thousands of operators
Potentially, since we have functions already, there should be a clear boundary for us to match things?
Maybe we should start tracking the converters performance (converting time, optimizing time with onnxruntime)
A similar thing is tracked at https://github.com/microsoft/onnx-converters-private/issues/166#issuecomment-1764864419 (dort, compilation time) From profiling we have seen the main delay being torch dynamo at the moment.
onnx-script will convert this into 3 operators (if + 2x matmul) and 2 initializers
I think a concreate usage will help discussion here. Since aten operators take all large tensors as input, I don't see we will duplicate large tensors in functions (they are more likely scalars).
The onnx model will always keep float tensors but if the model is float16, this could be reduced by half and the exported model could be smaller.
This I think presents a similar issue, where all castlike'd constants tend to be single element tensors (scalars) that don't take up spaces. The exported initializers will be float16 if the model is dealing with float16 inputs.
A few comments:
Thanks @gramalingam. I can change to (1) in the implementation if there is no objections. Fortunately the rest of the complexity for this op is no longer a concern because we realized Gemm can handle the op. But these points do help when we counter new instances like this
PyTorch Model:
Exporting with torch script based exporter yields:
which makes sense. It is after all a linear layer followed by a ReLU followed by another Linear layer.
Exporting the same model with torch dynamo based exporter yields:
Two levels beneath the linear layer, I find:
It seems like the Gemm is somehow manifested as a subgraph with matmuls, muls, adds, and castlikes. And digging deeper, I find that this definition comes from https://github.com/microsoft/onnxscript/blob/a981b8add9a4c7e67ad0d28622b23d4c6a55a76e/onnxscript/function_libs/torch_lib/ops/core.py#L220-L229
It seems wasteful that an op as simple as a Gemm needs to be represented as this subgraph. Looking at this document, this seems to be a design choice.
But imagine a model having thousands of
Gemm
s. EachGemm
is now this subgraph. Which means this optimization/fusion needs to run thousands of times to achieve something that probably can be achieved very easily at the source.It would benefit ONNX Runtime (inference and training) and the larger ONNX community if this subgraph were represented as a Gemm node after export.