amd / RyzenAI-SW

MIT License
366 stars 60 forks source link

Will only some models run on IPU? #92

Closed rhenry74 closed 4 months ago

rhenry74 commented 4 months ago

How does one pick a pretrained model that will run on IPU? I'm trying to find one that will return embeddings. I got close to getting it working but it's like the runtime will not assign anything to IPU and falls back to CPU. Can anyone give advice on getting this model to work or recommend one that will?

from transformers import AutoTokenizer
import onnxruntime as ort
from optimum.pipelines import pipeline
from pathlib import Path
import onnx
import numpy
import vai_q_onnx

 # `input_model_path` is the path to the original, unquantized ONNX model.
quantize_input_model_path = "UAE-Large-V1/onnx/model.onnx"

    # `output_model_path` is the path where the quantized model will be saved.
quantize_output_model_path = "UAE-Large-V1/onnx/angle_QDQ_MinMax_QInt8.onnx"

if True:

        vai_q_onnx.quantize_static(
                quantize_input_model_path,
                quantize_output_model_path,
                calibration_data_reader=None,
                quant_format=vai_q_onnx.QuantFormat.QDQ,
                calibrate_method=vai_q_onnx.CalibrationMethod.MinMax,
                #calibrate_method=vai_q_onnx.PowerOfTwoMethod.MinMSE,
                activation_type=vai_q_onnx.QuantType.QInt8,
                weight_type=vai_q_onnx.QuantType.QInt8,
                enable_ipu_cnn=True, 
                #enable_ipu_transformer=True,
                #extra_options={'ActivationSymmetric': True} 
        )
        print('Calibrated and quantized model saved at:', quantize_output_model_path)

ipu=True

providers = ['VitisAIExecutionProvider']  if ipu else ['CPUExecutionProvider']
cache_dir = Path(__file__).parent.resolve()
provider_options = [{
                'config_file': 'vaip_config.json',
                'cacheDir': str(cache_dir),
                'cacheKey': 'modelcachekey'
            }] if ipu else [{}]

session_options = ort.SessionOptions()
session_options.enable_profiling = True

session = ort.InferenceSession(quantize_output_model_path, 
                               providers=providers,
                               sess_options=session_options,                               
                               provider_options=provider_options)

input_shape = session.get_inputs()

tokenizer = AutoTokenizer.from_pretrained('UAE-Large-V1')

tokens = tokenizer.encode_plus(text='what are we going to do today', 
                                  return_attention_mask=True,
                                  return_token_type_ids=True
                                  )

print(input_shape[0])
print(input_shape[1])
print(input_shape[2])

input_ids = numpy.array([tokens['input_ids']], dtype=numpy.int64)
attention_mask = numpy.array([tokens['attention_mask']], dtype=numpy.int64)
token_type_ids = numpy.array([tokens['token_type_ids']], dtype=numpy.int64)

shape = {'input_ids': input_ids
        ,'attention_mask': attention_mask
        ,'token_type_ids': token_type_ids}

outputs = session.run(None, shape)                              

print(outputs)
Microsoft Windows [Version 10.0.22631.3593]
(c) Microsoft Corporation. All rights reserved.

(ryzenai-1.1-20240519-203238) C:\IPU\RyzenAI-SW\tutorial\getting_started_resnet>C:/Users/rhenr/.conda/envs/ryzenai-1.1-20240519-203238/python.exe c:/IPU/RyzenAI-SW/tutorial/getting_started_resnet/angleOnnxA.py
[VAI_Q_ONNX_INFO]: Time information:
2024-06-07 00:41:20.352031
[VAI_Q_ONNX_INFO]: OS and CPU information:
                                        system --- Windows
                                          node --- miniai
                                       release --- 10
                                       version --- 10.0.22631
                                       machine --- AMD64
                                     processor --- AMD64 Family 25 Model 116 Stepping 1, AuthenticAMD
[VAI_Q_ONNX_INFO]: Tools version information:
                                        python --- 3.9.18
                                          onnx --- 1.16.0
                                   onnxruntime --- 1.15.1
                                    vai_q_onnx --- 1.16.0+69bc4f2
[VAI_Q_ONNX_INFO]: Quantized Configuration information:
                                   model_input --- UAE-Large-V1/onnx/model.onnx
                                  model_output --- UAE-Large-V1/onnx/angle_QDQ_MinMax_QInt8.onnx
                       calibration_data_reader --- None
                                  quant_format --- QDQ
                                   input_nodes --- []
                                  output_nodes --- []
                          op_types_to_quantize --- []
                random_data_reader_input_shape --- []
                                   per_channel --- False
                                  reduce_range --- False
                               activation_type --- QInt8
                                   weight_type --- QInt8
                             nodes_to_quantize --- []
                              nodes_to_exclude --- []
                                optimize_model --- True
                      use_external_data_format --- False
                              calibrate_method --- CalibrationMethod.MinMax
                           execution_providers --- ['CPUExecutionProvider']
                                enable_ipu_cnn --- True
                        enable_ipu_transformer --- False
                                    debug_mode --- False
                          convert_fp16_to_fp32 --- False
                          convert_nchw_to_nhwc --- False
                                   include_cle --- False
                               include_fast_ft --- False
                                 extra_options --- {}
INFO:vai_q_onnx.quantize:calibration_data_reader is None, using random data for calibration
INFO:vai_q_onnx.quant_utils:The input ONNX model UAE-Large-V1/onnx/model.onnx can create InferenceSession successfully
INFO:vai_q_onnx.quant_utils:Random input name input_ids shape [1, 16] type <class 'numpy.int64'> 
INFO:vai_q_onnx.quant_utils:Random input name attention_mask shape [1, 16] type <class 'numpy.int64'>
INFO:vai_q_onnx.quant_utils:Random input name token_type_ids shape [1, 16] type <class 'numpy.int64'>
INFO:vai_q_onnx.quant_utils:Obtained calibration data with 1 iters
INFO:vai_q_onnx.quant_utils:The input ONNX model UAE-Large-V1/onnx/model.onnx can run inference successfully
INFO:vai_q_onnx.quantize:Removed initializers from input
INFO:vai_q_onnx.quantize:Loading model...
INFO:vai_q_onnx.quantize:enable_ipu_cnn is True, optimize the model for better hardware compatibility.
INFO:vai_q_onnx.quantize:Start calibration...
INFO:vai_q_onnx.quantize:Start collecting data, runtime depends on your model size and the number of calibration dataset.
INFO:vai_q_onnx.qdq_quantizer:Remove QuantizeLinear & DequantizeLinear on certain operations(such as conv-relu).
┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Op Type              ┃ Float Model                                   ┃
┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ Unsqueeze            │ 195                                           │
│ Cast                 │ 1                                             │
│ Constant             │ 561                                           │
│ Sub                  │ 50                                            │
│ Mul                  │ 98                                            │
│ Shape                │ 193                                           │
│ Gather               │ 196                                           │
│ Slice                │ 1                                             │
│ Add                  │ 340                                           │
│ ReduceMean           │ 98                                            │
│ Pow                  │ 49                                            │
│ Sqrt                 │ 49                                            │
│ Div                  │ 97                                            │
│ MatMul               │ 192                                           │
│ Concat               │ 96                                            │
│ Reshape              │ 96                                            │
│ Transpose            │ 96                                            │
│ Softmax              │ 24                                            │
│ Erf                  │ 24                                            │
├──────────────────────┼───────────────────────────────────────────────┤
│ Quantized model path │ UAE-Large-V1/onnx/angle_QDQ_MinMax_QInt8.onnx │
└──────────────────────┴───────────────────────────────────────────────┘
Calibrated and quantized model saved at: UAE-Large-V1/onnx/angle_QDQ_MinMax_QInt8.onnx
2024-06-07 00:42:32.7123550 [W:onnxruntime:Default, vitisai_provider_factory.cc:48 onnxruntime::VitisAIProviderFactory::CreateProvider] Construting a FlexML EP instance in Vitis AI EP
2024-06-07 00:42:32.7157647 [W:onnxruntime:Default, vitisai_execution_provider.cc:117 onnxruntime::VitisAIExecutionProvider::SetFlexMLEPPtr] Assigning the FlexML EP pointer in Vitis AI EP
2024-06-07 00:42:33.8510556 [W:onnxruntime:Default, vitisai_execution_provider.cc:137 onnxruntime::VitisAIExecutionProvider::GetCapability] Trying FlexML EP GetCapability
2024-06-07 00:42:33.8547276 [W:onnxruntime:Default, flexml_execution_provider.cc:180 onnxruntime::FlexMLExecutionProvider::GetCapability] FlexMLExecutionProvider::GetCapability, C:\amd\voe\binary-modules\ResNet.flexml\flexml_bm.signature can't not be found!
2024-06-07 00:42:33.8589939 [W:onnxruntime:Default, vitisai_execution_provider.cc:153 onnxruntime::VitisAIExecutionProvider::GetCapability] FlexML EP ignoring a non-ResNet50 graph
WARNING: Logging before InitGoogleLogging() is written to STDERR
I20240607 00:42:33.861953 21852 vitisai_compile_model.cpp:346] Vitis AI EP Load ONNX Model Success
I20240607 00:42:33.861953 21852 vitisai_compile_model.cpp:347] Graph Input Node Name/Shape (3)
I20240607 00:42:33.861953 21852 vitisai_compile_model.cpp:351]   input_ids : [-1x-1]
I20240607 00:42:33.862957 21852 vitisai_compile_model.cpp:351]   attention_mask : [-1x-1]
I20240607 00:42:33.862957 21852 vitisai_compile_model.cpp:351]   token_type_ids : [-1x-1]
I20240607 00:42:33.862957 21852 vitisai_compile_model.cpp:357] Graph Output Node Name/Shape (1)
I20240607 00:42:33.862957 21852 vitisai_compile_model.cpp:361]   last_hidden_state : [-1x-1x1024]
I20240607 00:42:33.865455 21852 vitisai_compile_model.cpp:232] use cache key modelcachekey
I20240607 00:42:34.251026 21852 pass_main.cpp:245] [VITIS AI EP] This model is not a supported CNN model which will not be compiled with DPU.
[Vitis AI EP] No. of Operators :   CPU  4434 
NodeArg(name='input_ids', type='tensor(int64)', shape=['batch_size', 'sequence_length'])
NodeArg(name='attention_mask', type='tensor(int64)', shape=['batch_size', 'sequence_length'])
NodeArg(name='token_type_ids', type='tensor(int64)', shape=['batch_size', 'sequence_length'])
[array([[[ 0.04070915,  0.6309919 , -0.46815526, ..., -1.2823383 ,
          0.20354576, -0.18319118],
        [-0.06106373,  0.6309919 , -0.38673696, ..., -1.20092   ,
          0.3460278 ,  0.        ],
        [ 0.04070915,  0.6309919 , -0.38673696, ..., -1.1602108 ,
          0.3460278 ,  0.        ],
        ...,
        [ 0.04070915,  0.6920556 , -0.14248204, ..., -1.3841112 ,
          0.20354576, -0.06106373],
        [ 0.04070915,  0.5088644 , -0.32567322, ..., -1.1602108 ,
          0.04070915,  0.        ],
        [-0.02035458,  0.6920556 , -0.28496408, ..., -1.2823383 ,
          0.20354576,  0.        ]]], dtype=float32)]
2024-06-07 00:42:35.3144377 [W:onnxruntime:Default, vitisai_execution_provider.cc:74 onnxruntime::VitisAIExecutionProvider::~VitisAIExecutionProvider] Releasing the FlexML EP pointer in Vitis AI EP

(ryzenai-1.1-20240519-203238) C:\IPU\RyzenAI-SW\tutorial\getting_started_resnet>
uday610 commented 4 months ago

Hi @rhenry74 ,

Here is how you can run that model on NPU:

Quantization Step

Assuming you have model.onnx in the current folder

quantize.py

from optimum.onnxruntime import ORTQuantizer, AutoQuantizationConfig

dqconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False, use_symmetric_activations=True, operators_to_quantize=["MatMul"],)

quantizer = ORTQuantizer.from_pretrained(".")
quantizer.quantize(save_dir="./", quantization_config=dqconfig )  

You get the quantized model model_quantized.onnx

Inference

infer.py

import onnxruntime as ort
from transformers import AutoTokenizer
from pathlib import Path

import numpy
import builtins

quantize_output_model_path = "model_quantized.onnx"

#providers = ['CPUExecutionProvider']
#provider_options = [{}]

providers = ['VitisAIExecutionProvider']  
cache_dir = Path(__file__).parent.resolve()
provider_options = [{
                'config_file': 'vaip_config.json',
                'cacheDir': str(cache_dir),
                'cacheKey': 'modelcachekey'
            }]

session_options = ort.SessionOptions()
session_options.enable_profiling = True
builtins.impl = "v0"
builtins.quant_mode = "w8a8"
session = ort.InferenceSession(quantize_output_model_path, 
                               providers=providers,
                               sess_options=session_options,                               
                               provider_options=provider_options)

input_shape = session.get_inputs()

tokenizer = AutoTokenizer.from_pretrained('UAE-Large-V1')

tokens = tokenizer.encode_plus(text='what are we going to do today', 
                                  return_attention_mask=True,
                                  return_token_type_ids=True
                                  )

print(input_shape[0])
print(input_shape[1])
print(input_shape[2])

input_ids = numpy.array([tokens['input_ids']], dtype=numpy.int64)
attention_mask = numpy.array([tokens['attention_mask']], dtype=numpy.int64)
token_type_ids = numpy.array([tokens['token_type_ids']], dtype=numpy.int64)

shape = {'input_ids': input_ids
        ,'attention_mask': attention_mask
        ,'token_type_ids': token_type_ids}

outputs = session.run(None, shape)                              

print(outputs)
  1. Make sure you have installed the following from the ryzen-ai-sw-1.1.zip
cd ryzen-ai-sw-1.1\ryzen-ai-sw-1.1\voe-4.0-win_amd64
pip install voe-0.1.0-cp39-cp39-win_amd64.whl --force-reinstall
pip install onnxruntime_vitisai-1.15.1-cp39-cp39-win_amd64.whl --force-reinstall
python installer.py
  1. Setting environments
git lfs install
git clone https://github.com/amd/RyzenAI-SW.git
cd RyzenAI-SW\example\transformers
.\setup.bat
cd models\opt-onnx
.\set_opt_onnx_env.bat
cd <your previous folder where infer.py is>
  1. Copy vaip_config.json from the RyzenAI-SW\example\transformers\models\opt-onnx folder for the running with infer.py

  2. Run infer.py, before that delete modelcachekey folder if it exists

rhenry74 commented 4 months ago

I tried to get this working from my functional RyzenAI-SW\tutorial\getting_started_resnet dir / env, and I came close but the best I got was some activity on the IPU then a silient (no error message) failure on outputs = session.run(None, shape)

So I tried starting from scratch:

I started with conda create -n "embeddings39" python=3.9 ipython because I think I read somewhere that Ryzen AI wants Python 3.9. (embeddings39) C:\IPU\ryzen-ai-sw-1.1\voe-4.0-win_amd64>pip install voe-0.1.0-cp39-cp39-win_amd64.whl --force-reinstall worked fine

(embeddings39) C:\IPU\ryzen-ai-sw-1.1\voe-4.0-win_amd64>pip install onnxruntime_vitisai-1.15.1-cp39-cp39-win_amd64.whl --force-reinstall failed with:

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torch 2.3.0 requires fsspec, which is not installed.

so i ran pip install fsspec and this time pip install onnxruntime_vitisai-1.15.1-cp39-cp39-win_amd64.whl --force-reinstall worked.

The rest worked until I tried to run quantize.py and it complained:

(embeddings39) C:\IPU\Embeddings>python quantize.py
Traceback (most recent call last):
  File "C:\IPU\Embeddings\quantize.py", line 1, in <module>
    from optimum.onnxruntime import ORTQuantizer, AutoQuantizationConfig
ModuleNotFoundError: No module named 'optimum'

So I ran (embeddings39) C:\IPU\Embeddings>pip install optimum and tried quantize.py again. This time I got:

 (embeddings39) C:\IPU\Embeddings>python quantize.py
Traceback (most recent call last):
  File "C:\IPU\Embeddings\quantize.py", line 1, in <module>
    from optimum.onnxruntime import ORTQuantizer, AutoQuantizationConfig
  File "C:\Users\rhenr\.conda\envs\embeddings39\lib\site-packages\optimum\onnxruntime\__init__.py", line 16, in <module>
    from transformers.utils import OptionalDependencyNotAvailable, _LazyModule
  File "C:\Users\rhenr\.conda\envs\embeddings39\lib\site-packages\transformers\__init__.py", line 26, in <module>
    from . import dependency_versions_check
  File "C:\Users\rhenr\.conda\envs\embeddings39\lib\site-packages\transformers\dependency_versions_check.py", line 16, in <module>
    from .utils.versions import require_version, require_version_core
  File "C:\Users\rhenr\.conda\envs\embeddings39\lib\site-packages\transformers\utils\__init__.py", line 33, in <module>
    from .generic import (
  File "C:\Users\rhenr\.conda\envs\embeddings39\lib\site-packages\transformers\utils\generic.py", line 461, in <module>
    import torch.utils._pytree as _torch_pytree
  File "C:\Users\rhenr\AppData\Roaming\Python\Python39\site-packages\torch\__init__.py", line 141, in <module>
    raise err
OSError: [WinError 126] The specified module could not be found. Error loading "C:\Users\rhenr\AppData\Roaming\Python\Python39\site-packages\torch\lib\shm.dll" or one of its dependencies.

I had solved this one in the past by installing pytorch as suggested on their website: conda install pytorch torchvision torchaudio cpuonly -c pytorch

Ran quantize.py again and:

(embeddings39) C:\IPU\Embeddings>python quantize.py
Traceback (most recent call last):
  File "C:\Users\rhenr\.conda\envs\embeddings39\lib\site-packages\transformers\utils\import_utils.py", line 1535, in _get_module
    return importlib.import_module("." + module_name, self.__name__)
  File "C:\Users\rhenr\.conda\envs\embeddings39\lib\importlib\__init__.py", line 127, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1030, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1007, in _find_and_load
  File "<frozen importlib._bootstrap>", line 986, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 680, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 850, in exec_module
  File "<frozen importlib._bootstrap>", line 228, in _call_with_frames_removed
  File "C:\Users\rhenr\.conda\envs\embeddings39\lib\site-packages\optimum\onnxruntime\quantization.py", line 22, in <module>
    import onnx
ModuleNotFoundError: No module named 'onnx'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "C:\IPU\Embeddings\quantize.py", line 1, in <module>
    from optimum.onnxruntime import ORTQuantizer, AutoQuantizationConfig
  File "<frozen importlib._bootstrap>", line 1055, in _handle_fromlist
  File "C:\Users\rhenr\.conda\envs\embeddings39\lib\site-packages\transformers\utils\import_utils.py", line 1525, in __getattr__
    module = self._get_module(self._class_to_module[name])
  File "C:\Users\rhenr\.conda\envs\embeddings39\lib\site-packages\transformers\utils\import_utils.py", line 1537, in _get_module
    raise RuntimeError(
RuntimeError: Failed to import optimum.onnxruntime.quantization because of the following error (look up to see its traceback):
No module named 'onnx'

So i did pip install onnx and tried quantize.py again it successfully created the quantized model. image

Cool, time to run infer.py...


(embeddings39) C:\IPU\Embeddings>python infer.py
2024-06-10 22:05:09.3856519 [W:onnxruntime:Default, vitisai_provider_factory.cc:48 onnxruntime::VitisAIProviderFactory::CreateProvider] Construting a FlexML EP instance in Vitis AI EP
2024-06-10 22:05:09.3901217 [W:onnxruntime:Default, vitisai_execution_provider.cc:117 onnxruntime::VitisAIExecutionProvider::SetFlexMLEPPtr] Assigning the FlexML EP pointer in Vitis AI EP
2024-06-10 22:05:09.6417864 [W:onnxruntime:Default, vitisai_execution_provider.cc:137 onnxruntime::VitisAIExecutionProvider::GetCapability] Trying FlexML EP GetCapability
2024-06-10 22:05:09.6452556 [W:onnxruntime:Default, flexml_execution_provider.cc:180 onnxruntime::FlexMLExecutionProvider::GetCapability] FlexMLExecutionProvider::GetCapability, C:\amd\voe\binary-modules\ResNet.flexml\flexml_bm.signature can't not be found!
2024-06-10 22:05:09.6492889 [W:onnxruntime:Default, vitisai_execution_provider.cc:153 onnxruntime::VitisAIExecutionProvider::GetCapability] FlexML EP ignoring a non-ResNet50 graph
WARNING: Logging before InitGoogleLogging() is written to STDERR
I20240610 22:05:09.654853  7368 vitisai_compile_model.cpp:346] Vitis AI EP Load ONNX Model Success
I20240610 22:05:09.654853  7368 vitisai_compile_model.cpp:347] Graph Input Node Name/Shape (3)
I20240610 22:05:09.654853  7368 vitisai_compile_model.cpp:351]   input_ids : [-1x-1]
I20240610 22:05:09.654853  7368 vitisai_compile_model.cpp:351]   attention_mask : [-1x-1]
I20240610 22:05:09.654853  7368 vitisai_compile_model.cpp:351]   token_type_ids : [-1x-1]
I20240610 22:05:09.654853  7368 vitisai_compile_model.cpp:357] Graph Output Node Name/Shape (1)
I20240610 22:05:09.654853  7368 vitisai_compile_model.cpp:361]   last_hidden_state : [-1x-1x1024]
I20240610 22:05:09.656965  7368 vitisai_compile_model.cpp:232] use cache key ipucachekey
[Vitis AI EP] No. of Operators :   CPU  1847 MATMULINTEGER   144
[Vitis AI EP] No. of Subgraphs :MATMULINTEGER   144
2024-06-10 22:05:27.8509118 [W:onnxruntime:, session_state.cc:1169 onnxruntime::VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.
2024-06-10 22:05:27.8576983 [W:onnxruntime:, session_state.cc:1171 onnxruntime::VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.
NodeArg(name='input_ids', type='tensor(int64)', shape=['batch_size', 'sequence_length'])
NodeArg(name='attention_mask', type='tensor(int64)', shape=['batch_size', 'sequence_length'])
NodeArg(name='token_type_ids', type='tensor(int64)', shape=['batch_size', 'sequence_length'])

(embeddings39) C:\IPU\Embeddings>

Close, but something is not right. It just fails with no error. I did get some activity on the IPU according to HWiNFO64.

The onnx model is here huggingface WhereIsAI UAE-Large-V1

zip of dir on my goggle drive with everything but the model and the quantized model

(embeddings39) C:\IPU\Embeddings>set
ALLUSERSPROFILE=C:\ProgramData
APPDATA=C:\Users\rhenr\AppData\Roaming
AWQ_CACHE=C:\IPU\RyzenAI-SW\example\transformers\\ext\awq_cache\
CommonProgramFiles=C:\Program Files\Common Files
CommonProgramFiles(x86)=C:\Program Files (x86)\Common Files
CommonProgramW6432=C:\Program Files\Common Files
COMPUTERNAME=MINIAI
ComSpec=C:\Windows\system32\cmd.exe
CONDA_DEFAULT_ENV=embeddings39
CONDA_EXE=C:\ProgramData\anaconda3\condabin\..\Scripts\conda.exe
CONDA_EXES="C:\ProgramData\anaconda3\condabin\..\Scripts\conda.exe"
CONDA_PREFIX=C:\Users\rhenr\.conda\envs\embeddings39
CONDA_PREFIX_1=C:\ProgramData\anaconda3
CONDA_PROMPT_MODIFIER=(embeddings39)
CONDA_PYTHON_EXE=C:\ProgramData\anaconda3\python.exe
CONDA_SHLVL=2
DEVICE=phx
DriverData=C:\Windows\System32\Drivers\DriverData
EFC_7740=1
HOMEDRIVE=C:
HOMEPATH=\Users\rhenr
LOCALAPPDATA=C:\Users\rhenr\AppData\Local
LOGONSERVER=\\MINIAI
NUMBER_OF_PROCESSORS=16
OneDrive=C:\Users\rhenr\OneDrive
OneDriveConsumer=C:\Users\rhenr\OneDrive
OS=Windows_NT
Path=C:\Users\rhenr\.conda\envs\embeddings39;C:\Users\rhenr\.conda\envs\embeddings39\Library\mingw-w64\bin;C:\Users\rhenr\.conda\envs\embeddings39\Library\usr\bin;C:\Users\rhenr\.conda\envs\embeddings39\Library\bin;C:\Users\rhenr\.conda\envs\embeddings39\Scripts;C:\Users\rhenr\.conda\envs\embeddings39\bin;C:\ProgramData\anaconda3\condabin;C:\Program Files\Python39\Scripts;C:\Program Files\Python39;C:\Windows\System32\AMD;C:\Windows\system32;C:\Windows;C:\Windows\System32\Wbem;C:\Windows\System32\WindowsPowerShell\v1.0;C:\Windows\System32\OpenSSH;C:\Program Files\dotnet;C:\Users\rhenr\AppData\Local\Microsoft\WindowsApps;C:\Users\rhenr\.dotnet\tools;C:\Program Files\CMake\bin;C:\Program Files\Microsoft SQL Server\130\Tools\Binn;C:\Program Files\Microsoft SQL Server\Client SDK\ODBC\170\Tools\Binn;C:\Program Files\Microsoft VS Code\bin;C:\Program Files\Git\cmd;C:\Users\rhenr\AppData\Local\Microsoft\WindowsApps;C:\Users\rhenr\.dotnet\tools;C:\IPU\RyzenAI-SW\example\transformers\third_party\lib;C:\IPU\RyzenAI-SW\example\transformers\third_party\bin;C:\IPU\RyzenAI-SW\example\transformers\ops\cpp;C:\IPU\RyzenAI-SW\example\transformers\third_party
PATHEXT=.COM;.EXE;.BAT;.CMD;.VBS;.VBE;.JS;.JSE;.WSF;.WSH;.MSC;.PY;.PYW
PROCESSOR_ARCHITECTURE=AMD64
PROCESSOR_IDENTIFIER=AMD64 Family 25 Model 116 Stepping 1, AuthenticAMD
PROCESSOR_LEVEL=25
PROCESSOR_REVISION=7401
ProgramData=C:\ProgramData
ProgramFiles=C:\Program Files
ProgramFiles(x86)=C:\Program Files (x86)
ProgramW6432=C:\Program Files
PROMPT=(embeddings39) $P$G
PSModulePath=C:\Program Files\WindowsPowerShell\Modules;C:\Windows\system32\WindowsPowerShell\v1.0\Modules
PUBLIC=C:\Users\Public
PWD=C:\IPU\RyzenAI-SW\example\transformers\
PYTHONPATH=;C:\IPU\RyzenAI-SW\example\transformers\\third_party\lib;C:\IPU\RyzenAI-SW\example\transformers\\third_party\bin;C:\IPU\RyzenAI-SW\example\transformers\\third_party;C:\IPU\RyzenAI-SW\example\transformers\\ops\python;C:\IPU\RyzenAI-SW\example\transformers\\onnx-ops\python;C:\IPU\RyzenAI-SW\example\transformers\\tools;C:\IPU\RyzenAI-SW\example\transformers\\ext\smoothquant\smoothquant;C:\IPU\RyzenAI-SW\example\transformers\\ext\smoothquant\smoothquant;C:\IPU\RyzenAI-SW\example\transformers\\ext\llm-awq;C:\IPU\RyzenAI-SW\example\transformers\\ext\llm-awq\awq\quantize;C:\IPU\RyzenAI-SW\example\transformers\\ext\llm-awq\awq\utils;C:\IPU\RyzenAI-SW\example\transformers\\ext\llm-awq\awq\kernels
PYTORCH_AIE_PATH=C:\IPU\RyzenAI-SW\example\transformers\
SESSIONNAME=Console
SSL_CERT_DIR=C:\Users\rhenr\.conda\envs\embeddings39\Library\ssl\certs
SSL_CERT_FILE=C:\ProgramData\anaconda3\Library\ssl\cacert.pem
SystemDrive=C:
SystemRoot=C:\Windows
TEMP=C:\Users\rhenr\AppData\Local\Temp
THIRD_PARTY=C:\IPU\RyzenAI-SW\example\transformers\\third_party
TMP=C:\Users\rhenr\AppData\Local\Temp
TVM_DLL_NUM=2
TVM_GEMM_M=1,8,
TVM_LIBRARY_PATH=C:\IPU\RyzenAI-SW\example\transformers\\third_party\lib;C:\IPU\RyzenAI-SW\example\transformers\\third_party\bin
TVM_MODULE_PATH=C:\IPU\RyzenAI-SW\example\transformers\models\opt-onnx\\..\..\dll\phx\qlinear\libGemmQnnAie_1x2048_2048x2048.dll,C:\IPU\RyzenAI-SW\example\transformers\models\opt-onnx\\..\..\dll\phx\qlinear\libGemmQnnAie_8x2048_2048x2048.dll,
USERDOMAIN=MINIAI
USERDOMAIN_ROAMINGPROFILE=MINIAI
USERNAME=rhenr
USERPROFILE=C:\Users\rhenr
windir=C:\Windows
XLNX_VART_FIRMWARE=C:\IPU\RyzenAI-SW\example\transformers\/xclbin/phx
XML_CATALOG_FILES=file:///C:/Users/rhenr/.conda/envs/embeddings39/etc/xml/catalog
XRT_PATH=C:\IPU\RyzenAI-SW\example\transformers\\third_party\xrt-ipu
__CONDA_OPENSLL_CERT_FILE_SET="1"
__CONDA_OPENSSL_CERT_DIR_SET="1"

What should I try next? Is there an option that I can set to get it to generate a log maybe?

uday610 commented 4 months ago

Ok, I see the problem. I can send you a voe package that you can install and try.

pip install voe-0.1.0-cp39-cp39-win_amd64.whl --force-reinstall And then, make sure to follow 2, 3, and 4th steps of my previous message.

If you let me know your email, I can share the voe package.

Thanks

rhenry74 commented 4 months ago

OK, Yes I think it is working! now my challenge is to package this into an installable product Thank you for your help Will this team make this package available to the community at large? I hope so.

uday610 commented 4 months ago

Yes, in next version of the release

qz233 commented 4 months ago

Hi @uday610

I just encountered the same issue which the code returned nothing while no error occurred. Can you send me the updated voe package? Thanks

uday610 commented 4 months ago

what model and flow you are running?

qz233 commented 4 months ago

I am working on a TTS model GPT-SoVITS (https://github.com/RVC-Boss/GPT-SoVITS). Among its components, I only plan to optimize the transformer that predict latent speech tokens, the most time consuming part. Its has a pure gpt-2 structure but with kv-cache.

(P.S. I am kind of worried whether inputs with 0 length on certain axis might cause error during optimization. During the first time forward input k and v has zero sequence length)

qz233 commented 4 months ago

@uday610 I am following the opt-onnx flow, already produced dynamic quantized onnx model, and end up executing fail with no error message as well. Judging from rhenry74's case I'm just one step behind goal?

uday610 commented 4 months ago

Yes, if you completely follow the OPT-ONNX flow including quantization as

dqconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False, use_symmetric_activations=True, operators_to_quantize=["MatMul"],)

and then setting environment like OPT-ONNX flow, then let me know your email, I will share that EP version

qz233 commented 4 months ago

@uday610 Sorry I forgot to set my email as public. Its a2576658523@163.com. And thanks for your aid!