iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.85k stars 614 forks source link

LLaMA/Vicuna 7B fp16 model precision issue on CPU #14014

Open yzhang93 opened 1 year ago

yzhang93 commented 1 year ago

What happened?

LLaMA/Vicuna 7B fp16 model is producing all zero results on CPU for both local-sync and local-task.

Steps to reproduce your issue

  1. Download the .mlir from https://storage.googleapis.com/shark-public/vivian/vicuna_fp16.mlir
  2. Compile the file with iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 --iree-llvmcpu-target-cpu-features=host --iree-llvmcpu-target-triple=x86_64-linux-gnu vicuna_fp16.mlir -o vicuna.vmfb Or directly download the compiled vmfb from https://storage.googleapis.com/shark_tank/dan/out16.vmfb
  3. Download the inputs: inp1: https://storage.googleapis.com/shark-public/prashant/traced_vicuna/inp1.npy and inp2: https://storage.googleapis.com/shark-public/prashant/traced_vicuna/inp2.npy
  4. Run via iree-module: iree-run-module --device=local-task --function=forward --input=@inp1.npy --input=@inp2.npy --module=vicuna.vmfb

Note: it may take a long time to get the results for the first time running the module. I also tried to use --iree-flow-break-dispatch=@forward: on different dispatches, but all generated zero results.

What component(s) does this issue relate to?

No response

Version information

No response

Additional context

No response

yzhang93 commented 1 year ago

Update: We made some changes on fx graph transform, and we are getting non-zero results. We have validated the outputs at the fx graph level. However, currently the IREE outputs are having precision issue. The IREE results don't match the pytorch ones. Especially, the first row has large differences.

pytorch output: [[[ -3.9557, -29.2099,   2.2765,  ...,  -2.0065,  -1.2465,  -2.1745],          [ -8.9753, -14.2879,   2.0630,  ...,  -3.9730,  -4.7902,  -3.5822],          [ -8.4618, -11.2923,   4.9130,  ...,  -1.1730,  -4.3431,  -3.7198],          ...,          [ -5.7914,  -5.1273,  11.0914,  ...,  -0.0792,  -1.0262,  -1.4256],          [ -6.0333,  -7.8192,  11.8842,  ...,   1.3139,  -0.2129,  -0.4712],          [-11.1994, -24.6870,   6.9853,  ...,  -4.7405,  -5.7670,  -2.4376]]],

iree output: [[[ -5.758  -31.48     2.834  ...  -3.117   -2.6     -3.223 ]   [ -8.75   -14.28     2.102  ...  -3.77    -4.66    -3.377 ]   [ -8.29   -11.71     4.97   ...  -0.9478  -4.08    -3.252 ]   ...   [ -5.957   -5.125   11.12   ...  -0.321   -1.041   -1.514 ]   [ -6.08    -7.953   11.99   ...   1.359   -0.3594  -0.384 ]   [-11.27   -24.58     6.895  ...  -4.78    -5.906   -2.6   ]]]