patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.12k stars 142 forks source link

PyTorch Hooks in Equinox? #864

Open samuelstevens opened 1 month ago

samuelstevens commented 1 month ago

I would like to record some model activations in an architecture-invariant way. In PyTorch, we can use forward hooks to do this, by registering a hook on modules that match some criteria (maybe all modules that are an MLP class, for example).

Is there an equivalent strategy in Equinox?

One idea is to create a class Wrapper(eqx.Module) that simply wraps a module and calls some callback in __call__ with the underlying module's activations, then somehow replace modules in an equinox module.

class Wrapper(eqx.Module):
    wrapped: eqx.Module
    callback: ...
    def __init__(self, module, callback):
        self.wrapped = module
        self.callback = callback

   def __calll__(self, *args, **kwargs):
       outs = self.wrapped(*args, **kwargs)
       self.callback(outs)  # this would save to disk or something

Then in the main script, I could do something like:

model = MyViT()
for i in range(n_layers):
    model = eqx.tree_at(lambda m: m.layers[i].mlp, replace_fn=lambda m: Wrapper(m, my_callback))

Is there a better/more obvious way to do this?

nasyxx commented 1 month ago

You can use jax.tree.leaves to get all Modules you want.

For example, if you need linear:

is_linear = lambda x: isinstance(x, eqx.nn.Linear)
get_linear = lambda m: [x
                        for x in jax.tree.leaves(m, is_leaf=is_linear)
                        if is_linear(x)]
linears = get_linear(model)
wrapped = [Wrapper(m, callback) for m in linears]
eqx.tree_at(get_linear, model, wrapped)

However, I'm not sure if your callback could run in the jitted module.

samuelstevens commented 1 month ago

Wow that's really neat, I can try that. I think I can use jax.debug.callback or jax.experimental.io_callback--not sure which will be better.