patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.07k stars 137 forks source link

Comparisons to NNX #718

Open jlperla opened 5 months ago

jlperla commented 5 months ago

Sorry to post issue, but didn't see discussion section on the repo.

What are your thoughts on https://flax.readthedocs.io/en/latest/experimental/nnx/index.html and it's scope relative to equinox? Do you see it as a possible replacement if it moves out of experimental stage, or is it intended to do something very different.

patrick-kidger commented 5 months ago

I believe it's pretty much a simplified version of Flax. It still has some framrework-y parts to it, i.e it isn't compatible with arbitrary JAX code (due to their creation of NNX-specific concepts like pygraphs).

I think for this reason it's mostly competing against Flax, rather than Equinox.

jlperla commented 5 months ago

Thanks @patrick-kidger

So just to make sure I understand: they have made the tradeoff of having a walled garden for what code is possible, along the lines of a Keras, to make coding with deep learning easier - but the tradeoff is that the new construction (e.g., pygraphs) mean that it is no longer JAX code in some sense. So for cookie-cutter ML prediction tasks it may be perfect, but for researcher code it is easy to hit fundamental limits of the framework?

patrick-kidger commented 5 months ago

Pretty much!

cgarciae commented 3 months ago

So just to make sure I understand: they have made the tradeoff of having a walled garden for what code is possible, along the lines of a Keras, to make coding with deep learning easier - but the tradeoff is that the new construction (e.g., pygraphs) mean that it is no longer JAX code in some sense. So for cookie-cutter ML prediction tasks it may be perfect, but for researcher code it is easy to hit fundamental limits of the framework?

I don't think this is fair characterization, NNX allows reference semantics at the expense of Modules not being pytrees, the tradeoff being that state management (including RNG handling) become trivial as its relegated to python but extra care has to be taken at functional boundaries. However there is no limit to expressivity, by using the Function API (which is similar to Filtering in equinox) one or more Modules can be converted into pytrees and back thus making it possible to integrate with any JAX API. To make usage easier NNX provides NNX Transforms which are a reference-aware version of JAX transforms (in the same spirit that equinox includes Filtered transforms). If needed, NNX Modules can also op-in register as PyTrees (this brakes reference sharing at boundaries):

class MyModule(nnx.Module, experimental_pytree=True):
  ...
jlperla commented 3 months ago

Thanks @cgarciae, that helps a lot.

Just trying to get a feeling for what different spaces the two libraries operate and whether they should be applied in different places. So you see any limits of NNX relative to Equinox for certain types of applications? Are there places where relying on the NNX pygraphs will get you stuck for non-standard pipelines (the stuff I do is much closer to the sorts of scientific machine learning situations, differentiable ODE solvers, where you need a lot of flexibility).

To help me form the mental model here: one big difference is that you are filtering based on nnx.Param when splitting and rebuilding classes, whereas Equinox is splitting and filtering along something a little more general (e.g. the eqx.is_inexact_array, etc.) ? So if I wanted to make something "learnable" I would need to wrap it in a Param? Are there any limitations there? For the sorts of things that I work on, they aren't just neural networks, but also involve joint optimization of multiple "functions" as well as parameters of different types. But I don't mind wrapping all of it in a nnx.Module with params and nnx.Module as elements? If so, though, I wouldn't really be using a __call__ for that wrapper?

In my mind it is totally fine if there are multiple libraries with different tradeoffs, so just want to see where they overlap.

cgarciae commented 3 months ago

Its hard to give unbiased advice so I'll just answer these two points:

So if I wanted to make something "learnable" I would need to wrap it in a Param?

Using nnx.Param is not needed, you can use bare Arrays and function filters. See Using Filters.

Are there places where relying on the NNX pygraphs will get you stuck for non-standard pipelines

Early adopters tend to do be doing non-stardand stuff and it hasn't been an issue.