intel / torch-xpu-ops

Apache License 2.0
25 stars 18 forks source link

[Release/2.5.0] Timm model jx_nest_base amp_fp16 inference got fail_accuracy #900

Open mengfei25 opened 3 weeks ago

mengfei25 commented 3 weeks ago

šŸ› Describe the bug

xpu eval jx_nest_base
[WARNING] Failed to create Level Zero tracer: 2013265921 (I): Detected 1024 spills, recompiling the kernel using large GRF mode (I): Kernel has now 0 spills (I): Detected 8192 spills, recompiling the kernel using large GRF mode (I): Kernel has now 0 spills (I): Detected 8192 spills, recompiling the kernel using large GRF mode (I): Kernel has now 0 spills (I): Detected 4096 spills, recompiling the kernel using large GRF mode (I): Kernel has now 0 spills (I): Detected 4096 spills, recompiling the kernel using large GRF mode (I): Kernel has now 0 spills E0912 00:16:10.029000 3264502 site-packages/torch/_dynamo/utils.py:1802] RMSE (res-fp64): 0.00087, (ref-fp64): 0.00036 and shape=torch.Size([8, 1000]). res.dtype: torch.float16, multiplier: 2.000000, tol: 0.001000, use_larger_multiplier_for_smaller_tensor: 0 fail_accuracy

Versions

pytorch: 2.5.0-rc1 (https://download.pytorch.org/whl/test/xpu) torch-xpu-ops: https://github.com/intel/torch-xpu-ops/commit/12065904d4c3c870059d746eb0fb45a0459f1d6d (main)

jianyizh commented 1 week ago

@mengfei25 I can pass this test locally on pvc 1550. xpu eval jx_nest_base (I): Detected 5504 spills, recompiling the kernel using large GRF mode (I): Kernel has now 0 spills (I): Detected 5504 spills, recompiling the kernel using large GRF mode (I): Kernel has now 0 spills (I): Detected 2304 spills, recompiling the kernel using large GRF mode (I): Kernel has now 0 spills (I): Detected 1792 spills, recompiling the kernel using large GRF mode (I): Kernel has now 0 spills Compiled module path: /home/sdp/jianyi/pytorch/inductor_log/timm_models/jx_nest_base/inference/amp_fp16/22/tmpdnoznzfe/ew/cewlskfnrewrjvkri3yjlsyn2wgwtgmzpz3s2ihdmqao5ack4a3p.py W0926 19:40:25.206000 1566446 torch/_inductor/debug.py:434] [0/0] jx_nest_base__0_inference_0 debug trace: /home/sdp/jianyi/pytorch/inductor_log/timm_models/jx_nest_base/inference/amp_fp16/22/torch_compile_debug/torch_compile_debug/run_2024_09_26_19_39_25_399754-pid_1566446/torchinductor/jx_nest_base__0_inference_0.0 pass

pytorch-triton-xpu 3.0.0+cc981feba1 torch 2.5.0a0+git4a3dabd /home/sdp/jianyi/pytorch