Open polvalente opened 4 weeks ago
I believe this is pretty unlikely because we are a subset of StableHLO. And the StableHLO for some backends may include generic MLIR operations as well.
I think this might be possible if we add a special kind of node for representing the unrepresentable StableHLO operations.
My main concern is about how regions can nest -- I don't think the nested ones can be split.
And even if we don't support all models, being able to support some models might already be a win.
It might be possible to use Beaver to parse StableHLO modules into an Elixir data structure, and convert that into the equivalent Nx.Defn.Expr.
This should allow us to operate on PyTorch and Jax exports not only to execute them through Nx, but for grad and, in the future, sharding.
Note: the first draft for this feature can be worked in this repository, but most likely it can be generally implemented in a separate library.