onnx / tensorflow-onnx

Convert TensorFlow, Keras, Tensorflow.js and Tflite models to ONNX
Apache License 2.0
2.32k stars 433 forks source link

Support for `jax.lax.switch` (takes array of functions as input - throws `You passed in an iterable attribute but I cannot figure out its applicable type`) #2070

Open josephrocca opened 2 years ago

josephrocca commented 2 years ago

Describe the bug The following function can't be converted to ONNX (via TF+jax2tf) due to the use of jax.lax.switch:

def test_jax(n, operand):

  def fn1(a):
    return a+2

  def fn2(a):
    return a*2

  result = jax.lax.switch(
    n,
    [fn1, fn2],
    operand,
  )

  return result

Seems like tf2onnx doesn't like arrays of functions?

It throws this error:

ERROR:tf2onnx.tf_utils:pass1 convert failed for name: "jax2tf_test_jax_/switch_case/indexed_case"
op: "StatelessCase"
input: "jax2tf_test_jax_/clip_by_value"
input: "operand"
attr {
  key: "Tin"
  value {
    list {
      type: DT_UINT32
    }
  }
}
attr {
  key: "Tout"
  value {
    list {
      type: DT_UINT32
    }
  }
}
attr {
  key: "_read_only_resource_inputs"
  value {
    list {
    }
  }
}
attr {
  key: "_xla_propagate_compile_time_consts"
  value {
    b: true
  }
}
attr {
  key: "branches"
  value {
    list {
      func {
        name: "jax2tf_test_jax__switch_case_indexed_case_branch0_241"
      }
      func {
        name: "jax2tf_test_jax__switch_case_indexed_case_branch1_242"
      }
      func {
        name: "jax2tf_test_jax__switch_case_indexed_case_branch2_243"
      }
      func {
        name: "jax2tf_test_jax__switch_case_indexed_case_branch3_244"
      }
    }
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
      }
    }
  }
}
, ex=You passed in an iterable attribute but I cannot figure out its applicable type.

Urgency No hard deadlines. This is part of my process of getting Stable Diffusion to run in web browsers.

System information

To Reproduce Here's a minimal reproduction:

https://colab.research.google.com/gist/josephrocca/db146f00593f86e86a0ecc87c49453d3

josephrocca commented 1 year ago

Hey @fatcat-z, sorry to ping - wondering if you managed to take a look at this? It looks like this might be a problem that's best solved up-stream, since TFLite converter complains that tf.Case op is missing (minimal notebook example). So maybe I should report this to the jax2tf team to see if they can convert this sort of thing into a set of more common ops.

fatcat-z commented 1 year ago

The error was thrown out by onnx code here. An array of functions is not allowed to be an attribute of an ONNX node, so this failed.

I agree with you that probably we need to discuss with jax2tfteam to see if there is another way to implement such function.