facebookincubator / AITemplate

AITemplate is a Python framework which renders neural network into high performance CUDA/HIP C++ code. Specialized for FP16 TensorCore (NVIDIA GPU) and MatrixCore (AMD GPU) inference.
Apache License 2.0
4.55k stars 368 forks source link

Low performance from unnecessary permutations #936

Open jonpryai opened 1 year ago

jonpryai commented 1 year ago

I'm using fx2ait to load an onnx graph. After optimization, the results are not good.

BS: 1, PT Eager time per iter: 0.01654841552734375ms, PT Eager QPS: 60.43, FX2AIT time per iter: 0.024108586425781252ms, FX2AIT Eager QPS: 41.48, Speedup: 0.69

Let alone compared to tensorRt. I profiled the optimized graph and found:

61.9 11,952,884,520 64,800 184,458.1 123,999.0 18,112 1,017,919 216,853.6 void ::PermuteKernel<(unsigned long)4, (unsigned long)2, int>(::PermuteKernelPara…

Analyzing this in nsys, i see what is happening is that the graph is consistently doing:

permute -> element wise addition -> permute.

These permutations don't do anything because the element wise operator doesn't care about the ordering.

How to fix?

ColinPeppler commented 1 year ago

Hi @jonpryai, thanks for flagging this. It does seem like at least one of the permutes could be redundant. But without a minimal repro, it's hard to determine whether they should be removed and whether we need a pass to handle this case.

Do you mind sharing details on how to reproduce this? Thanks!

jonpryai commented 1 year ago

I use this to compile

import onnx
from onnx2torch import convert
from fx2ait.example.benchmark_utils import benchmark_function

batch_size = 1
class TestModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        onnx_model = onnx.load("test.onnx")
        self.mod = convert(onnx_model)
    def forward(self, x):
        return self.mod(x)

model = TestModule().cuda().half()
inputs = [torch.randn(batch_size, 16, 224, 224).half().cuda()]
benchmark_function(
    self.__class__.__name__,
    100,
    model,
    inputs,
)

Profiling the above is difficult because of all the compiling and profiling. I locate the test.so in /tmp and copy it to the current dir. Then run:

import unittest

import torch
import torchvision
import onnx
import os
from aitemplate.compiler import compile_model, Model

from onnx2pytorch import ConvertModel
from onnx2torch import convert
from fx2ait.example.benchmark_utils import benchmark_function, verify_accuracy

def benchmark(model_name, batch_size, mod=None, graph_mode=True):
    # Load params
    #cuda_params = export_to_torch_tensor(model_name)
[test.zip](https://github.com/facebookincubator/AITemplate/files/12752757/test.zip)

    # Load compiled model
    if mod is None:
        model_name = f"{model_name}_{batch_size}"
        mod = Model(os.path.join("./", "test.so"))

    # prepare input/output tensor
    x_input = torch.randn([batch_size, 16, 224, 224]).cuda().half()
    x_input = x_input.contiguous()
    y_output = torch.zeros([batch_size, 64, 56, 56]).cuda().half()
    y_output = y_output.contiguous()

    # warm up
    t, _, __ = mod.benchmark_with_tensors(
        [x_input],
        [y_output],
        count=100,
        repeat=4,
        graph_mode=graph_mode,
    )
    # benchmark
    t, _, __ = mod.benchmark_with_tensors(
        [x_input],
        [y_output],
        count=100,
        repeat=4,
        graph_mode=graph_mode,
    )
    print(f"batch_size: {batch_size}, latency: {t}")
    dev_flag = os.environ.get("HIP_VISIBLE_DEVICES", "-1")
    dev_flag = dev_flag.replace(",", "_")
    with open(f"resnet50_ait_benchmark_dev_{dev_flag}.txt", "a") as f:
        f.write(f"batch_size: {batch_size}, latency: {t}\n")

if __name__ == "__main__":
    benchmark("",1)```
jonpryai commented 1 year ago

Example onnx section:

test.zip

ColinPeppler commented 1 year ago

For security reasons, I'm unable to download external files. I hope you understand.

It'll be easier to reproduce if you can share the model graph that AIT dumps automatically via dump_graph_debug_str_to_file. You'll need to set the following environment variable: LOGLEVEL=DEBUG when you compile the model. The files will appear in your workdir.

Once you do that, could you share the contents of memory_planning_pseudo_code.txt?

jonpryai commented 1 year ago

(Tensor(name=permute_0_0, shape=[1, 224, 224, 16])) 
= permute()(
Tensor(name=x, shape=[1, 16, 224, 224]))

# conv2d_bias_1
(Tensor(name=conv2d_bias_1_0, shape=[1, 112, 112, 32])) 
= conv2d_bias(dilate=1, group=1, pad=1, stride=2)(
Tensor(name=permute_0_0, shape=[1, 224, 224, 16]), Tensor(name=mod_level1_level1_0_Conv_weight, shape=[32, 3, 3, 16], data=(9216 bytes)), Tensor(name=mod_level1_level1_0_Conv_bias, shape=[32], data=(64 bytes)))

# permute_2
(Tensor(name=permute_2_0, shape=[1, 32, 112, 112])) 
= permute()(
Tensor(name=conv2d_bias_1_0, shape=[1, 112, 112, 32]))

# fused_elementwise_19
(Tensor(name=elementwise_3_0, shape=[1, 32, 112, 112])) 
= fused_elementwise(func=[<FuncEnum.RELU: 18>])(
Tensor(name=permute_2_0, shape=[1, 32, 112, 112]))

# permute_4
(Tensor(name=permute_4_0, shape=[1, 112, 112, 32])) 
= permute()(
Tensor(name=elementwise_3_0, shape=[1, 32, 112, 112]))

# permute_4
(Tensor(name=permute_5_0, shape=[1, 112, 112, 32])) 
= permute()(
Tensor(name=elementwise_3_0, shape=[1, 32, 112, 112]))

# conv2d_bias_6
(Tensor(name=conv2d_bias_6_0, shape=[1, 56, 56, 64])) 
= conv2d_bias(dilate=1, group=1, pad=1, stride=2)(
Tensor(name=permute_5_0, shape=[1, 112, 112, 32]), Tensor(name=mod_level2_tree1_conv1_Conv_weight, shape=[64, 3, 3, 32], data=(36864 bytes)), Tensor(name=mod_level2_tree1_conv1_Conv_bias, shape=[64], data=(128 bytes)))

# permute_7
(Tensor(name=permute_7_0, shape=[1, 64, 56, 56])) 
= permute()(
Tensor(name=conv2d_bias_6_0, shape=[1, 56, 56, 64]))

# fused_elementwise_20
(Tensor(name=elementwise_8_0, shape=[1, 64, 56, 56])) 
= fused_elementwise(func=[<FuncEnum.RELU: 18>])(
Tensor(name=permute_7_0, shape=[1, 64, 56, 56]))

# permute_9
(Tensor(name=permute_9_0, shape=[1, 56, 56, 64])) 
= permute()(
Tensor(name=elementwise_8_0, shape=[1, 64, 56, 56]))

# conv2d_bias_10
(Tensor(name=conv2d_bias_10_0, shape=[1, 56, 56, 64])) 
= conv2d_bias(dilate=1, group=1, pad=1, stride=1)(
Tensor(name=permute_9_0, shape=[1, 56, 56, 64]), Tensor(name=mod_level2_tree1_conv2_Conv_weight, shape=[64, 3, 3, 64], data=(73728 bytes)), Tensor(name=mod_level2_tree1_conv2_Conv_bias, shape=[64], data=(128 bytes)))

# permute_7
(Tensor(name=permute_11_0, shape=[1, 64, 56, 56])) 
= permute()(
Tensor(name=conv2d_bias_10_0, shape=[1, 56, 56, 64]))

# max_pool2d_12
(Tensor(name=max_pool2d_12_0, shape=[1, 56, 56, 32])) 
= max_pool2d(stride=2, pad=0, kernel_size=2, reduce_func=max)(
Tensor(name=permute_4_0, shape=[1, 112, 112, 32]))

# conv2d_bias_15
(Tensor(name=conv2d_bias_15_0, shape=[1, 56, 56, 64])) 
= conv2d_bias(dilate=1, group=1, pad=0, stride=1)(
Tensor(name=max_pool2d_12_0, shape=[1, 56, 56, 32]), Tensor(name=mod_level2_project_project_0_Conv_weight, shape=[64, 1, 1, 32], data=(4096 bytes)), Tensor(name=mod_level2_project_project_0_Conv_bias, shape=[64], data=(128 bytes)))

# permute_7
(Tensor(name=permute_16_0, shape=[1, 64, 56, 56])) 
= permute()(
Tensor(name=conv2d_bias_15_0, shape=[1, 56, 56, 64]))

# fused_elementwise_21
(Tensor(name=output_0, shape=[1, 64, 56, 56])) 
= fused_elementwise(func=[<FuncEnum.ADD: 1>, <FuncEnum.RELU: 18>])(
Tensor(name=permute_11_0, shape=[1, 64, 56, 56]), Tensor(name=permute_16_0, shape=[1, 64, 56, 56]))```
jonpryai commented 1 year ago

This image of the network might be helpful.

Screenshot from 2023-10-05 09-44-11

ColinPeppler commented 1 year ago

It does seem like either permute2 or permute4 can be removed here. It'll be easier to remove permute_2 imo.

And sorry for the delay, but this is what I believe we need:

  1. Find the conditions for removing permute_2.
    • We can make the conditions specific for your graph (i.e. only when the middle op is an elementwise-relu/gelu/etc.).
  2. Remove the first permute.
  3. Take the first permute's input (conv2d_bias_1_0) and make it the new input for the middle op (fused_elementwise_19).
  4. Confirm the shapes are correct for middle op and the remaining permute.
  5. Write a test case and confirm its accuracy.

Here's some pointers:

Lmk if there's any questions there.

jonpryai commented 1 year ago

I am not very familiar with the code, so I could be wrong. But my first impression looking at this is while the optimizer is able to look at different orderings, NHWC and NCHW for the conv2d, for some reason it is married to NCHW for the elementwise, and maybe doesn't take into account the permutation cost.

I think that both permute_2 and permute_4 can be removed. There's also 2 copies of permute_4 that yield exactly the same tensor. What is happening here is:

conv2d(NHWC) -> toNCHW -> elementWise -> toNHWC
                                      -> toNHWC

which is the same thing as conv2d(NHWC) -> elementWise

ColinPeppler commented 1 year ago

Ah I see, both permutes can definitely be removed in that case. And I'm not sure which pass introduces them in the first place.

Do you still have the dumped graphs in your directory? We can see which pass adds the permutes by looking at the {passname}_pseudo_code.txt.

jonpryai commented 1 year ago

They are present in everything except toposort_pseudo_code.txt. So bind_constants pass is causing it?

jonpryai commented 1 year ago

Actually, that's not true. It's even in the toposort, just the nodes haven't been annotated yet.

(Tensor(name=None, shape=[1, 224, 224, 16])) 
= permute()(
Tensor(name=x, shape=[1, 16, 224, 224]))

# None
(Tensor(name=None, shape=[1, 112, 112, 32])) 
= conv2d_bias(dilate=1, group=1, pad=1, stride=2)(
Tensor(name=None, shape=[1, 224, 224, 16]), Tensor(name=mod_level1_level1_0_Conv_weight, shape=[32, 3, 3, 16], data=(9216 bytes)), Tensor(name=mod_level1_level1_0_Conv_bias, shape=[32], data=(64 bytes)))

# None
(Tensor(name=None, shape=[1, 32, 112, 112])) 
= permute()(
Tensor(name=None, shape=[1, 112, 112, 32]))

# None
(Tensor(name=None, shape=[1, 32, 112, 112])) 
= elementwise(func=FuncEnum.RELU)(
Tensor(name=None, shape=[1, 32, 112, 112]))

# None
(Tensor(name=None, shape=[1, 112, 112, 32])) 
= permute()(
Tensor(name=None, shape=[1, 32, 112, 112]))

# None
(Tensor(name=None, shape=[1, 112, 112, 32])) 
= permute()(
Tensor(name=None, shape=[1, 32, 112, 112]))

# None
(Tensor(name=None, shape=[1, 56, 56, 64])) 
= conv2d_bias(dilate=1, group=1, pad=1, stride=2)(
Tensor(name=None, shape=[1, 112, 112, 32]), Tensor(name=mod_level2_tree1_conv1_Conv_weight, shape=[64, 3, 3, 32], data=(36864 bytes)), Tensor(name=mod_level2_tree1_conv1_Conv_bias, shape=[64], data=(128 bytes)))

# None
(Tensor(name=None, shape=[1, 64, 56, 56])) 
= permute()(
Tensor(name=None, shape=[1, 56, 56, 64]))

# None
(Tensor(name=None, shape=[1, 64, 56, 56])) 
= elementwise(func=FuncEnum.RELU)(
Tensor(name=None, shape=[1, 64, 56, 56]))

# None
(Tensor(name=None, shape=[1, 56, 56, 64])) 
= permute()(
Tensor(name=None, shape=[1, 64, 56, 56]))

# None
(Tensor(name=None, shape=[1, 56, 56, 64])) 
= conv2d_bias(dilate=1, group=1, pad=1, stride=1)(
Tensor(name=None, shape=[1, 56, 56, 64]), Tensor(name=mod_level2_tree1_conv2_Conv_weight, shape=[64, 3, 3, 64], data=(73728 bytes)), Tensor(name=mod_level2_tree1_conv2_Conv_bias, shape=[64], data=(128 bytes)))

# None
(Tensor(name=None, shape=[1, 64, 56, 56])) 
= permute()(
Tensor(name=None, shape=[1, 56, 56, 64]))

# None
(Tensor(name=None, shape=[1, 56, 56, 32])) 
= max_pool2d(stride=2, pad=0, kernel_size=2, reduce_func=max)(
Tensor(name=None, shape=[1, 112, 112, 32]))

# None
(Tensor(name=None, shape=[1, 32, 56, 56])) 
= permute()(
Tensor(name=None, shape=[1, 56, 56, 32]))

# None
(Tensor(name=None, shape=[1, 56, 56, 32])) 
= permute()(
Tensor(name=None, shape=[1, 32, 56, 56]))

# None
(Tensor(name=None, shape=[1, 56, 56, 64])) 
= conv2d_bias(dilate=1, group=1, pad=0, stride=1)(
Tensor(name=None, shape=[1, 56, 56, 32]), Tensor(name=mod_level2_project_project_0_Conv_weight, shape=[64, 1, 1, 32], data=(4096 bytes)), Tensor(name=mod_level2_project_project_0_Conv_bias, shape=[64], data=(128 bytes)))

# None
(Tensor(name=None, shape=[1, 64, 56, 56])) 
= permute()(
Tensor(name=None, shape=[1, 56, 56, 64]))

# None
(Tensor(name=None, shape=[1, 64, 56, 56])) 
= elementwise(func=FuncEnum.ADD)(
Tensor(name=None, shape=[1, 64, 56, 56]), Tensor(name=None, shape=[1, 64, 56, 56]))

# None
(Tensor(name=output_0, shape=[1, 64, 56, 56])) 
= elementwise(func=FuncEnum.RELU)(
Tensor(name=None, shape=[1, 64, 56, 56]))

Is it possible these nodes are being inserted by fxt2ai?

ColinPeppler commented 1 year ago

It could be fx2ait but it may also be onnx2torch.

I'm curious if replicating the model in Pytorch then using fx2ait will give us the same graph. If not, then I assume it's onnx2torch.

jonpryai commented 1 year ago

model gv

The permutes do not appear to be in the converted pytorch model. The permutes are present in the AITModel after the trace is performed.

ColinPeppler commented 1 year ago

You're right, the permutes are being added in fx2ait. The result from each conv2d is being permuted via ait_nhwc2nchw (here).

AIT does that because PyTorch takes channel-first tensors for conv, maxpool, etc., whereas, AIT takes channel-last tensors.

A potential workaround is to add a permute after each Conv2D? cc: @chenyang78

jonpryai commented 1 year ago

Is it possible to just make all the elementwise ops also do the permutation, then we will end up with a graph that is like

toNCHW -> conv2d -> toNHWC -> toHCHW -> elementWise -> to NHWC

then the remove permutations pass will find the redundant permutes

ColinPeppler commented 1 year ago

It sounds like that could work.

But would it be possible to try this?

  1. Permute your tensor so it's channel-last
  2. Set set_tensor_layout_policy(false) before lowering your model -- this avoids the permutes after conv2d
xmfbit commented 11 months ago

@jonpryai hi, have you solved the problem?

jonpryai commented 11 months ago

@xmfbit No not really. I am just trying to quickly see what the inference performance of a model would be with AITemplate. I'm wondering if instead of an onnx model, an FX graph may work correctly? Otherwise it may actually be easier to write the code to create an AITemplate model instead of trying to fix fxt2ait.

Trying to import a typical dla34 model gives a good example of the issues.