Closed nalzok closed 2 years ago
Its a very good question. I've been following Equinox closely as Treex was inspired by it, and talk every one in a while with Patrick. Ideally we only have a single Pytree Module library but as it stands here is the situation:
Kind
s to Pytrees that are able to emulate Flax collections. This is needed for the advanced lifted transformations as you see on Flax: nn.{scan, vmap, etc}
, it also simplifies a lot of code since e.g. selecting all BatchStats
become a trivial filter.If you don't need BatchNorm
or other potentially stateful layers (anything that needs a 'cache'
) then Equinox is a good choice, I highly encourage you to try it out, else Treex was designed to try deal with these cases (although the state management situation for pytrees is not perfect). I will definitely try to add Equinox support for Elegy.
Thanks. That's very clear! Hopefully, JAX will provide native support for PyDAG with tagged leaves one day...
I just looked at Equinox, and it apparently also has some support for stateful layers. Could you elaborate on why we must store metadata in the leaves to support so-called "lifted transformations"?
For example, maybe we just naively use suffixes in the keys to indicate the "kind" of a leaf, e.g. xxx_param
means it's a parameter and xxx_state
means it's a model state, e.g. {"weight_param": *, "bias_param": *, "batch_stat_state": *}
. I think this gives you everything you need to filter for a certain kind of PyTree node.
Additionally, what is the relationship between filtering and state management? The relevant page in the Treex documentation doesn't really mention filtering, so I cannot understand how these two concepts are related...
Stateful layers like BatchNorm
are experimental in Equinox, but I think this itself is not a problem. The real issue is that Equinox is designed to work with arbitrary pytrees so its filters can really only take into account leaf values, this means you don't have access any metadata including the field's name, which in turn means you probably cant express something like "select all the batch stats and cache leaves".
Treeo filters on the other hand operate over FieldInfo
which have access to the following information:
class FieldInfo:
name: Optional[str]
value: Any
kind: Type[Any]
module: Optional[Tree]
So if leaf is part of a treeo.Tree
the filter has access to field's name
, kind
, and the actual module
instance (tree would've been a better name for this), and of course the leaf value
itself. This makes treeo
/treex
filters a bit more powerful.
I see. Thanks for the explanation!
I think it is natural to compare Treex with Equinox, as both are PyTree-based libraries. The README currently says
I assume the total state management solution refers to the
Kind
system in Treeo. However, the recent RFC indicates that we cannot use that with higher-level frameworks like Elegy. Suppose I want to use Elegy for training loop automation, is there any reason I should prefer Treex over Equinox?