Closed francois-rozet closed 1 month ago
Sure, I'd be happy to take a PR on that!
Hi @patrick-kidger, I have created the PR!
On a related subject, I took the opportunity to do a few speed tests. In Equinox, the partitions are trees where some leaves are replaced by None
(or another sentinel). In Inox, the partitions are path-leaf mappings (i.e. flat dictionaries), similar to state_dict
in PyTorch.
There are three advantages to using dictionaries:
Follows a small example to illustrate the difference.
import equinox as eqx
import inox
import jax
tree = None
for i in range(100):
tree = {
"child": tree,
"index": i,
"array": jax.numpy.ones(i + 1),
"meta": ("irrelevant", "information"),
}
def f(tree):
total = 0
while tree:
total = total + tree["array"][tree["index"]]
tree = tree["child"]
return total
## Equinox
arrays, others = eqx.partition(tree, eqx.is_array)
def g(arrays):
tree = eqx.combine(arrays, others)
return f(tree)
g_jit = jax.jit(g)
%timeit g_jit(arrays) # 67 µs ± 6.64 µs per loop
## Inox
treedef, arrays, others = inox.tree_partition(tree, jax.Array)
def g(arrays):
tree = inox.tree_combine(treedef, arrays, others)
return f(tree)
g_jit = jax.jit(g)
%timeit g_jit(arrays) # 32.4 µs ± 1.19 µs per loop
If this difference of efficiency matters to you, I see two options:
eqx.partition
and eqx.combine
to return flat (dictionary) partitions.eqx.tree_partition
and eqx.tree_combine
(or eqx.leaf_*
) that return flat (dictionary) partitions. And possibly update the tutorials to drop the mentions to eqx.partition
and eqx.combine
.%timeit g_jit(arrays) # 67 µs ± 6.64 µs per loop %timeit g_jit(arrays) # 32.4 µs ± 1.19 µs per loop
I'm not Patrick, but I would like to chim in :)
I don't think it is of significant importance to optimize speed of partitioning:
scan
)To sum up, partitioning is usually executed only constant (and small) number of times per mention of eqx.partition
in the source code, so the overhead of 30 microseconds is unlikely to play any role.
Btw, I think your example is a not a perfect illustration of the performance differences: you use a static for-loop to generate these large trees (it seems), which might contribute a lot to the compilation time? I wonder if there is a better example out there :smile: So, you might totally have a point if the differences are (much) larger in reality.
Addressing other points. More readable: sure; easier to work with/not breaking backwards compatibility - not so sure. About safety I have no comment, I am not necessarily sure whether dictionaries are more safe, but I shall withhold my judgment.
Hi @knyazer, thank you for jumping in. My analysis is not about the speed of partitioning. It is about the speed of flattening (already built) partitions. Indeed, partitioning happens once. However, flattening happens at every call of a jitted function. JAX must flatten the inputs to check if the function has already been compiled for the inputs' static structure (treedef) or not.
My analysis is also not about the speed of compilation. In my example, the compilation happens only once. The time difference is only due to the speed of flattening the inputs. If the trees were larger/deeper the difference between the two implementation would be even larger.
Regarding safety, what I meant is that pickling/unpicking a dictionary of arrays does not execute any code.
Here is a more extreme example where the ratio between array leaves and static leaves is very small (1 array leaf, 100 000 static leaves). The function is extremely fast to compile. The performance gap comes entirely from flattening. In this case, Equinox's approach is 1000 times slower than Inox's.
import equinox as eqx
import inox
import jax
tree = {
"array": jax.numpy.ones(()),
"meta": list(range(100000)),
}
def f(tree):
return tree["array"]
## Equinox
arrays, others = eqx.partition(tree, eqx.is_array)
def g(arrays):
tree = eqx.combine(arrays, others)
return f(tree)
g_jit = jax.jit(g)
g_jit(arrays)
%timeit g_jit(arrays) # 3.72 ms ± 87.3 µs per loop
## Inox
treedef, arrays, others = inox.tree_partition(tree, jax.Array)
def g(arrays):
tree = inox.tree_combine(treedef, arrays, others)
return f(tree)
g_jit = jax.jit(g)
g_jit(arrays)
%timeit g_jit(arrays) # 2.7 µs ± 41.3 ns per loop
The choice of dictionaries vs keeping the tree structure was one I considered in the early days of Equinox. I decided against it for several reasons:
(a) In the early days of JAX, there was no way to express the path to a location in a PyTree. This would only have worked for certain kinds of PyTrees. (b) It's less readable, IMO. I much prefer still having the actual object of the appropriate type.
Ultimately this is far too breaking a choice to revisit now. It's not one that ever really affects speed. Almost all programs are large enough that the speed of the JIT'd region dominates the flattening (your example has a tiny function that does no work), and in the event that even that is too much then there it is possible to telescope the flattening.
Note that I also proposed to add new functions, which would not break anything.
Adding new methods
eqx.tree_partition
andeqx.tree_combine
(oreqx.leaf_*
) that return flat (dictionary) partitions.
I regularly need to split modules into more than 2 partitions. For example, linear parameters, conv parameters and others. It is not easy to do with the current
eqx.partition
function.In Inox,
inox.tree_partition
allows any number of partitions. If a single predicate (filter) is provided, it has the same behavior aseqx.partition
. If more than one are provided, ambiguities are resolved by the order of the predicates.I think it would be easy to adapt the logic to
eqx.partition
. I can do it if you think this feature is worth adding.