ASEM000 / pytreeclass

Visualize, create, and operate on pytrees in the most intuitive way possible.
https://pytreeclass.rtfd.io/en/latest
Apache License 2.0
42 stars 2 forks source link

Comparison with `equinox` #66

Open adam-hartshorne opened 1 year ago

adam-hartshorne commented 1 year ago

First, sorry to put it in issues, but there is no discussion tab. Second, I don't want this to sound negative or critical, rather I am genuinely interested in the reasoning (and it looks like your library has some nice unique features e.g. visualising PyTrees).

There are already a number of mature JAX libraries such as Equinox that handle the idea of constructing classes as PyTrees and layering on top convenient methods to manipulate them (plus I notice you written a NN library which then builds further). I was wondering why another set of libraries? What are the advantages of PyTreeClass and Serket over something like Equinox?

ASEM000 commented 1 year ago

Hello Adam, thank you for your question.

For background, there are a couple of libraries with similar ideas (Pytorch-like API) that predate Treex and Equinox and as you have seen, libraries with seemingly similar ideas that postdate them, each of these libraries has their reason to exist. Similarly, The landscape of neural network libraries is even more diverse, with Google and DeepMind alone having several such libraries, including objax, flax, haiku, and oryx.nn. Each of these libraries represents a slightly different conceptual model. Therefore, as you delve deeper into this landscape, you will likely discover a variant that aligns with your specific needs and preferences.

Now, let me explain why PyTreeClass exists when Equinox/Treex/simple_pytree/ exists; since you mentioned Equinox, I will try to focus more on it.

1- One of the core ideas of equinox is filtered transformations, where you filter your pytrees on the function level, while in PyTreeClass, you filter on the pytree level by masking. This is a deliberate decision, and doing this prevents me from creating automatic decorators like equinox.filter_{...} that parallels jax API.

I believe that mirroring an API can be a risky strategy (although it can be cleverly implemented like in jax.numpy) because it can lead to confusion and errors due to inconsistent behaviour. Additionally, it requires meticulous maintenance to keep up with updates to the original API (you can see examples of filter_ related issues in the Equinox issue tab). Moreover, debugging can be more challenging because you need to understand which nodes have been frozen and which are under training beforehand. If not handled carefully, this approach can introduce bugs and unexpected behaviours when interacting with pure Jax or other libraries. For these reasons, I prefer a more explicit method through masking, where it's possible to see which nodes are frozen before passing them to a function. This helps prevent any unforeseen outcomes.

2- IMO, PyTreeClass has better functional tree manipulation (~lenses-like), you can do couple of things with easily:

import optax
import pytreeclass as pytc
import jax

class Tree(pytc.TreeClass):
    a: float = 1.0
    b: float = 2.0
    c: float = 3.0

tree = Tree()

a_mask = tree.at[...].set(False).at["a"].set(True)
b_mask = tree.at[...].set(False).at["b"].set(True)
c_mask = tree.at[...].set(False).at["c"].set(True)

optim = optax.chain(
    # update `a` with sgd of learning rate 1
    optax.masked(optax.sgd(learning_rate=1), a_mask),
    # update `b` with sgd of learning rate -1
    optax.masked(optax.sgd(learning_rate=-1), b_mask),
    # update `c` with sgd of learning rate 0
    optax.masked(optax.sgd(learning_rate=0), c_mask),

)
import pytreeclass as pytc

class Tree(pytc.TreeClass):
    a: float = 1.0
    b: float = 2.0
    c: float = 3.0

    def add_leaf(self, name: str, value):
        setattr(self, name, value)

tree = Tree()
# Tree(a=1.0, b=2.0, c=3.0)

_ , tree_with_d = tree.at["add_leaf"]("d", 4.0)

tree_with_d
# Tree(a=1.0, b=2.0, c=3.0, d=4.0)

3- Debugging, all my viz tools are geared towards debugging; for example, you always have helpful information whenever you interact with trees. For example, for deep and nested networks, I usually resort to tree_diagram function with depth argument to navigate the network.

  1. more advanced features, like tree_map_with_trace, let you filter based on type path; this is useful if you want to freeze leaves with certain parent types ( Dropout layer leaves, for example). This is a unique feature of PyTreeClass

  2. Data model, pytreeclass blend the idea of pytree of arrays with array (optionally throw leafwise=True)

import pytreeclass as pytc import jax.numpy as jnp

class Tree(pytc.TreeClass, leafwise=True): a:int = 1 b:tuple[float] = (2.,3.) c:jax.Array = jnp.array([4.,5.,6.])

tree = Tree()

print(pytc.bcmap(jnp.where)(tree>2, tree+100, 0))

Tree(a=0, b=(0.0, 103.0), c=[104. 105. 106.])

print(tree.at[tree>1].apply(lambda x:x+100))

Tree(a=1, b=(102.0, 103.0), c=[104. 105. 106.])


7. Module design, this is where all other `PyTree` libraries have their flavour; I will focus on `Equinox` to explain my point; I will use the example I found [here](https://docs.kidger.site/equinox/examples/mnist/) 

```python

class CNN(eqx.Module):
    layers: list

    def __init__(self, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)
        # Standard CNN setup: convolutional layer, followed by flattening,
        # with a small MLP on top.
        self.layers = [
            eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
            eqx.nn.MaxPool2d(kernel_size=2),
            jax.nn.relu,
            jnp.ravel,
            eqx.nn.Linear(1728, 512, key=key2),
            jax.nn.sigmoid,
            eqx.nn.Linear(512, 64, key=key3),
            jax.nn.relu,
            eqx.nn.Linear(64, 10, key=key4),
            jax.nn.log_softmax,
        ]

    def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
        for layer in self.layers:
            x = layer(x)
        return x

In equinox, you need to declare your trainable params as type hinted fields on top of your class, so if you want the previous example to have nn.conv1 to point to the first convolution layer, for example, then you have to do something like this:


class CNN(eqx.Module):
    conv1:eqx.nn.Conv2d
    pool1:eqx.nn.MaxPool2d
    linear1:eqx.nn.Linear
    linear2:eqx.nn.Linear

    def __init__(self, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)
        # Standard CNN setup: convolutional layer, followed by flattening,
        # with a small MLP on top.
        self.conv1 = eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1)
        self.pool1 = eqx.nn.MaxPool2d(kernel_size=2)
        self.linear1 = eqx.nn.Linear(1728, 512, key=key2)
        self.linear2 = eqx.nn.Linear(512, 10, key=key3)

    def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
        x = self.conv1(x)
        x = self.pool1(x)
        x = jax.nn.relu(x)
        x = jnp.ravel(x)
        x = self.linear1(x)
        x = jax.nn.sigmoid(x)
        x = self.linear2(x)
        x = jax.nn.log_softmax(x)

IMO, This is a repetitive design. The example above escapes this repetition by using a mutable container(list) to wrap all the layers, but you must use something like nn. layers[0] instead of nn.conv1 to fetch your first layer which hurts ergonomics. moreover, by doing so, you lose the immutability (try nn.layers.pop() ) essential to correct behaviour under Jax. Another reason you want to avoid using tuple/list as a layer container is that you are missing out the name of the layer/leaf which can be accessed using jax.tree_util.tree_map_with_path from jax

In pytreeclass, all class variables are leaves by default. If you want to filter non-trainable parameters, use a mask, as seen in the readme.

  1. finally, I am a user Equinox. I use Equinox's internal tools equinox.internal and i think my library must play nicely with others in the jax ecosystem, This is why pytreeclass does not have any special treatment for non-pytreeclass instances. You can use all these tools with any library you like (e.g. flax/equinox/haiku).

so for the CNN example, you can inherit all pytreeclass pros by doing something like this:

import pytreeclass as pytc 
import equinox as eqx 

class CNN(pytc.TreeClass):
   def __init__(self, key):
       key1, key2, key3, key4 = jax.random.split(key, 4)
       # Standard CNN setup: convolutional layer, followed by flattening,
       # with a small MLP on top.
       self.conv1 = eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1)
       self.pool1 = eqx.nn.MaxPool2d(kernel_size=2)
       self.linear1 = eqx.nn.Linear(1728, 512, key=key2)
       self.linear2 = eqx.nn.Linear(512, 10, key=key3)

   def __call__(self, x):
       x = self.conv1(x)
       x = self.pool1(x)
       x = jax.nn.relu(x)
       x = jnp.ravel(x)
       x = self.linear1(x)
       x = jax.nn.sigmoid(x)
       x = self.linear2(x)
       x = jax.nn.log_softmax(x)

For serket, you inherit the tools and mental model of pytreeclass, while being 100% compatible with other libraries including equinox. If you are a user of eqx.nn, you can use serket layers that does not exist in equinox like fft convolution within equinox if you like.

Let me know if this answers your question.

adam-hartshorne commented 1 year ago

Thank you for the extremely in-depth response. It will take me some time to consider all that has been stated, but my interest has definitely been peaked.

One other quick question. Do you have any benchmarking for your implementation vs say Equinox for a range of uses? Obviously, I saw your charts for flatten / unflatten, which look very good. I wonder how it performs in terms of memory / speed, when it comes to various common NN architectures (as I have found over the years, JAX can be very sensitive in which small changes in code when it comes to using things like vmap's - this is obviously down to how JAX / XLA optimisation is being conducted).

ASEM000 commented 1 year ago

~~Except flax.struct, I think most Pytree libraries should behave similarly regarding memory/speed. PytreeClass is slightly faster because no logic (for static fields) is done when flattening/unflattening.~~ Check readme for benchmark links

ASEM000 commented 1 year ago

For reference: [1] Pytree-based implementation : one that predates equinox/treex flax PyTreeNode, another one that postdate it pax [2] equinox tree_at sample issue [3] filter inconsistent behavior-sample issues 1, 2