elixir-nx / nx

Multi-dimensional arrays (tensors) and numerical definitions for Elixir
2.66k stars 194 forks source link

Add NIF for loading custom plugins #1519

Open seanmor5 opened 3 months ago

seanmor5 commented 3 months ago

WIP

josevalim commented 3 months ago

The general idea looks neat to me! Although we still need to figure out how they would be specified via defn.

We probably want to use optional callbacks for this. Maybe we allow anyone to define optional callbacks and then we allow optional callbacks to be registered dynamically, when we register the plugin? Then we have the "built-in" optional callbacks and the third party ones. Thoughts?

seanmor5 commented 3 months ago

Yeah I think it should be similar to optional. I talked to Paulo a bit yesterday and we are going to move the custom call registration to jit time so we can provide shape and type information to the implementer. That would require the user to register a library and then pass the name of the library plus the symbol to the custom call

Then we could have a fallback which is an Nx function

josevalim commented 3 months ago

Sounds good! Happy to discuss sketches of the API any time!

polvalente commented 3 months ago

The general idea looks neat to me! Although we still need to figure out how they would be specified via defn.

We probably want to use optional callbacks for this. Maybe we allow anyone to define optional callbacks and then we allow optional callbacks to be registered dynamically, when we register the plugin? Then we have the "built-in" optional callbacks and the third party ones. Thoughts?

I was thinking more along the ways of adding something like Nx.Defn.Expr.block(name, expr) which would mark a given subgraph with a specific name, and then EXLA would take advantage of the arity and name to delegate blocks to custom calls whenever relevant.

The mapping could even be block name -> custom_call name in an EXLA config key

josevalim commented 3 months ago

I was thinking more along the ways of adding something like Nx.Defn.Expr.block(name, expr) which would mark a given subgraph with a specific name, and then EXLA would take advantage of the arity and name to delegate blocks to custom calls whenever relevant.

When I was thinking about this yesterday I came to the conclusion that we need the function name, the input and output types, and a default implementation for other backends, and that's exactly what optional provides. Which is why I thought about using instead of coming up with a new construct.

josevalim commented 3 months ago

Yeah, and the main issue with ets is copying the data that you read. If the data is a reference, it should be plenty fast. Even a process should be good enough.