tensorflow / agents

TF-Agents: A reliable, scalable and easy to use TensorFlow library for Contextual Bandits and Reinforcement Learning.
Apache License 2.0
2.8k stars 721 forks source link

is it possible to convert tf-agents to tf-lite and run on android device #280

Open ssujr opened 4 years ago

ssujr commented 4 years ago

We want to implement RL on android device. Just wondering if it is possible to run tf-agents on android or to convert tf-agents to tf-lite. It will be great if someone can share some experience. Thank you!

ebrevdo commented 4 years ago

Yes; you should be able to do this. I'm guessing you care about inference (running a policy) more than training (since tflite doesn't support that anyway).

See the PolicySaver class. You can use it to export a SavedModel. You can then use the TFLite converter to convert that SavedModel to a TFLite model.

Please report back and let us know if this works for you!

ssujr commented 4 years ago

Actually, we plan to do both training and inference on device. Do you guys have plan to support training in near future? Thank you for the response.

dvdhfnr commented 4 years ago

Hi!

Yes; you should be able to do this. I'm guessing you care about inference (running a policy) more than training (since tflite doesn't support that anyway).

See the PolicySaver class. You can use it to export a SavedModel. You can then use the TFLite converter to convert that SavedModel to a TFLite model.

Please report back and let us know if this works for you!

We tried to do this (using the DqnAgent.). However, we are receiving the following error when trying to convert the saved model (policy): "ValueError: This converter can only convert a single ConcreteFunction. Converting multiple functions is under development."

@ebrevdo Any suggestions? (If required, further details can be provided.)

Thanks!

ebrevdo commented 4 years ago

For "only convert a single ConcreteFunction" this is cause it's trying to use the new MLIR converter. I suggest filing a repro separately with the TensorFlow Issues so they can see this feature is required. @aselle @jdduke fyi.

Separately; for now you should be able to use the "old-style" converter (it should work fine). Try passing --enable_v1_converter when you call tflite_convert and report back :)

ebrevdo commented 4 years ago

For training on device you cannot do this with TFLite. You must either use the standard TF runtime, or try the (less well supported path) of using the new saved_model_cli aot_compile_cpu approach, which does not support dynamic shapes and a lot more manual, but would allow you to train on device. Unfortunately there's no tutorial (yet) on how to do this. If you're interested in that, we can involve the TF team to maybe write something about this approach.

ebrevdo commented 4 years ago

(for aot_compile_cpu; you will need the most recent tf2.2 RC; it's not in TF2.1).

dvdhfnr commented 4 years ago

enable_v1_converter

Thanks for the fast response!

--enable_v1_converter works "better", but leads to a different error: ValueError: No 'serving_default' in the SavedModel's SignatureDefs. Possible values are 'get_initial_state,__saved_model_init_op,get_train_step,action'.

(We do not require training on the device.)

ebrevdo commented 4 years ago

We can add a TODO to be able to create SavedModels out of the Agent.train() method; but my comments above still apply...

On Mon, May 4, 2020 at 9:11 AM ebrevdo notifications@github.com wrote:

(for aot_compile_cpu; you will need the most recent tf2.2 RC; it's not in TF2.1).

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/tensorflow/agents/issues/280#issuecomment-623557760, or unsubscribe https://github.com/notifications/unsubscribe-auth/AANWFG36MMSYA4LDOCQTOXLRP3SKLANCNFSM4KB5SFLA .

ebrevdo commented 4 years ago

The tflite_convert CLI help doesn't seem to show it, but you can pass a " --saved_model_signature_key https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/python/tflite_convert.py#L385" flag, you probably want to point it to "action". If you have an RNN in the model, you'll also want to create a separate TFLite model for "get_initial_state" which you would use to initialize the RNN at the beginning of an episode/sequence and pass as the initial state to "action".

On Mon, May 4, 2020 at 9:22 AM ebrevdo notifications@github.com wrote:

We can add a TODO to be able to create SavedModels out of the Agent.train() method; but my comments above still apply...

On Mon, May 4, 2020 at 9:11 AM ebrevdo notifications@github.com wrote:

(for aot_compile_cpu; you will need the most recent tf2.2 RC; it's not in TF2.1).

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub <https://github.com/tensorflow/agents/issues/280#issuecomment-623557760 , or unsubscribe < https://github.com/notifications/unsubscribe-auth/AANWFG36MMSYA4LDOCQTOXLRP3SKLANCNFSM4KB5SFLA

.

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/tensorflow/agents/issues/280#issuecomment-623564176, or unsubscribe https://github.com/notifications/unsubscribe-auth/AANWFG2D2TV76ACWGFPIXSDRP3TVDANCNFSM4KB5SFLA .

dvdhfnr commented 4 years ago

Great. Thanks.

tflite_convert --saved_model_dir saveDir --enable_v1_converter --saved_model_signature_key action --output_file out.tflite --allow_custom_ops seems to work for the conversion.

(Still need to investigate if this tflite model runs as expected on the Android device. I will try to report back.)

Thanks.

maslovay commented 4 years ago

@dvdhfnr how are things with implementing your tf agents trained NN on Android? I have this error:

"RuntimeError: Encountered unresolved custom op: BroadcastArgs.Node number 0 (BroadcastArgs) failed to prepare."

Here the case: https://stackoverflow.com/questions/61715154/tflite-model-load-error-runtimeerror-encountered-unresolved-custom-op-broadca

@ebrevdo

ebrevdo commented 4 years ago

@jdduke @raziel any suggestions?

dvdhfnr commented 4 years ago

When converting with the flag "--allow_custom_ops" you need to implement the ops that are not supported by TFLite by yourself: see e.g. https://www.tensorflow.org/lite/guide/ops_custom

Try to convert without "--allow_custom_ops". Then, you will see a list of ops that are not supported. Unfortunately, it seems that we will have to implement those by ourselves.

maslovay commented 4 years ago

@dvdhfnr you are right, the problem is this ops:

Exception: <unknown>:0: error: loc(fused["Deterministic_1/sample/BroadcastArgs@__inference_action_11129549", "StatefulPartitionedCall/StatefulPartitionedCall/StatefulPartitionedCall/StatefulPartitionedCall/Deterministic_1/sample/BroadcastArgs"]): 'tf.BroadcastArgs' op is neither a custom op nor a flex op
<unknown>:0: error: loc(fused["ActorDistributionNetwork/TanhNormalProjectionNetwork/MultivariateNormalDiag/shapes_from_loc_and_scale/prefer_static_broadcast_shape/BroadcastArgs@__inference_action_11129549", "StatefulPartitionedCall/StatefulPartitionedCall/StatefulPartitionedCall/StatefulPartitionedCall/ActorDistributionNetwork/TanhNormalProjectionNetwork/MultivariateNormalDiag/shapes_from_loc_and_scale/prefer_static_broadcast_shape/BroadcastArgs"]): 'tf.BroadcastArgs' op is neither a custom op nor a flex op
<unknown>:0: error: loc(fused["Deterministic_1/sample/BroadcastArgs_1@__inference_action_11129549", "StatefulPartitionedCall/StatefulPartitionedCall/StatefulPartitionedCall/StatefulPartitionedCall/Deterministic_1/sample/BroadcastArgs_1"]): 'tf.BroadcastArgs' op is neither a custom op nor a flex op
<unknown>:0: error: loc(fused["Deterministic_1/sample/BroadcastTo@__inference_action_11129549", "StatefulPartitionedCall/StatefulPartitionedCall/StatefulPartitionedCall/StatefulPartitionedCall/Deterministic_1/sample/BroadcastTo"]): 'tf.BroadcastTo' op is neither a custom op nor a flex op
<unknown>:0: error: failed while converting: 'main': Ops that can be supported by the flex runtime (enabled via setting the -emit-select-tf-ops flag): BroadcastArgs,BroadcastArgs,BroadcastArgs,BroadcastTo.
dvdhfnr commented 4 years ago

Currently, I am using the following pipeline:

policy_saver = PolicySaver(policy)
policy_saver.save('tmp')
converter = tf.lite.TFLiteConverter.from_saved_model('tmp', signature_keys=["action"])
tflite_policy = converter.convert()

Since I am actually not interested in saving the policy to a file, I tried to exchange the 2nd and 3rd line with

converter = tf.lite.TFLiteConverter.from_concrete_functions([policy_saver._signatures['action'].get_concrete_function()])

I noticed that this changes the order of the input tensors. Do I need to take care of other side-effects or is this method safe to use? Moreover, do I need to use the PolicySaver at all or can I just directly create a concrete function ('action') and convert from this? (The PolicySaver code looks quite sophisticated. Hence, I cannot fully get an overview of what is done and why.)

Thanks for your comments!

ebrevdo commented 3 years ago

There is now a unit test showing how to use policy saver with tflite converter in policy_saver_test.py. does it help?

soldierofhell commented 3 years ago

Hi @ebrevdo, There's short note in the code: https://github.com/tensorflow/agents/blob/3448c9e88fbe48d515c85ffab2b96e9f429a3b7d/tf_agents/policies/policy_saver_test.py#L358-L359 I guess this "native support for RNG ops, atan, etc." relates to unsupported BroadcastArgs and BroadcastTo ops. Could you please provide more details what is the root cause of the problem (e.g. where are those broadcast coming from)? Maybe it's possible to change something in tf_agents code? Or maybe we can somehow contribute to improve something on TFLite side? Thanks in advance, Regards,

ebrevdo commented 3 years ago

This has nothing to do with TF-Agents - it depends on TFLite team. @jdduke FYI. Is there a relevant issue open on tf's side?

ebrevdo commented 3 years ago

I'm not sure where the broadcast args are coming from. possibly from TF Probability? Here's where we use broadcast_to but I don't think these are the real places it's coming from. Probably from a library we're using as I mentioned.

jdduke commented 3 years ago

@thaink is actively working to support this. I'm not sure if there's a corresponding TF issue, but we do have an internal issue tracking this.

thaink commented 3 years ago

@ebrevdo I think the BroadcastArgs may come from using broadcast_to on a dynamic tensor. I am working on supporting BroadcastArgs now.

soldierofhell commented 3 years ago

Thanks guys, please leave here a comment when BroadcastArgs will be available

soldierofhell commented 3 years ago

@thaink any ETA for this BroadcastArgs issue? :)

thaink commented 3 years ago

Unfortunately, it is still under review.

thaink commented 3 years ago

@soldierofhell BroadcastArgs is added to master branch. You could try it using the nightly now.

windmaple commented 3 years ago

I can convert the model now. Thanks for @thaink 's work.