pytorch / executorch

On-device AI across mobile, embedded and edge for PyTorch
https://pytorch.org/executorch/
Other
1.41k stars 230 forks source link

Slower inference time running MobileNet V3 when compared to PyTorch Mobile #4005

Open laoluani opened 1 week ago

laoluani commented 1 week ago

I'm currently accessing the performance of Exectorch compared to PyTorch Mobile. When running the MobileNet V3 Small model from the Pytorch vision model hub. I find that it's slower than running on PyTorch Mobile v1.12.1 on Android the Samsung S20 Ultra. I have taken the android Executorch Demo app and integerated PyTorch Mobile and run the model side by side.

On average running on Executorch I get a forward pass time of around 30ms and 8ms on PyTorch.

PyTorch Mobile set up Here's the script used to do the conversion:

from torchvision import models
import torch
from torch.utils.mobile_optimizer import optimize_for_mobile

mv3_small = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1)
input = torch.randn((1, 3, 224, 224))
mv3_small.eval()
torchscript_model = torch.jit.trace(model, input)
torchscript_model._save_for_lite_interpreter("mv3.ptl")

I'm not applying any extra optimisations in the mobile model conversion

Executorch set up This follows the XNNPACK demo branched from v.0.2.0, but altered to use the MobileNet V3 small. Here's output from XNNPACK lowering process

graph():
    %arg210_1 : [num_users=1] = placeholder[target=arg210_1]
    %lowered_module_0 : [num_users=1] = get_attr[target=lowered_module_0]
    %executorch_call_delegate : [num_users=1] = call_function[target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_0, %arg210_1), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 0), kwargs = {})
    %aten_view_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.view_copy.default](args = (%getitem, [1, 576]), kwargs = {})
    %lowered_module_1 : [num_users=1] = get_attr[target=lowered_module_1]
    %executorch_call_delegate_1 : [num_users=1] = call_function[target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_1, %aten_view_copy_default), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate_1, 0), kwargs = {})
    return (getitem_1,)

I assume In these circumstances both Pytorch Mobile and Executorch are both using the XNNPACK backed where possible. So I would expect forwards pass times to be similar. Executorch has the added advantage of the applying the extra preprocessing to the model when lowering.

I'm guessing the general optimisations that torch.jit.trace applies to model is out performing the specific optimisations of the XNNPACK preprocess step.

Any recommendations for speed ups?

guangy10 commented 1 week ago

@laoluani Thanks you for sharing the pref experiment with us.

both Pytorch Mobile and Executorch are both using the XNNPACK backed where possible

I don't even see you call optimize_for_mobile for the PyTorch Mobile model, so it seems just scripts the model. @kimishpatel or @cccclai can clarify how the old torchscript based solution works.

cbilgin commented 2 days ago

I'm guessing the general optimisations that torch.jit.trace applies to model is out performing the specific optimisations of the XNNPACK preprocess step.

You should actually be seeing speedups with ExecuTorch for MV3 compared to torchscript. Can you share how you lowered to xnnpack as well?

laoluani commented 3 minutes ago

I realised that I hadn't turned on the release flag for my ExecuTorch builds, once enabled I got results much closer to PyTorch Mobile.

I ran a test over 500 iterations with new random inputs on every pass

pytorch mobile avg: 26.05, max: 57, min: 8, std deviation: 11.24
executorch avg: 26.32, max: 75, min: 5, std deviation: 10.22

Here's how I lowered the model:

import torch
import torchvision.models as models
import logging

from torch.export import export, ExportedProgram
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import EdgeProgramManager, ExecutorchProgramManager, to_edge
from executorch.exir.backend.backend_api import to_backend

mobilenet_v3 = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1).eval()
sample_inputs = (torch.randn(1, 3, 224, 224), )

exported_program: ExportedProgram = export(mobilenet_v3, sample_inputs)
edge: EdgeProgramManager = to_edge(exported_program)

logging.info(f"Exported graph:\n{edge.exported_program()}")

edge = edge.to_backend(XnnpackPartitioner())
logging.info(f"Lowered graph:\n{edge.exported_program()}")

exec_prog = edge.to_executorch()

with open("mv3-xnnpack.pte", "wb") as file:
    exec_prog.write_to_file(file) 

Here are the build commands:

cmake . -DCMAKE_INSTALL_PREFIX=/build \
    -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \
    -DANDROID_ABI="$ANDROID_ABI" \
    -DEXECUTORCH_BUILD_XNNPACK=ON \
    -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
    -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
    -DEXECUTORCH_BUILD_OPTIMIZED=ON \
    -DCMAKE_BUILD_TYPE=Release \
    -B/build

find "/build" -type f -exec sed -i 's/-fno-exceptions/-fexceptions/g' {} +

cmake --build /build -j16 --target install --config Release

cmake extension/android \
  -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}"/build/cmake/android.toolchain.cmake \
  -DANDROID_ABI="${ANDROID_ABI}" \
  -DCMAKE_INSTALL_PREFIX=/build \
  -DCMAKE_BUILD_TYPE=Release \
  -B/build/extension/android

cmake --build /build/extension/android -j16 --config Release

Are these the expected results and are there any improvements that can be made?