Open galah92 opened 1 year ago
@gnecula Is this bug still current?
Yes, still current with no apparent way overcoming it. I'll be happy to contribute a fix if you'll be able to guide me through.
@ferev to see if these is a fix.
@gnecula @Ferev Hi, I got back to this today and found out that experimental_from_jax
is being deprecated and jax2tf.convert
is the way to go now, which means I don't have a lot of time until I'll have to migrate.
Can I help in some way with this issue?
@ferev What is the status of TFLite support for StableHLO? I would like to deprecate the enable_xla=False
path altogether.
Hi @gnecula , TFLite has migrated to use native serialization as the default. We can simply save as as TFSavedModel via tf.saved_model.save
containing StableHLO ops as the intermediate format between JAX and TFLite. Then TFLite converter API will pick that up converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory') tflite_model = converter.convert()
. We will update our public documentation pretty soon.
Description
I've got a JAX function that I'm trying to fuse with a Keras model and convert everything to TFLite. To my understanding the best way to do that is to convert my function to a TensorFlow Concrete Function, merge it with my Keras model to another Concrete Function, and eventually convert it to TFlite using
TFLiteConverter.from_concrete_functions
.Problem is, getting a Concrete Function from my JAX function fails with the errors below. The weird thing is that converting this function to TFLite directly with
TFLiteConverter.experimental_from_jax
works! I assume it has something to do with conversion directly to HLO (reference), but still, is it possible to get the same behavior forfrom_concrete_functions
?Note: if it disable XLA (
enable_xla=True
) the conversion seemingly works but the final model I'm getting fromfrom_concrete_functions
has an empty graph.Thank you.
What jax/jaxlib version are you using?
jax==0.4.8 jaxlib==0.4.7
Which accelerator(s) are you using?
CPU
Additional system info
WSL2 over Windows 11
NVIDIA GPU info
No response