uber / neuropod

A uniform interface to run deep learning models from multiple frameworks
https://neuropod.ai
Apache License 2.0
936 stars 77 forks source link

Is nueropod designed to support tf.Example or sparse tensor ? #485

Open helinwang opened 3 years ago

helinwang commented 3 years ago

It's actually two separate questions:

  1. Is nueropod designed to support tf.Example? From the material I found, seems nueropod's design goal is: as long as model's input feature is a subset of the dataset's feature, Uber's production system can generate a single input and fed to any model for that dataset. And the single input's format has to be mapping from feature name to tensor value. So seems encoded model input like tf.Example does not align with the design? Since encoded model input is not strictly a subset (encoded subset, but not subset) of the dataset's features.

  2. Does nueropod support sparse tensor? Seems nueropod's input spec is for dense tensor only (at least for TF's case), because TF's sparse tensor is actually three dense tensors: value, indices, dense_shape, but there seems no such support in nueropod. Is it true that nueropod currently don't support sparse tensor, and is there any plan to support sparse tensor? And if sparse tensor is not supported, how does nueropod deal with missing values?

VivekPanyam commented 3 years ago

Hi! Thanks for your questions!

Question 1

Neuropod does not currently support tf.Example.

Based on my understanding of tf.train.Example, it's effectively a serializable dictionary mapping feature names to values (https://www.tensorflow.org/guide/data) and is often used alongside TFRecords for storing and loading datasets. Some models can also expect serialized tf.Examples as input (https://www.tensorflow.org/guide/estimator#savedmodels_from_estimators) instead of taking in individual tensors.

Because this is effectively an implementation detail of how a model takes in the input data, we should be able to implement this transparently in Neuropod (assuming there's a straightforward way to tell whether a model expects a serialized tf.Example or just tensors directly). We do something similar for TorchScript models where a model can expect input in one of several forms (Dict, Tensors, NamedTuple) and Neuropod transparently provides the input tensors in the format the model expects. See this test for more details.

Supporting tf.Example probably won't be particularly efficient though because it requires serializing all the tensors within Neuropod and then deserializing everything within the model (instead of just passing tensors directly to TF).

It seems like tf.estimator.export.build_raw_serving_input_receiver_fn can be used to export a model that takes in raw tensors rather than tf.train.Examples.

Other notes:

Neuropod currently does not have SavedModel support (as many of our initial TF usecases were with frozen graphs), but we're open to it. I'm not sure if using tf.Example requires SavedModel support, but I thought I'd mention it.

Question 2

We don't have sparse tensor support yet.

The main reason we didn't build it is we haven't had a need for it yet (at least not one anyone's asked about). Definitely open to it though if you have a usecase that needs it.

And if sparse tensor is not supported, how does nueropod deal with missing values?

It depends on how the tensor was created. For example, if it was created using zeros and then individual items were set, all the values that weren't explicitly set will be zero. If it was created using empty and individual items were set, there is no guaranteed value of any item in the tensor that wasn't explicitly set.

Basically the normal behavior you'd expect for dense tensors.

Please let me know if there's anything I can clarify!

helinwang commented 3 years ago

Hi Vivek, thanks for the detailed answers! I have got enough information from you answer, no feature request at this time :) Here are some follow up comments:

we should be able to implement this transparently in Neuropod

I think this is a good direction. One benefit of tf.Example is, TF's operator that parses tf.Example can produce different kinds of tensors: dense, sparse, ragged. This could extend the range of tensor that neuropod supports. There are other commonly used on-wire format (such as protobuf.Value) similar to tf.Example thats solves some problem tf.Example had, e.g., does not support n-d array when n > 1, etc.

Definitely open to it though if you have a use case that needs it

One use case is sparse tensor can be used to identify missing values, the model can do special handing to the missing values. I guess the model that neuropod serves so far don't need to handle missing values, i.e., the missing value are already imputed before fed to training / prediction.

helinwang commented 3 years ago

Hi @VigneshPeriasami may I ask some follow up question:

  1. Does neuropod support models trained by TensorFlow 2? I searched on PyPI for neuropod-backend-tensorflow https://pypi.org/search/?q=neuropod-backend-tensorflow, but haven't find backend for TF 2. The context is, I tried:
    
    from neuropod.loader import load_neuropod

with load_neuropod("./neuropod") as neuropod:

Do something here

pass
But it failed, the error was `Loading the default backend for type 'tensorflow' failed. Error from dlopen: libneuropod_tensorflow_backend.so: cannot open shared object file: No such file or directory`, so I think it's the backend problem. And I tried to find a backend for TF2, but couldn't find one.

2. How can neuropod package asset file like vocabulary, I could be wrong, but seems these files are dynamically loaded into the TF graph at runtime. For example, when trying to package the neuropod, I get the following error:
```bash
Traceback (most recent call last):
  File "/usr/local/google/home/helin/miniconda3/envs/tf_2_4/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/local/google/home/helin/miniconda3/envs/tf_2_4/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/usr/local/google/home/helin/miniconda3/envs/tf_2_4/lib/python3.7/site-packages/neuropod/loader.py", line 277, in <module>
    run_model(model)
  File "/usr/local/google/home/helin/miniconda3/envs/tf_2_4/lib/python3.7/site-packages/neuropod/loader.py", line 266, in run_model
    out = model.infer(input_data)
  File "/usr/local/google/home/helin/miniconda3/envs/tf_2_4/lib/python3.7/site-packages/neuropod/backends/neuropod_executor.py", line 184, in infer
    out = self.forward(inputs)
  File "/usr/local/google/home/helin/miniconda3/envs/tf_2_4/lib/python3.7/site-packages/neuropod/backends/tensorflow/executor.py", line 134, in forward
    outputs = self.sess.run(output_dict, feed_dict=feed_dict)
  File "/usr/local/google/home/helin/miniconda3/envs/tf_2_4/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 968, in run
    run_metadata_ptr)
  File "/usr/local/google/home/helin/miniconda3/envs/tf_2_4/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1191, in _run
    feed_dict_tensor, options, run_metadata)
  File "/usr/local/google/home/helin/miniconda3/envs/tf_2_4/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1369, in _do_run
    run_metadata)
  File "/usr/local/google/home/helin/miniconda3/envs/tf_2_4/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1394, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.FailedPreconditionError: Table not initialized.
     [[node transform/transform/apply_vocab_14/hash_table_Lookup/LookupTableFindV2 (defined at /site-packages/neuropod/backends/tensorflow/executor.py:78) ]]

The code is:

mport tensorflow.compat.v1 as tf
from neuropod.packagers import create_tensorflow_neuropod
import numpy as np

create_tensorflow_neuropod(
    neuropod_path="./neuropod",
    model_name="model",
    # graph_def=graph_def,
    frozen_graph_path = "/tmp/frozen_test.pb",
    node_name_mapping={
        "x": "input_example_tensor:0",
        "out": "truediv:0",
    },
    input_spec=[
        {"name": "x", "dtype": "string", "shape": (None,)},
    ],
    output_spec=[
        {"name": "out", "dtype": "float32", "shape": (None,)},
    ],
    test_input_data={"x": np.array([""])},
    test_expected_out={"out": np.array([1.0])},
    package_as_zip=False,
)

Have you saw this kind of error before? And another strange thing (related to question 1) is, seems here the TF version is TF 2, does it mean neuropod supports TF2?

Hope my question makes sense, could you kindly take a look? Thanks!