Open sachith-gunasekara opened 1 year ago
Note that the "name" of a parameter isn't a meaningful quantity in Equinox -- modules are just collections of parameters, much like a tuple-of-parameters. You can get just the parameters themselves by using jax.tree_util.tree_leaves(module)
, as modules are PyTrees.
If your use case is something like "I want to get all the biases", then the usual way to do it would be something like this:
def has_bias(x):
return hasattr(x, "bias")
def get_biases(pytree):
return [x.bias for x in jax.tree_util.tree_leaves(pytree, is_leaf=has_bias) if has_bias(x)]
or maybe
def is_linear(x):
return isinstance(x, eqx.nn.Linear)
def get_biases(pytree):
return [x.bias for x in jax.tree_util.tree_leaves(pytree, is_leaf=is_linear) if is_linear(x)]
If your use case is "I want to zero out all my biases" then you could do that by using the above get_biases
function and using
model = eqx.tree_at(get_biases, model, replace_fn=jnp.zeros_like)
Does that help?
jax
from 0.4.6, supports named leaves for pytrees. So it's up to @patrick-kidger to support it.
If adopted, then this will simplify a lot of tree manipulation in equinox ( IMO, it is pretty tricky now using equinox )
For example, your example will be easily handled using the new jax API
import pytreeclass as pytc
import jax
class Model(pytc.TreeClass):
def __init__(self):
self.a = 1
self.b = 2
m = Model()
for path, leaf in jax.tree_util.tree_flatten_with_path(m)[0]:
print(path[-1],leaf)
# .a 1
# .b 2
That's a good point! Keypaths are now public API for JAX.
I've just added support in #363.
Anyway, to answer the original question: the equivalent to named_parameters
is exactly what @ASEM000 has in his example: jax.tree_util.tree_flatten_with_path
.
Nice :D
Once this is merged, you can use -shameless plug- pytreeclass to add functional and composable setters/getters (~lenses like) to your equinox trees to modify values based on their name or boolean mask.
iimport equinox as eqx
import pytreeclass as pytc
import jax
class Tree(eqx.Module):
weight: jax.Array = jax.numpy.array([-1, 2, 3])
bias: jax.Array = jax.numpy.array([1])
counter: int = 1
@property
def at(self):
return pytc.AtIndexer(self, ())
tree = Tree()
tree = (
tree.at["counter"]
.set(1) # set counter to 1
.at[jax.tree_map(lambda x: x < 0, tree)]
.set(0) # set negative values to 0
.at["bias"].set(100) # set bias to 100
)
print(tree.weight)
# [0 2 3]
print(tree.bias)
# 100
print(tree.counter)
# 1
@patrick-kidger I understand that you have fixed the issue in #363. However, could you clarify how it can be used in the approach mentioned by @ASEM000?
@ASEM000 -- that's pretty neat! I really like that the libraries compose like this.
FWIW the Equinox equivalent is eqx.tree_at
-- but it could maybe do with a nicer interface.
@sachith-gunasekara -- in the next release, you'll be able to do for path, leaf in jax.tree_util.tree_flatten_with_path(m)[0]
. (As in @ASEM000's original post.)
I think tree_at
needs a makeover, possibly to match something like the above example. tree_at
did a good job when as you said, named leaves did not have a meaning in Jax, but now as the key path API is a public API I think this assumption can be relaxed.
I am happy to help if this sounds reasonable.
Sure, I'd be happy to see what you have in mind.
Thanks for the prompt responses, @patrick-kidger, @ASEM000.
My use case here is running through all the modules as well as submodules to find out for a certain leaf with a specific key, say "att". I believe at the moment Equinox is not able to interface this, at least until the next release, I presume.
You can do that today by following the pattern described in my first response.
Pytorch has this default function that yields the parameter name-value pairs in a module including all its submodules called torch.nn.Module.named_parameters().
This function can be used in a loop as follows:
for pn, p in self.named_parameters()
Is it possible to achieve something like this in JAX/Equinox?