cgarciae / treex

A Pytree Module system for Deep Learning in JAX
https://cgarciae.github.io/treex/
MIT License
214 stars 17 forks source link

Comparing Treex with Equinox #72

Closed nalzok closed 2 years ago

nalzok commented 2 years ago

I think it is natural to compare Treex with Equinox, as both are PyTree-based libraries. The README currently says

Other Pytree-based approaches like Parallax and Equinox do not have a total state management solution to handle complex states as encountered in Flax. Treex has the Filter and Update API, which is very expressive and can effectively handle systems with a complex state.

I assume the total state management solution refers to the Kind system in Treeo. However, the recent RFC indicates that we cannot use that with higher-level frameworks like Elegy. Suppose I want to use Elegy for training loop automation, is there any reason I should prefer Treex over Equinox?

cgarciae commented 2 years ago

Its a very good question. I've been following Equinox closely as Treex was inspired by it, and talk every one in a while with Patrick. Ideally we only have a single Pytree Module library but as it stands here is the situation:

If you don't need BatchNorm or other potentially stateful layers (anything that needs a 'cache') then Equinox is a good choice, I highly encourage you to try it out, else Treex was designed to try deal with these cases (although the state management situation for pytrees is not perfect). I will definitely try to add Equinox support for Elegy.

nalzok commented 2 years ago

Thanks. That's very clear! Hopefully, JAX will provide native support for PyDAG with tagged leaves one day...

nalzok commented 2 years ago

I just looked at Equinox, and it apparently also has some support for stateful layers. Could you elaborate on why we must store metadata in the leaves to support so-called "lifted transformations"?

For example, maybe we just naively use suffixes in the keys to indicate the "kind" of a leaf, e.g. xxx_param means it's a parameter and xxx_state means it's a model state, e.g. {"weight_param": *, "bias_param": *, "batch_stat_state": *}. I think this gives you everything you need to filter for a certain kind of PyTree node.

Additionally, what is the relationship between filtering and state management? The relevant page in the Treex documentation doesn't really mention filtering, so I cannot understand how these two concepts are related...

cgarciae commented 2 years ago

Stateful layers like BatchNorm are experimental in Equinox, but I think this itself is not a problem. The real issue is that Equinox is designed to work with arbitrary pytrees so its filters can really only take into account leaf values, this means you don't have access any metadata including the field's name, which in turn means you probably cant express something like "select all the batch stats and cache leaves".

Treeo filters on the other hand operate over FieldInfo which have access to the following information:

class FieldInfo:
    name: Optional[str]
    value: Any
    kind: Type[Any]
    module: Optional[Tree]

So if leaf is part of a treeo.Tree the filter has access to field's name, kind, and the actual module instance (tree would've been a better name for this), and of course the leaf value itself. This makes treeo/treex filters a bit more powerful.

nalzok commented 2 years ago

I see. Thanks for the explanation!