Closed MarioCavero closed 5 months ago
Hi @MarioCavero,
You are correct that the operations listed are not supported, but this is because these operations are used to load weights that are external to the graph file. This is resolved by freezing the weights into the graph.
Rather than calling the neuron-cc
compiler directly, the recommended approach to compiling a model is through the tensorflow-neuron
tracing API. When you call this API, any Variable
operations (such as ReadVariableOp
) will be replaced with a constant equivalent.
For example:
import tensorflow as tf
import tensorflow.neuron as tfn
example_input = tf.random.uniform([1, 256, 256, 3], dtype=tf.dtypes.int32)
neuron_model = tfn.trace(model, example_input)
Would you be able to check if the tfn.trace
API resolves this issue?
Hi and thanks for the quick reply @jluntamazon . This is the first thing I tried, I did not post it to the question, my bad! I will post the full log, but the most important message is this one:
2023-11-04 15:42:35.665809: I tensorflow/neuron/grappler/convert/segment.cc:456] There are 14 ops of 6 different types in the graph that are not compiled by neuron-cc: GatherNd, FloorDiv, ArgMax, ResizeBilinear, NoOp, Placeholder, (For more information see https://awsdocs-neuron.readthedocs-hosted.com/en/latest/release-notes/neuron-cc-ops/neuron-cc-ops-tensorflow.html).
And this one (I mentioned it in the question, about having negative % of compiled operators):
WARNING:tensorflow:Warning: Your traced model has -11200.0% of operators compiled to neuron
.
This is the code I used, loading the movenet thunder model, downloadable in here from tfhub:
import tensorflow as tf
import tensorflow.neuron as tfn
import sys
import os
movenet = tf.saved_model.load(model_path)
model = movenet.signatures["serving_default"]
model_func = model
example_input_shape = (1, 256, 256, 3)
# # In case encapsulating is needed
# def model_func(input):
# return model(input)
example_input = tf.random.uniform(example_input_shape)
print(f'Example input shape: {example_input.shape}')
model_neuron = tfn.trace(model_func, example_input)
model_neuron.save(compiled_model_dir)
And the complete log:
python compile_model.py models/movenet_singlepose_thunder_4/
2023-11-04 15:42:21.409595: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-11-04 15:42:21.526974: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-11-04 15:42:21.527004: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
2023-11-04 15:42:21.552209: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-11-04 15:42:22.252577: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-11-04 15:42:22.252650: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-11-04 15:42:22.252664: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2023-11-04 15:42:22.726762: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2023-11-04 15:42:22.726791: W tensorflow/stream_executor/cuda/cuda_driver.cc:263] failed call to cuInit: UNKNOWN ERROR (303)
2023-11-04 15:42:22.726812: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (ip-172-31-34-116.eu-west-3.compute.internal): /proc/driver/nvidia/version does not exist
2023-11-04 15:42:23.503221: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Example input shape: (1, 256, 256, 3)
2023-11-04 15:42:34.039337: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2023-11-04 15:42:34.039489: I tensorflow/core/grappler/clusters/single_machine.cc:358] Starting new session
2023-11-04 15:42:35.401120: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2023-11-04 15:42:35.401251: I tensorflow/core/grappler/clusters/single_machine.cc:358] Starting new session
2023-11-04 15:42:35.665809: I tensorflow/neuron/grappler/convert/segment.cc:456] There are 14 ops of 6 different types in the graph that are not compiled by neuron-cc: GatherNd, FloorDiv, ArgMax, ResizeBilinear, NoOp, Placeholder, (For more information see https://awsdocs-neuron.readthedocs-hosted.com/en/latest/release-notes/neuron-cc-ops/neuron-cc-ops-tensorflow.html).
2023-11-04 15:42:35.978566: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2023-11-04 15:42:35.978703: I tensorflow/core/grappler/clusters/single_machine.cc:358] Starting new session
2023-11-04 15:42:36.348731: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2023-11-04 15:42:36.348851: I tensorflow/core/grappler/clusters/single_machine.cc:358] Starting new session
2023-11-04 15:42:36.359685: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2023-11-04 15:42:36.359752: I tensorflow/core/grappler/clusters/single_machine.cc:358] Starting new session
2023-11-04 15:42:36.370634: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2023-11-04 15:42:36.370722: I tensorflow/core/grappler/clusters/single_machine.cc:358] Starting new session
2023-11-04 15:42:36.396288: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2023-11-04 15:42:36.396357: I tensorflow/core/grappler/clusters/single_machine.cc:358] Starting new session
2023-11-04 15:42:36.605126: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-11-04 15:42:36.627233: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
....
Compiler status PASS
2023-11-04 15:44:20.363853: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-11-04 15:44:20.366976: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
.
Compiler status PASS
2023-11-04 15:44:25.333002: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-11-04 15:44:25.336163: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
.
Compiler status PASS
2023-11-04 15:44:30.377428: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-11-04 15:44:30.380709: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
.
Compiler status PASS
2023-11-04 15:44:38.887911: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-11-04 15:44:38.896465: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
.
Compiler status PASS
WARNING:tensorflow:Warning: Your traced model has -11200.0% of operators compiled to neuron.
Model compiled and saved to: models/movenet_singlepose_thunder_4/converted_
If I may add to this, doing it for a .h5 keras model I trained, leads to something similar. The model was loaded differently, but accordingly for a keras model:
model = tf.keras.models.load_model(model_path, compile=False)
And using an accordingly example_input as well.
WARNING:tensorflow:Warning: Your traced model has -280.7692307692308% of operators compiled to neuron.
And:
I tensorflow/neuron/grappler/convert/segment.cc:456] There are 92 ops of 21 different types in the graph that are not compiled by neuron-cc: TensorListSetItem, Tanh, Sigmoid, Enter, TensorListFromTensor, Less, Merge, NoOp, Placeholder, MatMul, LoopCond, Exit, Switch, NextIteration, TensorListStack, Identity, Mul, TensorListGetItem, AddV2, Split, BiasAdd, (For more information see https://awsdocs-neuron.readthedocs-hosted.com/en/latest/release-notes/neuron-cc-ops/neuron-cc-ops-tensorflow.html).
Hello, thank you for the model reference!
Using the model we were able to reproduce the logs that you see. The log message of the number of supported operators is incorrect and we will look into implementing a fix.
The warning that a set of operators is unsupported can be expected when models use operators that are not well-supported on Neuron hardware. An example is the GatherNd
operation which we offload to CPU by default.
Otherwise, the model should be working and able to produce valid results.
One warning regarding the second trace: There are more unsupported operators which may lead to poor performance if they are evenly distributed throughout the model. During partitioning, Neuron will have to execute on hardware, copy intermediary data to CPU to execute the partitioned operation, and then transfer data back to Neuron. If this data movement is frequent, compute performance improvements are lost on data copying overhead.
For now, we will keep this ticket open and update when we have a fix for the logging error. Let us know if you have any other problems with running the model.
Thank you again! @jluntamazon I should have posted all in the first post. The first time I transformed it, I ignored the negative % conversion and ran it anyway. The reason I decided to use the neuron-cc command was to specify the config file, due to some errors I was obtaining when trying to run inferences in the converted model. Movenet needs ints
as input, but after converting, the expected tensors are floats
. Transforming those into floats
, another error shows up stating that the model is expecting ints
, but floats
were given. I proceed to log. The following code worked on CPU and GPU with the pipeline of data obtention I am using. In this case, images are 256x256.
def run_inference(model, input_size, image):
image_width, image_height = image.shape[1], image.shape[0]
assert image_width == input_size and image_height == input_size, "Image dimensions do not match the expected input size!"
input_image = image.reshape(-1, input_size, input_size, 3)
input_image = tf.cast(input_image, dtype=tf.int32)
print(f'Running movenet inference!')
outputs = model(input_image) # Removed "input= "
return outputs
Running this, gives the following error message:
Python inputs incompatible with input_signature:
inputs: (
tf.Tensor(
[[[[ 15 8 2]
[ 14 7 1]
[ 16 9 3]
...
...
[172 172 171]
[170 170 168]
[165 165 163]]]], shape=(1, 256, 256, 3), dtype=int32))
input_signature: (
TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name='input_1')).
If changing the casting to float and normalising = input_image = tf.cast(input_image, dtype=tf.float32) / 255.0
Failed due to Exception encountered when calling layer "aws_neuron_model" " f"(type AwsNeuronModel).
Graph execution error:
Expects arg[0] to be int32 but float is provided
[[{{node StatefulPartitionedCall}}]] [Op:__inference_restored_function_body_393]
Call arguments received by layer "aws_neuron_model" " f"(type AwsNeuronModel):
• args=('tf.Tensor(shape=(1, 256, 256, 3), dtype=float32)',)
• kwargs=<class 'inspect._empty'>
I thought that without the config file, the compilation of the model automatically assigned it to floats. I either am doing something wrong or the compilation is doing something strange with data types! Thanks again! I hope this output and logs help! Let me know if you need more testing/debugging!
So if your neuron_model is expecting floats it may be because the example_inputs you passed to trace were also floats. So to make it work, during trace you would have to pass inputs of the same dtype as the original model expects. For example, we are able to run this script successfully by making example_input an int32 tensor:
import tensorflow as tf
import tensorflow.neuron as tfn
import tensorflow_hub as hub
movenet = hub.load('https://tfhub.dev/google/movenet/singlepose/thunder/4')
model = movenet.signatures["serving_default"]
model_func = model
example_input_shape = (1, 256, 256, 3)
example_input = tf.random.uniform(example_input_shape, dtype=tf.dtypes.int32, maxval=256)
model_neuron = tfn.trace(model_func, example_input)
for i in range(1000):
print(model_neuron(example_input))
So if your neuron_model is expecting floats it may be because the example_inputs you passed to trace were also floats. So to make it work, during trace you would have to pass inputs of the same dtype as the original model expects. For example, we are able to run this script successfully by making example_input an int32 tensor:
import tensorflow as tf
import tensorflow.neuron as tfn
import tensorflow_hub as hub
movenet = hub.load('https://tfhub.dev/google/movenet/singlepose/thunder/4')
model = movenet.signatures["serving_default"]
model_func = model
example_input_shape = (1, 256, 256, 3)
example_input = tf.random.uniform(example_input_shape, dtype=tf.dtypes.int32, maxval=256)
model_neuron = tfn.trace(model_func, example_input)
for i in range(1000):
print(model_neuron(example_input))
HI MarioCavero - did aws-rhsoln's comment resolve your issue?
Thanks for the follow up @aws-donkrets . Both models produce inferences but are not fully optimized! I was waiting to close the issue if The log message of the number of supported operators is incorrect and we will look into implementing a fix.
had a fix. But the issue can be closed I guess! Feel free to close it and/or re-open it to post that the operator messages are fixed.
@MarioCavero a better log message will be available in an upcoming release. Closing the ticket for now. Thanks.
I wanted to convert 2 models for usage in inf1, the movenet model and another model which was saved from keras as a .h5 file. The movenet model is a tensorflow model (if I am not mistaken), saved in a folder as follows
assets saved_model.pb variables
.I run a script to compile them. After some time, it seemed to work for both (although the operator traced model gave as output a huge negative % of operators compiled, so therefore the compiled models were not working.
I dug out a bit more, and taking into account movenet's structure of input and output data, I decided to use the command line tool for neuron:
neuron-cc compile models/movenet_singlepose_thunder_4 --framework TENSORFLOW --output models/compiled_movenet_singlepose_thunder_4 --io-config config/io_config.json
And io_config.json, after inspecting the model:The neuron_cc command gives
Failed to parse model /home/ec2-user/inference_server/models/movenet_singlepose_thunder_4: The following operators are not implemented: {'ReadVariableOp', 'VarHandleOp', 'StatefulPartitionedCall'} (NotImplementedError)
. Some information: I used two scripts for the installation and set up, code from the docs:and:
To check for version information:
And:
Are these operators not implemented yet, or do I have a problem while setting up neuron?