elixir-nx / nx

Multi-dimensional arrays (tensors) and numerical definitions for Elixir
2.66k stars 194 forks source link

Investigate conversion from StableHLO to Nx.Defn.Expr #1547

Open polvalente opened 4 weeks ago

polvalente commented 4 weeks ago

It might be possible to use Beaver to parse StableHLO modules into an Elixir data structure, and convert that into the equivalent Nx.Defn.Expr.

This should allow us to operate on PyTorch and Jax exports not only to execute them through Nx, but for grad and, in the future, sharding.

Note: the first draft for this feature can be worked in this repository, but most likely it can be generally implemented in a separate library.

josevalim commented 4 weeks ago

I believe this is pretty unlikely because we are a subset of StableHLO. And the StableHLO for some backends may include generic MLIR operations as well.

polvalente commented 4 weeks ago

I think this might be possible if we add a special kind of node for representing the unrepresentable StableHLO operations.

My main concern is about how regions can nest -- I don't think the nested ones can be split.

And even if we don't support all models, being able to support some models might already be a win.