google / yggdrasil-decision-forests

A library to train, evaluate, interpret, and productionize decision forest models such as Random Forest and Gradient Boosted Decision Trees.
https://ydf.readthedocs.io/
Apache License 2.0
485 stars 52 forks source link

Serving SavedModel files with Tensorflow Java? #126

Open turtlemonvh opened 2 months ago

turtlemonvh commented 2 months ago

Your documentation mentions :

Run models in Python, C++, Go, JavaScript, and CLI.

image

What about Tensorflow Java? https://github.com/tensorflow/java/releases If so, which versions are supported? I was looking to see what version of TensorFlow was required for serving these models, but couldn't tell from the docs.

rstz commented 2 months ago

Tl;Dr: It's likely possible by first converting to JAX and then to Tensorflow, but our team has not yet tried it.

There are two paths for serving from Tensorflow:

Using the Tensorflow Decision Forests custom op

TensorFlow Decision Forests (TF-DF) defines a custom op for Tensorflow that allows Tensorflow to run tree models generated with YDF or TF-DF itself. This means that you can save a YDF model with model.to_tensorflow_saved_model("/tmp/ydf/tf_model", mode="tf") (details here) and then run it in Tensorflow environments provided that they support this custom op. Notably, the custom op is supported by TF-Serving and Python environments that have Tensorflow Decision Forests installed.

To the best of my knowledge, the custom op is not available in TensorFlow Java and no attempts have been made to include it in Tensorflow Java.

Pure models by converting YDF -> Jax -> Tensorflow

We recently added the possibility to export YDF models to pure JAX functions. JAX functions can be converted to TensorFlow models as shown in this tutorial. The resulting SavedModel is a pure Tensorflow model and should be compatible with all (*) TensorFlow surfaces. We have not tried it with TensorFlow Java, but I'd be very interested if someone has the bandwidth to experiment with it. Note that exporting to JAX is currently implemented for Gradient Boosted Trees only.

(*) TFLite support is coming with the next version of YDF.

Craigacp commented 2 months ago

I believe people eventually got TFDF models to work with TF-Java, there were a bunch of issues with how TFDF was exporting its symbols which were incompatible with TF-Java's build, but now we're using the same binaries as Python it should work (with TF-Java 1.0.0-rc.1).

rstz commented 2 months ago

That's super interesting, thank you for letting us know!