microsoft / onnxscript

ONNX Script enables developers to naturally author ONNX functions and models using a subset of Python.
https://onnxscript.ai/
MIT License
282 stars 53 forks source link

Linear from PyTorch must map to Gemm in ONNX #1089

Closed baijumeswani closed 1 year ago

baijumeswani commented 1 year ago

PyTorch Model:

class NeuralNet(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super().__init__()

        self.fc1 = torch.nn.Linear(input_size, hidden_size)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(hidden_size, num_classes)

    def forward(self, input1):
        out = self.fc1(input1)
        out = self.relu(out)
        out = self.fc2(out)
        return out

Exporting with torch script based exporter yields:

image

which makes sense. It is after all a linear layer followed by a ReLU followed by another Linear layer.

Exporting the same model with torch dynamo based exporter yields:

image

Two levels beneath the linear layer, I find:

image

It seems like the Gemm is somehow manifested as a subgraph with matmuls, muls, adds, and castlikes. And digging deeper, I find that this definition comes from https://github.com/microsoft/onnxscript/blob/a981b8add9a4c7e67ad0d28622b23d4c6a55a76e/onnxscript/function_libs/torch_lib/ops/core.py#L220-L229

It seems wasteful that an op as simple as a Gemm needs to be represented as this subgraph. Looking at this document, this seems to be a design choice.

We favor general ops like MatMul than specialized ops like Gemm in the function lib.

But imagine a model having thousands of Gemms. Each Gemm is now this subgraph. Which means this optimization/fusion needs to run thousands of times to achieve something that probably can be achieved very easily at the source.

It would benefit ONNX Runtime (inference and training) and the larger ONNX community if this subgraph were represented as a Gemm node after export.

baijumeswani commented 1 year ago

cc: @BowenBao @justinchuby

justinchuby commented 1 year ago

Thanks for raising this issue! When we created the decomposition, I realized Gemm is a special case for the op addmm (https://github.com/pytorch/pytorch/blob/a6b452dfdcb484d5dfdbb577b74cecbd7021df2e/torch/onnx/symbolic_opset9.py#L645-L652). In the design of torchlib, we wanted the ONNX functions to mirror the aten ops behavior as closely as possible, so that we preserve the richest information for downstream optimization (doc). To be able to use Gemm for addmm, we need to know the type and rank of the input, which are not assumed to be available at export time.

This kind of fusion should actually be simple for downstream optimization passes by design. We can look at the aten_addmm function, its input types and rank when available, then make the substitution. We do need the type and rank information for this which is not available in nested functions though as @BowenBao pointed out in https://github.com/onnx/onnx/issues/5487

justinchuby commented 1 year ago

However, for this special case, we may be able to create an overload for supported types to conditionally choose Gemm based on rank. Optimization passes will still need to fold if branches for this.

Edit:

I tried (1)

@torch_op("aten::addmm")
def aten_addmm(
    self: TReal, mat1: TReal, mat2: TReal, beta: float = 1.0, alpha: float = 1.0
) -> TReal:
    """addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"""

    use_gemm = op.And(op.Equal(op.Size(op.Shape(mat1)), 2), op.Equal(op.Size(op.Shape(mat2)), 2))
    if use_gemm:
        result = op.Gemm(mat1, mat2, self, alpha=alpha, beta=beta)
    else:
        mat1_mat2 = op.MatMul(mat1, mat2)
        scaled_mat1_mat2 = op.Mul(mat1_mat2, alpha)
        scaled_self = op.Mul(self, beta)
        result = op.Add(scaled_self, scaled_mat1_mat2)
    return result

But apparently ORT has only Gemm for float32 and not other types. So this needs to become (2)

@torch_op("aten::addmm")
def aten_addmm_gemm(
    self: FLOAT, mat1: FLOAT, mat2: FLOAT, beta: float = 1.0, alpha: float = 1.0
) -> FLOAT:
    """addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"""

    use_gemm = op.And(op.Equal(op.Size(op.Shape(mat1)), 2), op.Equal(op.Size(op.Shape(mat2)), 2))
    if use_gemm:
        result = op.Gemm(mat1, mat2, self, alpha=alpha, beta=beta)
    else:
        mat1_mat2 = op.MatMul(mat1, mat2)
        scaled_mat1_mat2 = op.Mul(mat1_mat2, alpha)
        scaled_self = op.Mul(self, beta)
        result = op.Add(scaled_self, scaled_mat1_mat2)
    return result

@torch_op("aten::addmm")
def aten_addmm(
    self: TNotFloat32, mat1: TNotFloat32, mat2: TNotFloat32, beta: float = 1.0, alpha: float = 1.0
) -> TNotFloat32:
    """addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"""

    mat1_mat2 = op.MatMul(mat1, mat2)
    scaled_mat1_mat2 = op.Mul(mat1_mat2, alpha)
    scaled_self = op.Mul(self, beta)
    result = op.Add(scaled_self, scaled_mat1_mat2)
    return result

But since Gemm is defined on {tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)}, it makes less sense for the exporter to make this specialization on FLOAT inputs.

Let me know what you think or if I am missing anything. Thanks!

justinchuby commented 1 year ago

So far it looks like the best path forward is for ORT to implement Gemm on spec'ed types and use (1). This way we strike a balance on correctness, complexity and the effort needed for fusion.

justinchuby commented 1 year ago

Although: if folding is not necessarily easier and the model may run even slower with if branches when unoptimized. The assumption is we don’t want to specialize the function at conversion time so we can’t just use Gemm.

baijumeswani commented 1 year ago

All solutions offered here are not very helpful since they all require a subgraph computation/optimization (to be either folded away or fused to Gemm).

Ideally, the information about the rank/type of the input matrices as well as the value of alpha and beta are known at export time. Which makes me feel that this should be dealt with at source and not down streamed to another optimization pass at a later time.

This particularly becomes more important for scenarios where the export is an inline operation (such as in ORTModule) and the export time along with other optimization times results in performance penalty for the scenario.

baijumeswani commented 1 year ago

cc @pranavsharma for awareness, as I think this would impact inference as well.

pranavsharma commented 1 year ago

Thanks @baijumeswani for adding me.

Exporter team: Please try to fix this at the export time as ORT is not the only consumer of ONNX graphs. There is a whole ecosystem around ONNX and such changes will break all of them.

ORT has not implemented Gemm for certain types because there was no production use case and adding unnecessary types increases the binary size. Hence, it doesn't make sense for ORT to implement ops for all types. For the most frequently used types, can we emit Gemm? This way we're not penalizing the majority use cases.

justinchuby commented 1 year ago

Thanks for this perspective! Happy to explore options here. One of the things that come to mind is as we build out aot optimization capabilities for ONNX graphs, these type of patterns can be optimized away (by the exporter) before the runtime sees the graph. This way tools in the ecosystem can choose to operate on graphs with different levels of generality based on the assumptions they are built against.

xadupre commented 1 year ago

Some thoughts related to these issues.

Models can be very big nowadays, anything we don't handle at exporting handle must be taken care of at optimizing time. It is ok for small models but it is still ok on bigger models with thousands of operators? Looking for patterns in such graphs adds significant time. Maybe we should start tracking the converters performance (converting time, optimizing time with onnxruntime).

One particular case with onnx-script, it is rare but it can happen.

if beta == 0:
   B = op.Matmul(X, np.array(...))
else:
   B = op.Matmul(X, np.array(....))

onnx-script will convert this into 3 operators (if + 2x matmul) and 2 initializers. Then an optimization will fold the constants and keep one operator and one initializer. What if both initializers are very big? We would add unnecessary tensors to the model making it unnecessary big.

Another one, again, it is rare but it is possible:

B = op.Matmul(A, op.CastLike(np.array([....], dtype=np.float32), B)

The onnx model will always keep float tensors but if the model is float16, this could be reduced by half and the exported model could be smaller.

justinchuby commented 1 year ago

Thanks!

it is still ok on bigger models with thousands of operators

Potentially, since we have functions already, there should be a clear boundary for us to match things?

Maybe we should start tracking the converters performance (converting time, optimizing time with onnxruntime)

A similar thing is tracked at https://github.com/microsoft/onnx-converters-private/issues/166#issuecomment-1764864419 (dort, compilation time) From profiling we have seen the main delay being torch dynamo at the moment.

onnx-script will convert this into 3 operators (if + 2x matmul) and 2 initializers

I think a concreate usage will help discussion here. Since aten operators take all large tensors as input, I don't see we will duplicate large tensors in functions (they are more likely scalars).

The onnx model will always keep float tensors but if the model is float16, this could be reduced by half and the exported model could be smaller.

This I think presents a similar issue, where all castlike'd constants tend to be single element tensors (scalars) that don't take up spaces. The exported initializers will be float16 if the model is dealing with float16 inputs.

gramalingam commented 1 year ago

A few comments:

justinchuby commented 1 year ago

Thanks @gramalingam. I can change to (1) in the implementation if there is no objections. Fortunately the rest of the complexity for this op is no longer a concern because we realized Gemm can handle the op. But these points do help when we counter new instances like this