tensorflow / decision-forests

A collection of state-of-the-art algorithms for the training, serving and interpretation of Decision Forest models in Keras.
Apache License 2.0
663 stars 110 forks source link

Can I load and use trained tfdf model in Java? #81

Closed YujieW0201 closed 2 years ago

YujieW0201 commented 2 years ago

Hi I trained my tfdf model in python and want to use it in java for production. For conventional NN model, we can load the model from SavedModelBundle and get prediction.

try (SavedModelBundle b = SavedModelBundle.load("/tmp/model", "serve")) {

        // create the session from the Bundle
        Session sess = b.session();
        // create an input Tensor, value = 2.0f
        Tensor x = Tensor.create(
            new long[] {NUM_PREDICTIONS}, 
            FloatBuffer.wrap( new float[] {2.0f} ) 
        );

        // run the model
        float[] y = sess.runner()
            .feed("x", x)
            .fetch("y")
            .run()
            .get(0)
            .copyTo(new float[NUM_PREDICTIONS]);

        // print out the result.
        System.out.println(y[0]);
    }                

I'm currently trying to use my tfdf model and wondering if current tfdf support loading and inference in Java? Will the model's graph and useful info be loaded? I'm still trying to load it and wondering if anyone has clue? Thank you so much!

achoum commented 2 years ago

Hi AudreyW0201,

We have not yet experimented using TF-DF in Java, so take what I say with a gain of salt.

It seems, that TF-Java relies on the core TF runtime. More clearly, TF-Java uses the same implementation as TF-Python. This means that TF-DF models can likely run in TF-Java seemingly.

The most likely issue (if any) will be to configure TF-Java to use the TF-DF model custom op (which is compatible with the TF runtime). In TF-Python, this is done automatically when importing the TF-DF library. In TF-Java, you might have to read the documentation / ask the TF-Java people.

If you have an error, don't hesitate to post it here. Maybe we can help figuring out.

Alternatively, depending on your and other needs, it could be interesting to implement the TF-DF inference code directly in Java. A simple version of this code will be small (probably less than 20 lines of code; c++ example).

YujieW0201 commented 2 years ago

Thanks @achoum for your reply! I tried to use java's function SavedModelBundle.load('model_path') to load the model but got error message: Op type not registered 'SimpleMLCreateModelResource' in binary. Make sure the Op and Kernel are registered in the binary running in this process.

I also contacted TF-Java people and tried the method but it doesn't work. https://github.com/tensorflow/java/issues/419#issuecomment-1045480656

For your second suggestion, I'm wondering if I cannot load the model successfully using java, how can I do inference on it? Thank you!

achoum commented 2 years ago

Hi AudreyW0201,

@Craigacp mentioned to use TensorFlow.loadLibrary. Could you share what it does when it does not work?

I think the solution should look like:

1. Download the TF-DF Pip library corresponding to your OS at: https://pypi.org/project/tensorflow-decision-forests/#files

For example, if you are working on Linux, download any version that ends in manylinux_2_12_x86_64.manylinux2010_x86_64.whl

Alternatively, the file should already be on your computer if you installed TF-DF in Python.

2. Open the .whl file. This is a classical archive Zip file. For example, use 7zip.

3. Extract the file `/tensorflow_decision_forests/tensorflow/ops/inference/inference.so" from the archive. This is the "custom op library" that need to be loaded in Java.

4. In your Java code, before the SavedModelBundle.load, run the following code

byte[] opList = TensorFlow.loadLibrary(path_to_library_so);
assertTrue(opList.length > 0);

with path_to_library_so the path to the .so file you extracted earlier.

For the second solution, you (or a contributor) would have to post the inference code in Java. It is likely more work though.

YujieW0201 commented 2 years ago

Thank you @achoum and @Craigacp for your detailed instruction! I'm able to load the custom op library now.

However, when I tried to load my model using SavedModelBundle.load, it throws exception: TFInvalidArgumentException: No shape inference function exists for op 'SimpleMLLoadModelFromPathWithHandle', did you forget to define it? Should I load other library or ops before loading the model? More detailed error log:

external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:107] Reading meta graph with tags { serve }
external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:148] Reading SavedModel debug info (if present) from: /Users/
external/org_tensorflow/tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:210] Restoring SavedModel bundle.
external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:194] Running initialization op on SavedModel bundle at path: /Users/
[INFO kernel.cc:1153] Loading model from path
[INFO quick_scorer_extended.cc:824] The binary was compiled without AVX2 support, but your CPU supports it. Enable it for faster model inference.
[INFO abstract_model.cc:1063] Engine "GradientBoostedTreesQuickScorerExtended" built
[INFO kernel.cc:1001] Use fast generic engine
 I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:283] SavedModel load for tags { serve }; Status: success: OK. Took 1266572 microseconds.
Exception in thread "main" org.tensorflow.exceptions.TFInvalidArgumentException: No shape inference function exists for op 'SimpleMLLoadModelFromPathWithHandle', did you forget to define it?
    at org.tensorflow.internal.c_api.AbstractTF_Status.throwExceptionIfNotOK(AbstractTF_Status.java:87)
    at org.tensorflow.SavedModelBundle.load(SavedModelBundle.java:623)
    at org.tensorflow.SavedModelBundle.access$000(SavedModelBundle.java:67)
    at org.tensorflow.SavedModelBundle$Loader.load(SavedModelBundle.java:97)
    at org.tensorflow.SavedModelBundle.load(SavedModelBundle.java:357)

Thanks!

nicolas-kim-reddit commented 2 years ago

Hey @AudreyW0201 , just following up to see if you've found a solution or workaround for the above issue?

Craigacp commented 2 years ago

This looks like an issue in TF-DF. The TF C API requires that ops have shape inference functions, but this is disabled by the TF Python API (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/ops.py#L3211). As that call isn't part of the public TF C API we can't use it in TF-Java, and so need full shape inference for all ops.

achoum commented 2 years ago

All the TF-DF ops have been augmented with ShapeInference (example). Could you try again ? :)

MatthewZholud commented 2 years ago

Thank you achoum!

https://github.com/tensorflow/decision-forests/issues/81#issuecomment-1049839308 helped me!

achoum commented 2 years ago

Awesome :)

binalj-spotify commented 7 months ago

If I were running this on Dataflow how would I do the same