tensorforce / tensorforce

Tensorforce: a TensorFlow library for applied reinforcement learning
Apache License 2.0
3.3k stars 530 forks source link

tf2 branch: unable to use "saved_model" #704

Closed bezineb5 closed 4 years ago

bezineb5 commented 4 years ago

Hi,

I've started to look at the saved_model export in the tf2 branch and I face some issues: First, I had to change tensorforce/core/utils/dicts.py, line 121 to accept all data types - it seems that tensorflow tries to rebuild dictionaries in the process: value_type=(tf.IndexedSlices, tf.Tensor, tf.Variable, object)

Then, in tensorforce/core/models/model.py line 678, I got errors caused by the signature: ValueError: Got non-flat outputs '(TensorDict(main_sail_angle=Tensor("StatefulPartitionedCall:1", shape=(None,), dtype=float32), jib_angle=Tensor("StatefulPartitionedCall:0", shape=(None,), dtype=float32), rudder_angle=Tensor("StatefulPartitionedCall:2", shape=(None,), dtype=float32)), TensorDict())' from 'b'__inference_function_graph_2203'' for SavedModel signature 'serving_default'. Signatures have one Tensor per output, so to have predictable names Python functions used to generate these signatures should avoid outputting Tensors in nested structures.

I tried to remove the signature in the saved_model.save call, and I got troubles with tensorforce/core/module.py, the function tf_function which build a function_graphs with keys which are tuples - and tensorflow doesn't like it. I converted them to string and I could save a file, but it's totally unusable.

So I'm stuck here, I'd need more help: what is tf_function doing exactly? Why don't you use tf.function instead?

Thanks! Ben

AlexKuhnle commented 4 years ago

Hey,

Thanks for this guidance, with that I managed to get it to work -- at least based on the unittest, which so far just stores a saved-model, see latest tf2 commit. In the final version the additional value_type necessary was TensorSpec (presumably saved-model does some sort of type/shape-tracing). The flattened output was only necessary for independent_act, so I just returned it all as one flat dictionary (with corresponding keys actions/... and internals/...) which should be okay. Moreover, as you suggest, the function_graphs keys are now always turned into strings.

tf_function is an extension of tf.function, which is called here. It does a few things, in particular: distinguish between static and tensor arguments and keep track of the different graph versions per function in function_graphs, assign a fully specified signature per graph instance to avoid TensorFlow creating multiple versions, and wrap and unwrap arguments since tf.function seems to require / turn everything into nested tuples (or something like that, if I remember correctly). This may not be necessary anymore at some point, as I expect tf.function being improved in that direction.

(Please re-open if the model can't be loaded.)

bezineb5 commented 4 years ago

Actually, I'm sorry but I think there's something wrong with the exported model, as I cannot load it afterward. For example, the tflite converter and the onnx convert raise this error when loading the file:

(...)
  File "/usr/local/lib/python3.7/site-packages/tensorflow/lite/python/lite.py", line 399, in from_saved_model
    saved_model = _load(saved_model_dir, tags)
  File "/usr/local/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py", line 578, in load
    return load_internal(export_dir, tags)
  File "/usr/local/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py", line 604, in load_internal
    export_dir)
  File "/usr/local/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py", line 116, in __init__
    meta_graph.graph_def.library))
  File "/usr/local/lib/python3.7/site-packages/tensorflow/python/saved_model/function_deserialization.py", line 311, in load_function_def_library
    func_graph = function_def_lib.function_def_to_graph(copy)
  File "/usr/local/lib/python3.7/site-packages/tensorflow/python/framework/function_def_to_graph.py", line 63, in function_def_to_graph
    importer.import_graph_def_for_function(graph_def, name="")
  File "/usr/local/lib/python3.7/site-packages/tensorflow/python/framework/importer.py", line 412, in import_graph_def_for_function
    graph_def, validate_colocation_constraints=False, name=name)
  File "/usr/local/lib/python3.7/site-packages/tensorflow/python/framework/importer.py", line 501, in _import_graph_def_internal
    raise ValueError(str(e))
ValueError: The inner -1 dimensions of input.shape=[] must match the inner 1 dimensions of updates.shape=[?,24]: Shapes must be equal rank, but are 0 and 1 for '{{node estimator/ResourceScatterNdUpdate}} = ResourceScatterNdUpdate[T=DT_FLOAT, Tindices=DT_INT64, _output_shapes=[], use_locking=true](estimator_cond_input_8:0, estimator/ExpandDims:0, args_0:0, ^estimator/Gather, ^estimator/cond)' with input shapes: [], [?,1], [?,24].

And when I inspect it using netron, it's not as I expect it at all:

image

Whereas in version 0.5.5, it was much more meaningful:

image

It seems that the actual operation (matrix multiplication,... ) are not exported in 0.6.

Does that make any sense?

Thanks a lot! Ben

bezineb5 commented 4 years ago

Actually, netron might not fully support saved_model from tf 2.0, so the display might be incorrect. Still, it can't be imported by tf2 itself. Another strange thing is that the old files used to be <100kB, whereas now it's about 2.2MB.

AlexKuhnle commented 4 years ago

Well, I guess in that case the issue is not yet closed. There are moments when I wonder whether moving away from TF1 was a good idea... :-/

One thing which is irritating is that the exception you mention above complains about the estimator, where it shouldn't be included in the act-only model at all. I will try with TF SavedModel loading to see what's going on. Regarding the graph visualization, I think that may be due to the use of tf.function, which means that the graph consists of lots of independent and nested-called subgraphs. Tensorboard graph visualization looks similar. However, maybe there is a way to improve this (like create independent_act completely separately without nested tf.functions).

bezineb5 commented 4 years ago

I know perfectly that feeling... but soon everything will work fine, and anyway, one day, the migration to tf2 has to be done. Can I be of any help to investigate the issue?

Thanks a lot for your hard work!

AlexKuhnle commented 4 years ago

Hehe yes. Well, your investigation above was already very helpful, at least the saved-model saving "compiles" now. :-) What I would do now to continue debugging is to try to load the model again via the saved-model Python API (which I think you did already?). Since the "estimator" exception is unexpected, I would try to find out why the estimator appears in this model -- basically, independent=True should not involve the estimator here.

AlexKuhnle commented 4 years ago

Ultimately, the aim is to update/modify the ActonlyAgent to be compatible with the new saved-model format (here), so that it can be ensured that everything works in principle (although I expect a more likely use case for saved-model is to load in another language / deployment setting).

bezineb5 commented 4 years ago

Indeed, I got the error while converting to either tflite of ONNX. I think the tflite converter is installed with tensorflow, just invoke it using:

tflite_convert --saved_model_dir xxxx --output_file=yyyyy

As you guessed, my goal is to export the trained model for inference. In my case, I'm targetting a low-power device (namely, a Raspberry Pi Zero) where it's not even possible to install TensorFlow.

AlexKuhnle commented 4 years ago

Since that was always a driving motivation for implementing the entire agent architecture in TensorFlow, it would be really great to see this working (otherwise, what's the point of all this overhead work ;-). What you could try, though, is what happens if you use https://www.tensorflow.org/api_docs/python/tf/saved_model/load instead -- does this throw a more meaningful exception? Does it maybe even work? That's where I would start, at least. If you do, and there is an exception, can you share the relevant script lines and stacktrace?

AlexKuhnle commented 4 years ago

Made some progress on this. First, I was wrongly assuming that by specifying the signature, the saved-model would only include that part of the graph. I'm not sure whether there TF offers functionality to extract subgraphs (it certainly did in TF1). Apart from that, the exception you mention is also what I encounter, and it seems that the only problem are a few scatter_nd_update operations in agent.act(). That should work very soon, then the only problem left is to reduce the saved-model to just independent_act().

bezineb5 commented 4 years ago

Since that was always a driving motivation for implementing the entire agent architecture in TensorFlow, it would be really great to see this working (otherwise, what's the point of all this overhead work ;-). What you could try, though, is what happens if you use https://www.tensorflow.org/api_docs/python/tf/saved_model/load instead -- does this throw a more meaningful exception? Does it maybe even work? That's where I would start, at least. If you do, and there is an exception, can you share the relevant script lines and stacktrace?

Oups sorry, I didn't see that update. I'll do that and keep you informed.

AlexKuhnle commented 4 years ago

No worries, I think it's basically working now... will soon update.

AlexKuhnle commented 4 years ago

It seems to work now, see unittest. However, the saved-model still includes the full model, not just the independent-act graph, and it's a bit awkward to interface right now (in particular getting the correct graph from the _* attribute). This will still be improved.

AlexKuhnle commented 4 years ago

Did you get a chance to try it out?

bezineb5 commented 4 years ago

Sorry, I was busy the last couple of days. Yes, I did try out and it's working! Actually, after converting to ONNX, I don't see the training part of the graphs. And I tried to evaluate it, just to find that it works perfectly!!!

Only things which have been lost since version 0.5 are the names of the input and outputs. Input is now named "args_0:0" and outputs are "Identity:0", "Identity_1:0",... I'll have a look where it comes from.

Thanks a lot, awesome work!

AlexKuhnle commented 4 years ago

Great. So either ONNX only loads the signature-relevant parts, or it only looks like the full model is saved. But your observation that the model files are bigger now seems to suggest that probably the full model is saved. Anyway, need to check a bit more.

The naming is a good point... I don't know to what degree the user has control over it after using tf.function (and I thought the whole saved-model signature stuff is taking care of that?), but at least the outputs could definitely be named again. Do you still have to feed and retrieve tensors via names TF1-style when using saved-model with ONNX?

bezineb5 commented 4 years ago

Indeed, the ONNX converter extract inference subgraph using the signature (so there's a sharp reduction in file size: ONNX:~90kB, pb-model: ~2MB). From what I can find online (eg.: https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/saved_model.ipynb#scrollTo=6VYAiQmLUiox), the output name should come from the signature, indeed. They mention: "To control the names of outputs, modify your tf.function to return a dictionary that maps output names to outputs" - which is exactly what you're doing...

I'll continue to investigate.

bezineb5 commented 4 years ago

So, here are my findings: 1) For the input name (which is a completely minor issue, for sure!), it's just that the TensorSpec don't have a name. It could be fixed by putting meaingful name when they are create or by re-naming them in the tf_function with sometinhg like that:

graph_signature = SignatureDict(**{k: tf.TensorSpec.from_spec(v, name=k) for k, v in graph_signature.items()})

2) For the output names, I tried different things. The saved_model actually contains the names:

saved_model_cli show --signature_def serving_default --tag_set serve --dir ./export/

The given SavedModel SignatureDef contains the following input(s):
  inputs['states/state'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1, 24)
      name: serving_default_states/state:0
The given SavedModel SignatureDef contains the following output(s):
  outputs['jib_angle'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1)
      name: StatefulPartitionedCall:0
  outputs['main_sail_angle'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1)
      name: StatefulPartitionedCall:1
  outputs['rudder_angle'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1)
      name: StatefulPartitionedCall:2

However, they are lost when converting to ONNX or TFLite. The "good" news is that there is an issue in tensorflow for this: https://github.com/tensorflow/tensorflow/issues/32180

So I consider that tensorforce works fine! I still don't understand why it saves the training model, though. Should we close this ticket, and open a new one dedicated to that issue?

Anyway, thanks a lot for the good work!

AlexKuhnle commented 4 years ago

Ah, interesting, so TensorSpec's can be named. I will add this to the signatures soon.

Yes, let's open a separate issue for that, probably low priority right now, since the main saved-model stuff is working now.

If there's a simple example you could add to the examples folder illustrating the saved-model use, that would be a very welcome PR. But no worries if it's not straightforward, otherwise I can probably take something from the saving unittest.

bezineb5 commented 4 years ago

For the example, what about adding an export section in the temperature controller notebook? I'd add examples to export in tflite and/or ONNX, as I think it's a meaningful use case, then infer either in python or in javascript.

AlexKuhnle commented 4 years ago

Yes, that sounds good. Maybe then also worth mentioning it in the docs, so that people don't miss that part (I can do that). The loading would then be in a separate file, or still in the notebook?

AlexKuhnle commented 4 years ago

FYI: the SavedModel handling has changed and improved in the latest version, including I think input tensor naming, see also the example script here. Are there any issues left, or can this be closed?

AlexKuhnle commented 4 years ago

Actually, let's close this and reopen a new issue if required.

bezineb5 commented 4 years ago

No, I agree, this can be closed - there's nop issue left on my side, sorry for not closing it. Thanks a lot!