elixir-nx / nx

Multi-dimensional arrays (tensors) and numerical definitions for Elixir
2.55k stars 190 forks source link

Use generic API for creating MLIR operations #1477

Closed jonatanklosko closed 2 months ago

jonatanklosko commented 2 months ago

Currently we use C++ API for each individual MLIR operation, which involves a separate NIF that delegates to a corresponding "builder" function. Additionally, we need to parse any configuration we pass to the NIF.

This PR adds a single generic EXLA.NIF.mlir_op(...), where we pass the operation name, operands, attributes, etc. This approach allows us to remove ~3k lines of C++, and doesn't even inflate the Elixir code. Attributes and types are passed as strings (same as in the human-readable MLIR syntax), and we call MLIR API to parse them.

One major difference between the two approaches is that most of the C++ operation APIs don't require passing the result types, because there is C++ logic that does inference based on the operands. With the generic operation this is no longer the case (since we don't use a specific operation class). Consequently, we need to explicitly pass the result types. In most cases it's trivial, since we can just take the type from the corresponding Defn.Expr; in others it's easy to compute.

MLIR types for tensors effectively consist of both the numeric type and the shape, so I added EXLA.Typespec to make it easier to pass this information around. I also entirely removed EXLA.Shape and replaced all its usage with EXLA.Typespec. EXLA.Shape uses a resource, which I don't think makes sense for this kind of small data; in many places we would actually call EXLA.Shape.make_shape(...) (NIF), just to pass the type/shape information around, or right before passing it to another NIF. EXLA.Typespec can be passed directly to NIFs using a compact representation, and we do that only in a couple places. I intentionally chose the name "typespec", to use different wording from "type" and "shape" (with EXLA.Shape it was ambiguous what a shape argument/variable means).

Sidenote: most changes related to the operations are in EXLA.MLIR.Value and MLIR.Defn. A lot of the other changes are related to EXLA.Shape -> EXLA.Typespec.

jonatanklosko commented 2 months ago

@polvalente updated!