Open samuelstevens opened 1 month ago
You can use jax.tree.leaves
to get all Modules you want.
For example, if you need linear:
is_linear = lambda x: isinstance(x, eqx.nn.Linear)
get_linear = lambda m: [x
for x in jax.tree.leaves(m, is_leaf=is_linear)
if is_linear(x)]
linears = get_linear(model)
wrapped = [Wrapper(m, callback) for m in linears]
eqx.tree_at(get_linear, model, wrapped)
However, I'm not sure if your callback could run in the jitted module.
Wow that's really neat, I can try that. I think I can use jax.debug.callback
or jax.experimental.io_callback
--not sure which will be better.
I would like to record some model activations in an architecture-invariant way. In PyTorch, we can use forward hooks to do this, by registering a hook on modules that match some criteria (maybe all modules that are an MLP class, for example).
Is there an equivalent strategy in Equinox?
One idea is to create a
class Wrapper(eqx.Module)
that simply wraps a module and calls some callback in__call__
with the underlying module's activations, then somehow replace modules in an equinox module.Then in the main script, I could do something like:
Is there a better/more obvious way to do this?