Open josephrocca opened 2 years ago
@marcvanzee @ferev FYI
The actual error you are getting is not with jax2tf, so this bug should be filed with TFLite. I cannot tell if this is easy to fix or not on their side.
Yep, as I mentioned, this is a feature request - since neither ONNX, nor TFLite , nor TF.js can handle this operation, but if there's no way to turn this op into some equivalent that's "more supported" in downstream formats, then this issue can be closed. Thanks!
@ferev this issue seems to be on the TFLite side, perhaps you could reroute it to someone from the TFLite team?
@marcvanzee apologies for the ping - please ignore this if nothing can be done at this point. A member of the ONNX team said:
An array of functions is not allowed to be an attribute of an ONNX node, so this failed. I agree with you that probably we need to discuss with jax2tf team to see if there is another way to implement such function.
Wondering if there is any viable workaround here? Could the test_jax
function in my original post perhaps be re-arranged in some way such that it gets exported in a way that is compatible with ONNX/TFLite/TF.js?
I've filed an issue on the tensorflow repo here.
(Again, please feel free to ignore and/or close this.)
I am pretty surprised that TFLiteConverter does not support tf.Case
, and this is pretty much one of the simplest use cases I can think of. I do not remember seeing these problems in our past experiments with TFLiteConverter.
@ferev do you have any insight here?
Hi @josephrocca , does setting converter.allow_custom_ops = True
unblock you? (On my end this seems to make it convert).
For some reason this is interpreted as a custom op. Will ask the team investigate why this is the case. I agree this is weird behavior for Case
.
Hey @Ferev, that does make the conversion successful, but I haven't played with custom ops before so not sure just yet if this is viable for me, especially since the end-goal is to get it working in the browser, and I don't think the tflite browser runtime can be configured for custom ops (although I believe that is changing, since the tfjs-tflite team is opening up their wasm tflite build process).
For now I think I've managed to work around this by breaking up the model into two models - so I'll do the first stage in the first model, then jump out (into Python or JS) to do the stuff described in the original post, and then take the outputs of that process and put it into the second tflite model, if that makes sense. Yet to properly test this, but seems like it should work, and hopefully is performant enough.
In any case, looking forward to the result of this investigation. Thanks!
Edit: Oh, btw, just in case it's helpful for the investigation, this is what the resulting model (using your allow_custom_ops=True
approach) looks like in Netron:
Trying to run it in Python TFLite runtime gives RuntimeError: Encountered unresolved custom op: StatelessCase
(which I assume is expected behavior here).
FWIW, it is often the case that an e2e JAX program will contain numerical pieces written in JAX glued together with Python control-flow, because often that control flow does not fit naturally the JAX programming model. Whether this is performant enough it depends on how load-bearing the control-flow is.
I have the following minimal code (see this Colab notebook) that uses
jax.lax.switch
and tries to convert the model to TFLite:That TFLite conversion code outputs:
I've also tried converting it to ONNX and I get a related error - it's documented here.
I also tried converting to TF.js (minimal Colab notebook) and got
Unsupported Ops in the model before optimization: _SwitchN
.And I tried using
experimental_from_jax
(Colab notebook), but for some reason it is crashing my notebook.The reason why I filed this as a feature request is because it seems like this could be seen as just a problem with op support in the down-stream formats (TFLite, ONNX, etc.) - but I was wondering if (since it seems to be unsupported in basically all the non-JAX/SavedModel formats) it would be possible to convert
lax.switch
into ops that are supported in the other formats? Aswitch
is a pretty "fundamental" operation so I wouldn't be surprised if this is simply not possible, but I figured I'd file a request just in case, since this is blocking me on a project.Thanks!