Open patrickjedlicka opened 1 month ago
For GradientBoostedTrees, you can (experimentally) convert to JAX and then convert to LiteRT. I've posted this in another issue (and should probably add it to the documentation)
# Train a Gradient Boosted Trees model
gbt_learner = ydf.GradientBoostedTreesLearner(label='class')
gbt_model = gbt_learner.train(df_fft_general)
# Convert the model to Jax
jax_model = gbt_model.to_jax_function(compatibility="TFL")
# Convert a Jax model to a TensorFlow model.
tf_model = tf.Module()
tf_model.predict = tf.function(
jax2tf.convert(jax_model.predict, with_gradient=False),
jit_compile=True,
autograph=False,
)
# Convert the Tensorflow model to a TFLite model
selected_examples = test_ds[:1].drop(model.label(), axis=1)
input_values = jax_model.encoder(selected_examples)
tf_input_specs = {
k: tf.TensorSpec.from_tensor(tf.constant(v), name=k)
for k, v in input_values.items()
}
concrete_predict = tf_model.predict.get_concrete_function(tf_input_specs)
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[concrete_predict], tf_model
)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS, # enable TensorFlow ops.
]
tflite_model = converter.convert()
For Random Forests and Isolation Forests, we haven't had time to implement it, but I don't see any blockers.
The faster and, probably, more elegant approach is to just compile the YDF inference code for your architecture. People have done this for some architectures, e.g. Raspberry Pi, but we can't make any promises about it. For reference, here are the (since deleted) instructions https://github.com/google/yggdrasil-decision-forests/blob/4bdedd31c041706a3d022313f1edaf494dea53c1/documentation/installation.md#compilation-on-and-for-raspberry-pi
Note that we're now using Bazel 5.3.0 (and will be migrating to Bazel 6 or 7 at some point).
In the compilation step, just going for the predict
tool should be enough, i.e.,
${BAZEL} build //yggdrasil_decision_forests/cli:predict \
--config=linux_cpp17 --features=-fully_static_link --host_javabase=@local_jdk//:jdk
If you're successful, please let us know, we'd be happy to include an updated guide in the repo.
Hi, has anyone experience with making predictions on a microcontroller? Is the YDF C++ Library compatible with microcontroller architecture?
For my understanding YDF models not yet compatible with TFlite (nowadays LiteRT). Is this true or has something changed here?
Every help is welcome!
Best regards,
/P