Closed noahzhy closed 7 months ago
Thanks for raising this!
Indeed JAX has no primitive Mean operation. That happens quite often; instead of having lots of distinct kernels, JAX relies on a compiler to fuse primitives into performant kernels. That means that if instead of a compiler we have an interpreter downstream of a JAX function, we might get inefficiencies like this. (It's an open question whether we can have our cake and eat it too, by e.g. staging out a program representation that provides both the Mean function and its implementation in terms of primitives, so that downstream interpreters can use their own direct Mean implementations rather than having to interpret its implementation.)
I think this is working-as-intended, but I'm going to tag @gnecula and @superbobry to check. Any thoughts to add?
Description
Convert the same model into tflite file, but the
Mean
layer was converted to two different node.Certainly, this will affect the inference speed of the TFLite model.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.26 jaxlib: 0.4.23 numpy: 1.26.4 python: 3.11.7 (main, Dec 4 2023, 18:10:11) [Clang 15.0.0 (clang-1500.1.0.2.5)] jax.devices (1 total, 1 local): [METAL(id=0)]