google / yggdrasil-decision-forests

A library to train, evaluate, interpret, and productionize decision forest models such as Random Forest and Gradient Boosted Decision Trees.
https://ydf.readthedocs.io/
Apache License 2.0
498 stars 53 forks source link

Usage on microcontroller (ARM Cortex-M4) / LiteRT #139

Open patrickjedlicka opened 1 month ago

patrickjedlicka commented 1 month ago

Hi, has anyone experience with making predictions on a microcontroller? Is the YDF C++ Library compatible with microcontroller architecture?

For my understanding YDF models not yet compatible with TFlite (nowadays LiteRT). Is this true or has something changed here?

Every help is welcome!

Best regards,

/P

rstz commented 1 month ago

For GradientBoostedTrees, you can (experimentally) convert to JAX and then convert to LiteRT. I've posted this in another issue (and should probably add it to the documentation)

# Train a Gradient Boosted Trees model
gbt_learner = ydf.GradientBoostedTreesLearner(label='class')
gbt_model = gbt_learner.train(df_fft_general)

# Convert the model to Jax
jax_model = gbt_model.to_jax_function(compatibility="TFL")

# Convert a Jax model to a TensorFlow model.
tf_model = tf.Module()
tf_model.predict = tf.function(
    jax2tf.convert(jax_model.predict, with_gradient=False),
    jit_compile=True,
    autograph=False,
)

# Convert the Tensorflow model to a TFLite model
selected_examples = test_ds[:1].drop(model.label(), axis=1)
input_values = jax_model.encoder(selected_examples)
tf_input_specs = {
    k: tf.TensorSpec.from_tensor(tf.constant(v), name=k)
    for k, v in input_values.items()
}
concrete_predict = tf_model.predict.get_concrete_function(tf_input_specs)
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [concrete_predict], tf_model
)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
    tf.lite.OpsSet.SELECT_TF_OPS,  # enable TensorFlow ops.
]
tflite_model = converter.convert()

For Random Forests and Isolation Forests, we haven't had time to implement it, but I don't see any blockers.

The faster and, probably, more elegant approach is to just compile the YDF inference code for your architecture. People have done this for some architectures, e.g. Raspberry Pi, but we can't make any promises about it. For reference, here are the (since deleted) instructions https://github.com/google/yggdrasil-decision-forests/blob/4bdedd31c041706a3d022313f1edaf494dea53c1/documentation/installation.md#compilation-on-and-for-raspberry-pi

Note that we're now using Bazel 5.3.0 (and will be migrating to Bazel 6 or 7 at some point).

In the compilation step, just going for the predict tool should be enough, i.e.,

${BAZEL} build //yggdrasil_decision_forests/cli:predict \
  --config=linux_cpp17 --features=-fully_static_link --host_javabase=@local_jdk//:jdk

If you're successful, please let us know, we'd be happy to include an updated guide in the repo.