jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.62k stars 2.82k forks source link

[jax2tf] Converting a model with `lax.switch` to TFLite/TF.js/ONNX fails #13416

Open josephrocca opened 2 years ago

josephrocca commented 2 years ago

I have the following minimal code (see this Colab notebook) that uses jax.lax.switch and tries to convert the model to TFLite:

def test_jax(n, operand):

  def fn1(a):
    return a+2

  def fn2(a):
    return a*2

  result = jax.lax.switch(
    n,
    [fn1, fn2],
    operand,
  )

  return result
jax.jit(test_jax)(1, 2)  # returns `4`
my_model = tf.Module()
my_model.f = tf.function(jax2tf.convert(test_jax, enable_xla=False), jit_compile=True, autograph=False, input_signature=[
    tf.TensorSpec([], tf.uint32, name="n"),
    tf.TensorSpec([], tf.uint32, name="operand"),
])
tf.saved_model.save(my_model, './test')
converter = tf.lite.TFLiteConverter.from_saved_model('./test')

converter.target_spec.supported_ops = [
  tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
  tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]

tflite_model = converter.convert()

with open('test.tflite', 'wb') as f:
  f.write(tflite_model)

That TFLite conversion code outputs:

ConverterError: <unknown>:0: error: loc(callsite(callsite(fused["StatelessCase:", "jax2tf_test_jax_/switch_case/indexed_case@__inference_converted_fun_tf_58"] at fused["PartitionedCall:", "PartitionedCall@__inference_signature_wrapper_66"]) at fused["PartitionedCall:", "PartitionedCall"])): 'tf.Case' op is neither a custom op nor a flex op
<unknown>:0: note: loc(fused["PartitionedCall:", "PartitionedCall"]): called from
<unknown>:0: note: loc(callsite(callsite(fused["StatelessCase:", "jax2tf_test_jax_/switch_case/indexed_case@__inference_converted_fun_tf_58"] at fused["PartitionedCall:", "PartitionedCall@__inference_signature_wrapper_66"]) at fused["PartitionedCall:", "PartitionedCall"])): Error code: ERROR_NEEDS_CUSTOM_OPS
<unknown>:0: error: failed while converting: 'main': 
Some ops in the model are custom ops, See instructions to implement custom ops: https://www.tensorflow.org/lite/guide/ops_custom 
Custom ops: StatelessCase
Details:
    tf.Case(tensor<i32>, tensor<ui32>) -> (tensor<*xui32>) : {Tin = [ui32], Tout = [ui32], _read_only_resource_inputs = [], _xla_propagate_compile_time_consts = true, branches = [@jax2tf_test_jax__switch_case_indexed_case_branch0_220, @jax2tf_test_jax__switch_case_indexed_case_branch1_230, @jax2tf_test_jax__switch_case_indexed_case_branch2_240, @jax2tf_test_jax__switch_case_indexed_case_branch3_250], device = "", is_stateless = true}

I've also tried converting it to ONNX and I get a related error - it's documented here.

I also tried converting to TF.js (minimal Colab notebook) and got Unsupported Ops in the model before optimization: _SwitchN.

And I tried using experimental_from_jax (Colab notebook), but for some reason it is crashing my notebook.

The reason why I filed this as a feature request is because it seems like this could be seen as just a problem with op support in the down-stream formats (TFLite, ONNX, etc.) - but I was wondering if (since it seems to be unsupported in basically all the non-JAX/SavedModel formats) it would be possible to convert lax.switch into ops that are supported in the other formats? A switch is a pretty "fundamental" operation so I wouldn't be surprised if this is simply not possible, but I figured I'd file a request just in case, since this is blocking me on a project.

Thanks!

gnecula commented 2 years ago

@marcvanzee @ferev FYI

The actual error you are getting is not with jax2tf, so this bug should be filed with TFLite. I cannot tell if this is easy to fix or not on their side.

josephrocca commented 2 years ago

Yep, as I mentioned, this is a feature request - since neither ONNX, nor TFLite , nor TF.js can handle this operation, but if there's no way to turn this op into some equivalent that's "more supported" in downstream formats, then this issue can be closed. Thanks!

marcvanzee commented 1 year ago

@ferev this issue seems to be on the TFLite side, perhaps you could reroute it to someone from the TFLite team?

josephrocca commented 1 year ago

@marcvanzee apologies for the ping - please ignore this if nothing can be done at this point. A member of the ONNX team said:

An array of functions is not allowed to be an attribute of an ONNX node, so this failed. I agree with you that probably we need to discuss with jax2tf team to see if there is another way to implement such function.

Wondering if there is any viable workaround here? Could the test_jax function in my original post perhaps be re-arranged in some way such that it gets exported in a way that is compatible with ONNX/TFLite/TF.js?

I've filed an issue on the tensorflow repo here.

(Again, please feel free to ignore and/or close this.)

gnecula commented 1 year ago

I am pretty surprised that TFLiteConverter does not support tf.Case, and this is pretty much one of the simplest use cases I can think of. I do not remember seeing these problems in our past experiments with TFLiteConverter.

@ferev do you have any insight here?

Ferev commented 1 year ago

Hi @josephrocca , does setting converter.allow_custom_ops = True unblock you? (On my end this seems to make it convert).

For some reason this is interpreted as a custom op. Will ask the team investigate why this is the case. I agree this is weird behavior for Case.

josephrocca commented 1 year ago

Hey @Ferev, that does make the conversion successful, but I haven't played with custom ops before so not sure just yet if this is viable for me, especially since the end-goal is to get it working in the browser, and I don't think the tflite browser runtime can be configured for custom ops (although I believe that is changing, since the tfjs-tflite team is opening up their wasm tflite build process).

For now I think I've managed to work around this by breaking up the model into two models - so I'll do the first stage in the first model, then jump out (into Python or JS) to do the stuff described in the original post, and then take the outputs of that process and put it into the second tflite model, if that makes sense. Yet to properly test this, but seems like it should work, and hopefully is performant enough.

In any case, looking forward to the result of this investigation. Thanks!

Edit: Oh, btw, just in case it's helpful for the investigation, this is what the resulting model (using your allow_custom_ops=True approach) looks like in Netron:

image

Trying to run it in Python TFLite runtime gives RuntimeError: Encountered unresolved custom op: StatelessCase (which I assume is expected behavior here).

gnecula commented 1 year ago

FWIW, it is often the case that an e2e JAX program will contain numerical pieces written in JAX glued together with Python control-flow, because often that control flow does not fit naturally the JAX programming model. Whether this is performant enough it depends on how load-bearing the control-flow is.