patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.13k stars 143 forks source link

Best way to do complex weight sharing #790

Open neel04 opened 4 months ago

neel04 commented 4 months ago

I have a use-case where I want to replicate the whole NN n times (where each NN consists of some amount of AttentionBlocks).

This is effectively a Universal Transformer, where a set of layers is repeated n times.

But I want to share the FF part, so that I only have one set of FF weights rather than n of them, but retain seperate attention weights for each n.

Currently, the way I setup is that a single NN PyTree holds all the blocks (4) simultaneously (This is from using the scan-over trick to reduce compilation time):

MLP(
  layer_1=LinearProj(
  bias=f32[4,256],
  weight=f32[4,64,256],
  input_dim=64,
  output_dim=256
),
  ....

Thus, I have a List of n PyTrees wherein I want to share the FF weights for.

To accomplish this weight-tying, I was thinking of doing something like:

for i in range(1, n):
    shared = eqx.nn.Shared(
        NN_list,
        where=lambda x: x[i].mlp.layer_1.weight,
        get=lambda x: x[i - 1].mlp.layer_1.weight,
    )

and then repeat it for other attributes - like mlp.{layer_1 | layer_2}.{weight | bias}, the Layernorms etc.

which seems extremely messy.

My reasoning behind this approach was that during __call__, I could just index into the list of n weight-tied NNs and do something like

for i in range(n):
    input = NN_list[i](input)

What would be the equinox-y way of achieving this?

neel04 commented 4 months ago

I tried something like this, which does parameter sharing fine but it excruciatingly slow - about 30-50x slower.

That's probably because the nested eqx.Shared does not play well during computation, unfortunately 😢

@staticmethod
def share_ffn_params(
    block: List[RecurrentModule],
    iters_to_do: int,
) -> List[RecurrentModule]:
    '''
    Create iters_to_do copies of RecurrentModule
    where in each iteration, the Feedforward/linear layers are shared.
    '''
    shared: List[RecurrentModule] = block

    shared_attrs = ['attention_layers.mlp.layer_1.weight', 'attention_layers.mlp.layer_1.bias',
                    'attention_layers.mlp.layer_2.weight', 'attention_layers.mlp.layer_2.bias']

    def get_pytree_attr(i: int, attr: str) -> callable:
        return lambda x: get_nested_attr(unwrap_pytree(x)[i], attr)

    for i in range(1, iters_to_do):
        for myattr in shared_attrs:
            shared = eqx.nn.Shared(
                shared,
                where=get_pytree_attr(i, myattr),
                get=get_pytree_attr(0, myattr),
            )

    return shared

 def unwrap_pytree(pytree: PyTree) -> PyTree:
    '''
    Recursively unwrap a pytree
    '''
    if hasattr(pytree, 'pytree'):
        return unwrap_pytree(pytree.pytree)
    else:
        return pytree

def get_nested_attr(obj: object, attr: str) -> Any:
    attrs = attr.split('.')
    for a in attrs:
        obj = getattr(obj, a)
    return obj
patrick-kidger commented 3 months ago

So probably the simplest thing to do would be to make the MLP an attribute of the parent module, and then either (if you can) call it directly from there, or (if required) pass this MLP as a __call__-time argument to each of your layers.

FWIW the use of Shared shouldn't impose a runtime penalty -- the overhead should vanish under JIT.

neel04 commented 3 months ago

Yea I gave up and did that, though it definitely feels like param-sharing could use a bit more love in equinox.

FWIW the use of Shared shouldn't impose a runtime penalty -- the overhead should vanish under JIT

Well it was a nested eqx.Shared object, so I don't know how well JIT would be able to handle it - especially since it can easily get up to a depth of ~15. I think it might be some sort of a memory layout issue, where everytime unwrap_callable unwraps the object, it runs the tree_at which instead of pointers actually moves data around from somewhere.

That's the only thing I can imagine causing these huge slowdowns 🤷

patrick-kidger commented 3 months ago

Yea I gave up and did that, though it definitely feels like param-sharing could use a bit more love in equinox.

Great. For what it's worth this passing around is how we used to do param-sharing in the 'old days' before eqx.nn.Shared. It's a pattern I actually quite like... !

Well it was a nested eqx.Shared object, so I don't know how well JIT would be able to handle it - especially since it can easily get up to a depth of ~15. I think it might be some sort of a memory layout issue, where everytime unwrap_callable unwraps the object, it runs the tree_at which instead of pointers actually moves data around from somewhere.

So neither eqx.nn.Shared nor eqx.tree_at actually copy memory around or anything like that. They just do pure-Python manipulation to put arrays in the right places. If you're seeing a slowdown here I'm fairly sure it will be due to some other aspect of this.