microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.7k stars 2.93k forks source link

The speed of running the onnx model is 6x slower than the pytorch model on Jetson TX2 #7233

Open GavinJiacheng opened 3 years ago

GavinJiacheng commented 3 years ago

Describe the bug The speed of running the onnx model is 6x slower than running it on PyTorch

Urgency April 20 /2020

System information

To Reproduce I ran the code to run the ONNX model:


import io
import numpy as np

from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx

# Super Resolution model definition in PyTorch
import torch.nn as nn
import torch.nn.init as init

import onnx
import onnxruntime

import time

batch_size = 1

# Input to the model
x1 = torch.randn(batch_size, 3, 384, 192, requires_grad=True, device='cuda')
x2 = torch.randn(batch_size, 3, 384, 192, requires_grad=True, device='cuda')

ort_session = onnxruntime.InferenceSession("gwt_jetso.onnx")

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

ortvalue1 = onnxruntime.OrtValue.ortvalue_from_numpy(to_numpy(x1), 'cuda', 0)
ortvalue2 = onnxruntime.OrtValue.ortvalue_from_numpy(to_numpy(x2), 'cuda', 0)

ort_inputs = {ort_session.get_inputs()[0].name: ortvalue1, ort_session.get_inputs()[1].name: ortvalue2}
ort_outs = ort_session.run([], ort_inputs)

The model I used is the GwcNet, the model is:


from __future__ import print_function
import torch
import torch.nn as nn
import torch.utils.data
from torch.autograd import Variable
import torch.nn.functional as F
from models.submodule import *
import math

class feature_extraction(nn.Module):
    def __init__(self, concat_feature=False, concat_feature_channel=12):
        super(feature_extraction, self).__init__()
        self.concat_feature = concat_feature

        self.inplanes = 32
        self.firstconv = nn.Sequential(convbn(3, 32, 3, 2, 1, 1),
                                       nn.ReLU(inplace=True),
                                       convbn(32, 32, 3, 1, 1, 1),
                                       nn.ReLU(inplace=True),
                                       convbn(32, 32, 3, 1, 1, 1),
                                       nn.ReLU(inplace=True))

        self.layer1 = self._make_layer(BasicBlock, 32, 3, 1, 1, 1)
        self.layer2 = self._make_layer(BasicBlock, 64, 16, 2, 1, 1)
        self.layer3 = self._make_layer(BasicBlock, 128, 3, 1, 1, 1)
        self.layer4 = self._make_layer(BasicBlock, 128, 3, 1, 1, 2)

        if self.concat_feature:
            self.lastconv = nn.Sequential(convbn(320, 128, 3, 1, 1, 1),
                                          nn.ReLU(inplace=True),
                                          nn.Conv2d(128, concat_feature_channel, kernel_size=1, padding=0, stride=1,
                                                    bias=False))

    def _make_layer(self, block, planes, blocks, stride, pad, dilation):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion), )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, pad, dilation))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, 1, None, pad, dilation))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.firstconv(x)
        x = self.layer1(x)
        l2 = self.layer2(x)
        l3 = self.layer3(l2)
        l4 = self.layer4(l3)

        gwc_feature = torch.cat((l2, l3, l4), dim=1)

        if not self.concat_feature:
            return {"gwc_feature": gwc_feature}
        else:
            concat_feature = self.lastconv(gwc_feature)
            return {"gwc_feature": gwc_feature, "concat_feature": concat_feature}

class hourglass(nn.Module):
    def __init__(self, in_channels):
        super(hourglass, self).__init__()

        self.conv1 = nn.Sequential(convbn_3d(in_channels, in_channels * 2, 3, 2, 1),
                                   nn.ReLU(inplace=True))

        self.conv2 = nn.Sequential(convbn_3d(in_channels * 2, in_channels * 2, 3, 1, 1),
                                   nn.ReLU(inplace=True))

        self.conv3 = nn.Sequential(convbn_3d(in_channels * 2, in_channels * 4, 3, 2, 1),
                                   nn.ReLU(inplace=True))

        self.conv4 = nn.Sequential(convbn_3d(in_channels * 4, in_channels * 4, 3, 1, 1),
                                   nn.ReLU(inplace=True))

        self.conv5 = nn.Sequential(
            nn.ConvTranspose3d(in_channels * 4, in_channels * 2, 3, padding=1, output_padding=1, stride=2, bias=False),
            nn.BatchNorm3d(in_channels * 2))

        self.conv6 = nn.Sequential(
            nn.ConvTranspose3d(in_channels * 2, in_channels, 3, padding=1, output_padding=1, stride=2, bias=False),
            nn.BatchNorm3d(in_channels))

        self.redir1 = convbn_3d(in_channels, in_channels, kernel_size=1, stride=1, pad=0)
        self.redir2 = convbn_3d(in_channels * 2, in_channels * 2, kernel_size=1, stride=1, pad=0)

    def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)

        conv3 = self.conv3(conv2)
        conv4 = self.conv4(conv3)

        conv5 = F.relu(self.conv5(conv4) + self.redir2(conv2), inplace=True)
        conv6 = F.relu(self.conv6(conv5) + self.redir1(x), inplace=True)

        return conv6

class GwcNet(nn.Module):
    def __init__(self, maxdisp, use_concat_volume=False):
        super(GwcNet, self).__init__()
        self.maxdisp = maxdisp
        self.use_concat_volume = use_concat_volume

        self.num_groups = 40

        if self.use_concat_volume:
            self.concat_channels = 12
            self.feature_extraction = feature_extraction(concat_feature=True,
                                                         concat_feature_channel=self.concat_channels)
        else:
            self.concat_channels = 0
            self.feature_extraction = feature_extraction(concat_feature=False)

        self.dres0 = nn.Sequential(convbn_3d(self.num_groups + self.concat_channels * 2, 32, 3, 1, 1),
                                   nn.ReLU(inplace=True),
                                   convbn_3d(32, 32, 3, 1, 1),
                                   nn.ReLU(inplace=True))

        self.dres1 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
                                   nn.ReLU(inplace=True),
                                   convbn_3d(32, 32, 3, 1, 1))

        self.dres2 = hourglass(32)

        self.dres3 = hourglass(32)

        self.dres4 = hourglass(32)

        self.classif0 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
                                      nn.ReLU(inplace=True),
                                      nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False))

        self.classif1 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
                                      nn.ReLU(inplace=True),
                                      nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False))

        self.classif2 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
                                      nn.ReLU(inplace=True),
                                      nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False))

        self.classif3 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
                                      nn.ReLU(inplace=True),
                                      nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.Conv3d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

    def forward(self, left, right):
        features_left = self.feature_extraction(left)
        features_right = self.feature_extraction(right)

        gwc_volume = build_gwc_volume(features_left["gwc_feature"], features_right["gwc_feature"], self.maxdisp // 4,
                                      self.num_groups)
        if self.use_concat_volume:
            concat_volume = build_concat_volume(features_left["concat_feature"], features_right["concat_feature"],
                                                self.maxdisp // 4)
            volume = torch.cat((gwc_volume, concat_volume), 1)
        else:
            volume = gwc_volume

        cost0 = self.dres0(volume)
        cost0 = self.dres1(cost0) + cost0

        out1 = self.dres2(cost0)
        out2 = self.dres3(out1)
        out3 = self.dres4(out2)

        if self.training:
            cost0 = self.classif0(cost0)
            cost1 = self.classif1(out1)
            cost2 = self.classif2(out2)
            cost3 = self.classif3(out3)

            cost0 = F.upsample(cost0, [self.maxdisp, left.size()[2], left.size()[3]], mode='trilinear')
            cost0 = torch.squeeze(cost0, 1)
            pred0 = F.softmax(cost0, dim=1)
            pred0 = disparity_regression(pred0, self.maxdisp)

            cost1 = F.upsample(cost1, [self.maxdisp, left.size()[2], left.size()[3]], mode='trilinear')
            cost1 = torch.squeeze(cost1, 1)
            pred1 = F.softmax(cost1, dim=1)
            pred1 = disparity_regression(pred1, self.maxdisp)

            cost2 = F.upsample(cost2, [self.maxdisp, left.size()[2], left.size()[3]], mode='trilinear')
            cost2 = torch.squeeze(cost2, 1)
            pred2 = F.softmax(cost2, dim=1)
            pred2 = disparity_regression(pred2, self.maxdisp)

            cost3 = F.upsample(cost3, [self.maxdisp, left.size()[2], left.size()[3]], mode='trilinear')
            cost3 = torch.squeeze(cost3, 1)
            pred3 = F.softmax(cost3, dim=1)
            pred3 = disparity_regression(pred3, self.maxdisp)
            return [pred0, pred1, pred2, pred3]

        else:
            cost3 = self.classif3(out3)
            cost3 = F.upsample(cost3, [self.maxdisp, left.size()[2], left.size()[3]], mode='trilinear')
            cost3 = torch.squeeze(cost3, 1)
            pred3 = F.softmax(cost3, dim=1)
            pred3 = disparity_regression(pred3, self.maxdisp)
            return [pred3]

# Handle pre-trained weight and model.parallel
# https://discuss.pytorch.org/t/loading-weights-from-dataparallel-models/20570/2
class WrappedModel(nn.Module):
    def __init__(self, module):
        super(WrappedModel, self).__init__()
        self.module = module # that I actually define.
    def forward(self, x, y):
        return self.module(x, y)

def GwcNet_G(d):
    model = GwcNet(d, use_concat_volume=False)
    model = WrappedModel(model)
    return model
   # return GwcNet(d, use_concat_volume=False)

def GwcNet_GC(d):
    model = GwcNet(d, use_concat_volume=True)
    model = WrappedModel(model)
    return model
    #return GwcNet(d, use_concat_volume=True)

The code we used to convert the PyTorch model to ONNX is:


def convert_model(model):

    print('Converting')

    left = torch.randn(1, 3, 384, 192, requires_grad=True)
    right = torch.randn(1, 3, 384, 192, requires_grad=True)
    torch_out = model(left, right)

   # Export the model
    torch.onnx.export(model,                     # model being run
                     (left, right),             # model input (or a tuple for multiple inputs)
                     "GwcNet_1.onnx",             # where to save the model (can be a file or file-like object)
                     export_params=True,        # store the trained parameter weights inside the model file
                     opset_version=11,          # the ONNX version to export the model to
                     do_constant_folding=True,  # whether to execute constant folding for optimization
                     input_names = ['left', 'right'],   # the model's input names
                     output_names = ['output'], # the model's output names
                     dynamic_axes={'left' : [0, 2, 3],
                                   'right' : [0, 2, 3],
                                   'output' : [0, 2, 3],
                                  }
                                   )

    print('Conversion successful')

The code we used to run the model on PyTorch is:


@make_nograd_func
def testsample():
    left = torch.randn(1, 3, 384, 192, requires_grad=True, dtype=torch.float32).cuda()
    right = torch.randn(1, 3, 384, 192, requires_grad=True, dtype=torch.float32).cuda()
    torch_out = model(left, right)

The download link of our ONNX file

Expected behavior

The speed of running this model on PyTorch is around 2.5s for one frame. However, it is around 14s when we run it on the ONNX model. It is around 6 times slower.

We checked the log, didn't find the warning "fall back to CPU".

pranavsharma commented 3 years ago

@jywu-msft @HectorSVC do any of you have access to the Jetson device for this investigation?

pranavsharma commented 3 years ago

I don't see the timing code here, so not sure if you included the time needed to create the ORT session. You should exclude that time. Also, it looks like x1 and x2 are first allocated on the gpu, then copied to cpu (to_numpy) and then created on the gpu again (ortvalue_from_numpy). Hope this is not included in the timing?

GavinJiacheng commented 3 years ago

I don't see the timing code here, so not sure if you included the time needed to create the ORT session. You should exclude that time. Also, it looks like x1 and x2 are first allocated on the gpu, then copied to cpu (to_numpy) and then created on the gpu again (ortvalue_from_numpy). Hope this is not included in the timing?

@pranavsharma Actually, I used the timing code before. I just deleted it before I put the code here. I used time.time() to check the code ort_outs = ort_session.run([], ort_inputs) and torch_out = model(left, right) on pytorch. That's why I got the value 2.5 sec and 14 sec.

And yes, the converting is not including in the timing. the timing is only for one line. Sorry, the code is a little bit messy since we did many tests to see why the speed is so slow.

pranavsharma commented 3 years ago

I don't have a Jetson TX2 to debug. But, one thing to check is to see if any of the nodes in the graph got assigned to CPU during the graph partitioning phase. Turn on verbose logging and look at the logs to see this info.

GavinJiacheng commented 3 years ago

@pranavsharma We checked the log. There is no "fall back to CPU". Only few lines have the keyword "CPU":


2021-04-05 13:07:50.370189803 [I:onnxruntime:Default, bfc_arena.cc:23 BFCArena] Creating BFCArena for CUDA_CPU with following configs: initial_chunk_size_bytes: 1048576 max_dead_bytes_per_chunk: 134217728 memory limit: 18446744073709551615 arena_extend_strategy 0

2021-04-05 13:07:50.370303887 [I:onnxruntime:Default, bfc_arena.cc:23 BFCArena] Creating BFCArena for Cpu with following configs: initial_chunk_size_bytes: 1048576 max_dead_bytes_per_chunk: 134217728 memory limit: 18446744073709551615 arena_extend_strategy 0

2021-04-05 13:07:52.420892341 [I:onnxruntime:Default, bfc_arena.cc:280 AllocateRawInternal] Extending BFCArena for Cpu. bin_num:17 rounded_bytes:35389440