Open adam-hartshorne opened 1 year ago
Hello Adam, thank you for your question.
For background, there are a couple of libraries with similar ideas (Pytorch-like API) that predate Treex
and Equinox
and as you have seen, libraries with seemingly similar ideas that postdate them, each of these libraries has their reason to exist. Similarly, The landscape of neural network libraries is even more diverse, with Google and DeepMind alone having several such libraries, including objax, flax, haiku, and oryx.nn. Each of these libraries represents a slightly different conceptual model. Therefore, as you delve deeper into this landscape, you will likely discover a variant that aligns with your specific needs and preferences.
Now, let me explain why PyTreeClass exists when Equinox
/Treex
/simple_pytree
/ exists; since you mentioned Equinox
, I will try to focus more on it.
1- One of the core ideas of equinox
is filtered transformations, where you filter your pytrees on the function level, while in PyTreeClass
, you filter on the pytree level by masking. This is a deliberate decision, and doing this prevents me from creating automatic decorators like equinox.filter_{...}
that parallels jax API.
I believe that mirroring an API can be a risky strategy (although it can be cleverly implemented like in jax.numpy) because it can lead to confusion and errors due to inconsistent behaviour. Additionally, it requires meticulous maintenance to keep up with updates to the original API (you can see examples of filter_ related issues in the Equinox issue tab). Moreover, debugging can be more challenging because you need to understand which nodes have been frozen and which are under training beforehand. If not handled carefully, this approach can introduce bugs and unexpected behaviours when interacting with pure Jax or other libraries. For these reasons, I prefer a more explicit method through masking, where it's possible to see which nodes are frozen before passing them to a function. This helps prevent any unforeseen outcomes.
2- IMO, PyTreeClass has better functional tree manipulation (~lenses-like), you can do couple of things with easily:
import optax
import pytreeclass as pytc
import jax
class Tree(pytc.TreeClass):
a: float = 1.0
b: float = 2.0
c: float = 3.0
tree = Tree()
a_mask = tree.at[...].set(False).at["a"].set(True)
b_mask = tree.at[...].set(False).at["b"].set(True)
c_mask = tree.at[...].set(False).at["c"].set(True)
optim = optax.chain(
# update `a` with sgd of learning rate 1
optax.masked(optax.sgd(learning_rate=1), a_mask),
# update `b` with sgd of learning rate -1
optax.masked(optax.sgd(learning_rate=-1), b_mask),
# update `c` with sgd of learning rate 0
optax.masked(optax.sgd(learning_rate=0), c_mask),
)
import pytreeclass as pytc
class Tree(pytc.TreeClass):
a: float = 1.0
b: float = 2.0
c: float = 3.0
def add_leaf(self, name: str, value):
setattr(self, name, value)
tree = Tree()
# Tree(a=1.0, b=2.0, c=3.0)
_ , tree_with_d = tree.at["add_leaf"]("d", 4.0)
tree_with_d
# Tree(a=1.0, b=2.0, c=3.0, d=4.0)
3- Debugging, all my viz tools are geared towards debugging; for example, you always have helpful information whenever you interact with trees. For example, for deep and nested networks, I usually resort to tree_diagram
function with depth
argument to navigate the network.
more advanced features, like tree_map_with_trace
, let you filter based on type path; this is useful if you want to freeze leaves with certain parent types ( Dropout layer leaves, for example). This is a unique feature of PyTreeClass
Data model, pytreeclass blend the idea of pytree of arrays with array (optionally throw leafwise=True)
import pytreeclass as pytc import jax.numpy as jnp
class Tree(pytc.TreeClass, leafwise=True): a:int = 1 b:tuple[float] = (2.,3.) c:jax.Array = jnp.array([4.,5.,6.])
tree = Tree()
print(pytc.bcmap(jnp.where)(tree>2, tree+100, 0))
print(tree.at[tree>1].apply(lambda x:x+100))
7. Module design, this is where all other `PyTree` libraries have their flavour; I will focus on `Equinox` to explain my point; I will use the example I found [here](https://docs.kidger.site/equinox/examples/mnist/)
```python
class CNN(eqx.Module):
layers: list
def __init__(self, key):
key1, key2, key3, key4 = jax.random.split(key, 4)
# Standard CNN setup: convolutional layer, followed by flattening,
# with a small MLP on top.
self.layers = [
eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
eqx.nn.MaxPool2d(kernel_size=2),
jax.nn.relu,
jnp.ravel,
eqx.nn.Linear(1728, 512, key=key2),
jax.nn.sigmoid,
eqx.nn.Linear(512, 64, key=key3),
jax.nn.relu,
eqx.nn.Linear(64, 10, key=key4),
jax.nn.log_softmax,
]
def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
for layer in self.layers:
x = layer(x)
return x
In equinox, you need to declare your trainable params as type hinted fields on top of your class, so if you want the previous example to have nn.conv1 to point to the first convolution layer, for example, then you have to do something like this:
class CNN(eqx.Module):
conv1:eqx.nn.Conv2d
pool1:eqx.nn.MaxPool2d
linear1:eqx.nn.Linear
linear2:eqx.nn.Linear
def __init__(self, key):
key1, key2, key3, key4 = jax.random.split(key, 4)
# Standard CNN setup: convolutional layer, followed by flattening,
# with a small MLP on top.
self.conv1 = eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1)
self.pool1 = eqx.nn.MaxPool2d(kernel_size=2)
self.linear1 = eqx.nn.Linear(1728, 512, key=key2)
self.linear2 = eqx.nn.Linear(512, 10, key=key3)
def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
x = self.conv1(x)
x = self.pool1(x)
x = jax.nn.relu(x)
x = jnp.ravel(x)
x = self.linear1(x)
x = jax.nn.sigmoid(x)
x = self.linear2(x)
x = jax.nn.log_softmax(x)
IMO, This is a repetitive design. The example above escapes this repetition by using a mutable container(list) to wrap all the layers, but you must use something like nn. layers[0]
instead of nn.conv1
to fetch your first layer which hurts ergonomics. moreover, by doing so, you lose the immutability (try nn.layers.pop()
) essential to correct behaviour under Jax. Another reason you want to avoid using tuple/list as a layer container is that you are missing out the name of the layer/leaf which can be accessed using jax.tree_util.tree_map_with_path
from jax
In pytreeclass, all class variables are leaves by default. If you want to filter non-trainable parameters, use a mask, as seen in the readme.
Equinox
. I use Equinox's internal tools equinox.internal
and i think my library must play nicely with others in the jax ecosystem, This is why pytreeclass does not have any special treatment for non-pytreeclass instances. You can use all these tools with any library you like (e.g. flax/equinox/haiku).so for the CNN example, you can inherit all pytreeclass pros by doing something like this:
import pytreeclass as pytc
import equinox as eqx
class CNN(pytc.TreeClass):
def __init__(self, key):
key1, key2, key3, key4 = jax.random.split(key, 4)
# Standard CNN setup: convolutional layer, followed by flattening,
# with a small MLP on top.
self.conv1 = eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1)
self.pool1 = eqx.nn.MaxPool2d(kernel_size=2)
self.linear1 = eqx.nn.Linear(1728, 512, key=key2)
self.linear2 = eqx.nn.Linear(512, 10, key=key3)
def __call__(self, x):
x = self.conv1(x)
x = self.pool1(x)
x = jax.nn.relu(x)
x = jnp.ravel(x)
x = self.linear1(x)
x = jax.nn.sigmoid(x)
x = self.linear2(x)
x = jax.nn.log_softmax(x)
For serket
, you inherit the tools and mental model of pytreeclass, while being 100% compatible with other libraries including equinox. If you are a user of eqx.nn
, you can use serket
layers that does not exist in equinox like fft convolution within equinox if you like.
Let me know if this answers your question.
Thank you for the extremely in-depth response. It will take me some time to consider all that has been stated, but my interest has definitely been peaked.
One other quick question. Do you have any benchmarking for your implementation vs say Equinox for a range of uses? Obviously, I saw your charts for flatten / unflatten, which look very good. I wonder how it performs in terms of memory / speed, when it comes to various common NN architectures (as I have found over the years, JAX can be very sensitive in which small changes in code when it comes to using things like vmap's - this is obviously down to how JAX / XLA optimisation is being conducted).
~~Except flax.struct
, I think most Pytree libraries should behave similarly regarding memory/speed.
PytreeClass
is slightly faster because no logic (for static fields) is done when flattening/unflattening.~~ Check readme for benchmark links
For reference:
[1] Pytree-based implementation : one that predates equinox
/treex
flax PyTreeNode, another one that postdate it pax
[2] equinox tree_at
sample issue
[3] filter inconsistent behavior-sample issues 1, 2
First, sorry to put it in issues, but there is no discussion tab. Second, I don't want this to sound negative or critical, rather I am genuinely interested in the reasoning (and it looks like your library has some nice unique features e.g. visualising PyTrees).
There are already a number of mature JAX libraries such as Equinox that handle the idea of constructing classes as PyTrees and layering on top convenient methods to manipulate them (plus I notice you written a NN library which then builds further). I was wondering why another set of libraries? What are the advantages of PyTreeClass and Serket over something like Equinox?