patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.04k stars 135 forks source link

Spliting a module in more than 2 partitions #824

Open francois-rozet opened 2 weeks ago

francois-rozet commented 2 weeks ago

I regularly need to split modules into more than 2 partitions. For example, linear parameters, conv parameters and others. It is not easy to do with the current eqx.partition function.

In Inox, inox.tree_partition allows any number of partitions. If a single predicate (filter) is provided, it has the same behavior as eqx.partition. If more than one are provided, ambiguities are resolved by the order of the predicates.

I think it would be easy to adapt the logic to eqx.partition. I can do it if you think this feature is worth adding.

patrick-kidger commented 2 weeks ago

Sure, I'd be happy to take a PR on that!

francois-rozet commented 2 weeks ago

Hi @patrick-kidger, I have created the PR!

On a related subject, I took the opportunity to do a few speed tests. In Equinox, the partitions are trees where some leaves are replaced by None (or another sentinel). In Inox, the partitions are path-leaf mappings (i.e. flat dictionaries), similar to state_dict in PyTorch.

There are three advantages to using dictionaries:

  1. They are easy to "read" for a human.
  2. They are safe(r) to pickle/serialize.
  3. For simple pytrees, the cost of flattening is proportional to the number of leaves. With tree partitions (Equinox), the number of leaves is constant, even when there is a lot of metadata leaves compared to array leaves. With dictionary partitions (Inox), the array partition is much smaller (has less leaves) than the full tree. This results in faster partition flattening and, consequently, faster evaluation of jitted functions.

Follows a small example to illustrate the difference.

import equinox as eqx
import inox
import jax

tree = None

for i in range(100):
    tree = {
        "child": tree,
        "index": i,
        "array": jax.numpy.ones(i + 1),
        "meta": ("irrelevant", "information"),
    }

def f(tree):
    total = 0

    while tree:
        total = total + tree["array"][tree["index"]]
        tree = tree["child"]

    return total

## Equinox
arrays, others = eqx.partition(tree, eqx.is_array)

def g(arrays):
    tree = eqx.combine(arrays, others)
    return f(tree)

g_jit = jax.jit(g)

%timeit g_jit(arrays)  # 67 µs ± 6.64 µs per loop

## Inox
treedef, arrays, others = inox.tree_partition(tree, jax.Array)

def g(arrays):
    tree = inox.tree_combine(treedef, arrays, others)
    return f(tree)

g_jit = jax.jit(g)

%timeit g_jit(arrays)  # 32.4 µs ± 1.19 µs per loop

If this difference of efficiency matters to you, I see two options:

  1. Modifying eqx.partition and eqx.combine to return flat (dictionary) partitions.
  2. Adding new method eqx.tree_partition and eqx.tree_combine (or eqx.leaf_*) that return flat (dictionary) partitions. And possibly update the tutorials to drop the mentions to eqx.partition and eqx.combine.
knyazer commented 1 week ago

%timeit g_jit(arrays) # 67 µs ± 6.64 µs per loop %timeit g_jit(arrays) # 32.4 µs ± 1.19 µs per loop

I'm not Patrick, but I would like to chim in :)

I don't think it is of significant importance to optimize speed of partitioning:

  1. Partitioning has (effectively) zero runtime overhead, so this should not affect you unless your code is extremely tracing heavy. I have encountered such code a few times, but it was mostly due to people not knowing how tracing works (e.g. using for loops instead of scan)
  2. Note how I used the term 'tracing' in the previous point. Tracing takes only a small part of the compilation, under normal conditions: most of the time is usually taken by XLA compilation, and any Python code is not executed during this stage. You can see that by enabling jax compile log.

To sum up, partitioning is usually executed only constant (and small) number of times per mention of eqx.partition in the source code, so the overhead of 30 microseconds is unlikely to play any role.


Btw, I think your example is a not a perfect illustration of the performance differences: you use a static for-loop to generate these large trees (it seems), which might contribute a lot to the compilation time? I wonder if there is a better example out there :smile: So, you might totally have a point if the differences are (much) larger in reality.


Addressing other points. More readable: sure; easier to work with/not breaking backwards compatibility - not so sure. About safety I have no comment, I am not necessarily sure whether dictionaries are more safe, but I shall withhold my judgment.

francois-rozet commented 1 week ago

Hi @knyazer, thank you for jumping in. My analysis is not about the speed of partitioning. It is about the speed of flattening (already built) partitions. Indeed, partitioning happens once. However, flattening happens at every call of a jitted function. JAX must flatten the inputs to check if the function has already been compiled for the inputs' static structure (treedef) or not.

My analysis is also not about the speed of compilation. In my example, the compilation happens only once. The time difference is only due to the speed of flattening the inputs. If the trees were larger/deeper the difference between the two implementation would be even larger.

Regarding safety, what I meant is that pickling/unpicking a dictionary of arrays does not execute any code.

francois-rozet commented 1 week ago

Here is a more extreme example where the ratio between array leaves and static leaves is very small (1 array leaf, 100 000 static leaves). The function is extremely fast to compile. The performance gap comes entirely from flattening. In this case, Equinox's approach is 1000 times slower than Inox's.

import equinox as eqx
import inox
import jax

tree = {
    "array": jax.numpy.ones(()),
    "meta": list(range(100000)),
}

def f(tree):
    return tree["array"]

## Equinox
arrays, others = eqx.partition(tree, eqx.is_array)

def g(arrays):
    tree = eqx.combine(arrays, others)
    return f(tree)

g_jit = jax.jit(g)
g_jit(arrays)

%timeit g_jit(arrays)  # 3.72 ms ± 87.3 µs per loop

## Inox
treedef, arrays, others = inox.tree_partition(tree, jax.Array)

def g(arrays):
    tree = inox.tree_combine(treedef, arrays, others)
    return f(tree)

g_jit = jax.jit(g)
g_jit(arrays)

%timeit g_jit(arrays)  # 2.7 µs ± 41.3 ns per loop
patrick-kidger commented 1 week ago

The choice of dictionaries vs keeping the tree structure was one I considered in the early days of Equinox. I decided against it for several reasons:

(a) In the early days of JAX, there was no way to express the path to a location in a PyTree. This would only have worked for certain kinds of PyTrees. (b) It's less readable, IMO. I much prefer still having the actual object of the appropriate type.

Ultimately this is far too breaking a choice to revisit now. It's not one that ever really affects speed. Almost all programs are large enough that the speed of the JIT'd region dominates the flattening (your example has a tiny function that does no work), and in the event that even that is too much then there it is possible to telescope the flattening.

francois-rozet commented 1 week ago

Note that I also proposed to add new functions, which would not break anything.

Adding new methods eqx.tree_partition and eqx.tree_combine (or eqx.leaf_*) that return flat (dictionary) partitions.