patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.07k stars 136 forks source link

Obtaining the name, values pairs in Equinox #362

Open sachith-gunasekara opened 1 year ago

sachith-gunasekara commented 1 year ago

Pytorch has this default function that yields the parameter name-value pairs in a module including all its submodules called torch.nn.Module.named_parameters().

This function can be used in a loop as follows: for pn, p in self.named_parameters()

Is it possible to achieve something like this in JAX/Equinox?

patrick-kidger commented 1 year ago

Note that the "name" of a parameter isn't a meaningful quantity in Equinox -- modules are just collections of parameters, much like a tuple-of-parameters. You can get just the parameters themselves by using jax.tree_util.tree_leaves(module), as modules are PyTrees.

If your use case is something like "I want to get all the biases", then the usual way to do it would be something like this:

def has_bias(x):
    return hasattr(x, "bias")

def get_biases(pytree):
    return [x.bias for x in jax.tree_util.tree_leaves(pytree, is_leaf=has_bias) if has_bias(x)]

or maybe

def is_linear(x):
    return isinstance(x, eqx.nn.Linear)

def get_biases(pytree):
    return [x.bias for x in jax.tree_util.tree_leaves(pytree, is_leaf=is_linear) if is_linear(x)]

If your use case is "I want to zero out all my biases" then you could do that by using the above get_biases function and using

model = eqx.tree_at(get_biases, model, replace_fn=jnp.zeros_like)

Does that help?

ASEM000 commented 1 year ago

jax from 0.4.6, supports named leaves for pytrees. So it's up to @patrick-kidger to support it.

If adopted, then this will simplify a lot of tree manipulation in equinox ( IMO, it is pretty tricky now using equinox )

For example, your example will be easily handled using the new jax API


import pytreeclass as pytc
import jax

class Model(pytc.TreeClass):
    def __init__(self):
        self.a = 1
        self.b = 2

m = Model()

for path, leaf in jax.tree_util.tree_flatten_with_path(m)[0]:
    print(path[-1],leaf)

# .a 1
# .b 2
patrick-kidger commented 1 year ago

That's a good point! Keypaths are now public API for JAX.

I've just added support in #363.

Anyway, to answer the original question: the equivalent to named_parameters is exactly what @ASEM000 has in his example: jax.tree_util.tree_flatten_with_path.

ASEM000 commented 1 year ago

Nice :D

Once this is merged, you can use -shameless plug- pytreeclass to add functional and composable setters/getters (~lenses like) to your equinox trees to modify values based on their name or boolean mask.


iimport equinox as eqx
import pytreeclass as pytc
import jax

class Tree(eqx.Module):
    weight: jax.Array = jax.numpy.array([-1, 2, 3])
    bias: jax.Array = jax.numpy.array([1])
    counter: int = 1

    @property
    def at(self):
        return pytc.AtIndexer(self, ())

tree = Tree()

tree = (
    tree.at["counter"]
    .set(1)  # set counter to 1
    .at[jax.tree_map(lambda x: x < 0, tree)]
    .set(0)  # set negative values to 0
    .at["bias"].set(100)  # set bias to 100
)

print(tree.weight)
# [0 2 3]
print(tree.bias)
# 100
print(tree.counter)
# 1
sachith-gunasekara commented 1 year ago

@patrick-kidger I understand that you have fixed the issue in #363. However, could you clarify how it can be used in the approach mentioned by @ASEM000?

patrick-kidger commented 1 year ago

@ASEM000 -- that's pretty neat! I really like that the libraries compose like this.

FWIW the Equinox equivalent is eqx.tree_at -- but it could maybe do with a nicer interface.

@sachith-gunasekara -- in the next release, you'll be able to do for path, leaf in jax.tree_util.tree_flatten_with_path(m)[0]. (As in @ASEM000's original post.)

ASEM000 commented 1 year ago

I think tree_at needs a makeover, possibly to match something like the above example. tree_at did a good job when as you said, named leaves did not have a meaning in Jax, but now as the key path API is a public API I think this assumption can be relaxed. I am happy to help if this sounds reasonable.

patrick-kidger commented 1 year ago

Sure, I'd be happy to see what you have in mind.

sachith-gunasekara commented 1 year ago

Thanks for the prompt responses, @patrick-kidger, @ASEM000.

My use case here is running through all the modules as well as submodules to find out for a certain leaf with a specific key, say "att". I believe at the moment Equinox is not able to interface this, at least until the next release, I presume.

patrick-kidger commented 1 year ago

You can do that today by following the pattern described in my first response.