keras-team / tf-keras

The TensorFlow-specific implementation of the Keras API, which was the default Keras from 2019 to 2023.
Apache License 2.0
64 stars 30 forks source link

Support for StableHLO generation with JAX backend #26

Open ashutosh-arm opened 1 year ago

ashutosh-arm commented 1 year ago

System information.

TensorFlow version (you are using): N/A Are you willing to contribute it (Yes/No): Not immediately.

Describe the feature and the current behavior/state.

In some of the Keras Core examples JAX backend has been used. IIUC this flow uses jitting via XLA. Here I assume that the lowering must generate StableHLO as an IR before its consumed by XLA. If this is truly the case, is it viable to produce StableHLO while using JAX as the backend? It will be useful for compiling the same model using IREE instead of XLA.

Existing flow: Keras Core model --> JAX.JIT (XLA) Desired flow: Keras Core model --> JAX.JIT --> Side outputs is StableHLO --> IREE

Will this change the current api? How?

I am not sure.

Who will benefit from this feature?

IREE users.

Contributing

divyashreepathihalli commented 1 year ago

Thanks for filing the issue. We have model.export support right now. We plan on adding support for Onnx support and stableHLO support in the future.

ashutosh-arm commented 1 year ago

Is the export support capable of generating PyTorch and JAX programs from the original Keras Core models?