pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.15k stars 22.09k forks source link

xpu: efficientnet inference underperforms ipex #132176

Open dvrogozh opened 1 month ago

dvrogozh commented 1 month ago

Running efficientnet inference in eager mode using sample at https://github.com/intel/ai-reference-models/tree/a6a9365c0e94a78ea38bbeacc14f5b43bd5fd577/models_v2/pytorch/efficientnet/inference/gpu on Intel(R) Data Center GPU Max 1100 I observe underperformance of XPU pytorch backend vs. IPEX, ~60%:

Similar difference is with other precisions and batch sizes.

Example can be executed as follows:

git clone https://github.com/intel/ai-reference-models.git
cd ai-reference-models
export PYTHONPATH=$(pwd)/models_v2/common/:$PYTHONPATH

# run with XPU pytorch backend (note: --ipex no):
cd ai-reference-models/models_v2/pytorch/efficientnet/inference/gpu
PLATFORM=Max ./run_model.sh --ipex no --dummy --min-test-duration 30 --max-test-duration 30 --precision fp32

# run with IPEX:
docker build -f docker/flex-gpu/pytorch-efficientnet-inference/pytorch-flex-series-efficientnet-inference.Dockerfile -t enet
docker run -it --rm --ipc=host --cap-add SYS_NICE --device /dev/dri/ -e OUTPUT_DIR=/opt/outputs -e PLATFORM=Flex enet \
    ./run_model.sh --dummy --min-test-duration 30 --max-test-duration 30 --precision fp32

Versions used:

CC: @gujinghui @EikanWang @fengyuan14 @guangyey @jgong5 @vlad-penkin

cc @gujinghui @EikanWang @fengyuan14 @guangyey

EikanWang commented 1 month ago

@fengyuan14 , there should be no implementation difference. Could you help investigate the performance variance?

EikanWang commented 1 month ago

In addition, it is conv-based model. Please check if the model is using block-format or channel last.

fengyuan14 commented 1 month ago

@fengyuan14 , there should be no implementation difference. Could you help investigate the performance variance?

Yes, I will take it.

fengyuan14 commented 1 month ago

Hi @dvrogozh, I checked the model script, https://github.com/intel/ai-reference-models/blob/a6a9365c0e94a78ea38bbeacc14f5b43bd5fd577/models_v2/pytorch/efficientnet/inference/gpu/run_model.sh#L23 IPEX execution might be using torch.jit.trace by default for inference. Could you try another IPEX execution with --jit=.

torch.jit.trace is a legacy graph solution IPEX supported. With the graph optimization, there will be some fusion recipes enabled, like convolution + relu ...

dvrogozh commented 1 month ago

Script supports --jit trace|script|none with trace being default. Here are my side results with different --jit options and few other variables I can variate. Channels last is default and option for it is not exposed on shell script level, but is available on pytorch level, so can be changed here

jit fps fps, IPEX_XPU_ONEDNN_LAYOUT=OFF fps, channels-last=False
trace 320 315 317
script 323 324 328
none 208 190 200
jit fps channels-last=False
trace 198 180
script 260 228
none 196 180

For XPU --jit script performs better, but still underperfoms IPEX.

fengyuan14 commented 1 month ago
  1. CL=False, XPU is 10% behind IPEX. The comparison should be investigated. BTW, the gap is on par with what we got in our internal test. Since there is not an abnormal gap there, we will follow our existing plan. You also can keep the issue. I will keep it updated if any progress.
  2. jit=script, XPU gets 260FPS. I cannot image where the improvement comes from. We will take it as low priority since jit path is not recommended by PyTorch due to a lot limitations of usage.
    • When not importing IPEX, XPU specific jit fusion recipes (convolution or matmul related) won't be registered. So there should be no benefit from it.
    • There is no NV fuser or NNC like graph compiler on JIT graph path for XPU, so there should no elementwise/reduction fusion.
    • If we are talking about convolution + batch_norm folding, jit=trace should also benefit from it, but get data same as jit=none.
fengyuan14 commented 1 month ago

@dvrogozh If you want on-par performance as torch.jit.trace/script + ipex, you can expect torch.compile of PT2.6. We plan to support same recipes in Inductor. You can keep the issue, or close it and open new one to track XPU performance of graph mode.

dvrogozh commented 3 weeks ago

With torch.compile() I am currently getting 217 fps which is lower than IPEX. I will continue tracking this towards PT2.6, hope we will see improvements in perf by this milestone. See https://github.com/intel/ai-reference-models/pull/189 for torch.compile support in the tried out samples.