Open neel04 opened 4 months ago
I tried something like this, which does parameter sharing fine but it excruciatingly slow - about 30-50x slower.
That's probably because the nested eqx.Shared
does not play well during computation, unfortunately 😢
@staticmethod
def share_ffn_params(
block: List[RecurrentModule],
iters_to_do: int,
) -> List[RecurrentModule]:
'''
Create iters_to_do copies of RecurrentModule
where in each iteration, the Feedforward/linear layers are shared.
'''
shared: List[RecurrentModule] = block
shared_attrs = ['attention_layers.mlp.layer_1.weight', 'attention_layers.mlp.layer_1.bias',
'attention_layers.mlp.layer_2.weight', 'attention_layers.mlp.layer_2.bias']
def get_pytree_attr(i: int, attr: str) -> callable:
return lambda x: get_nested_attr(unwrap_pytree(x)[i], attr)
for i in range(1, iters_to_do):
for myattr in shared_attrs:
shared = eqx.nn.Shared(
shared,
where=get_pytree_attr(i, myattr),
get=get_pytree_attr(0, myattr),
)
return shared
def unwrap_pytree(pytree: PyTree) -> PyTree:
'''
Recursively unwrap a pytree
'''
if hasattr(pytree, 'pytree'):
return unwrap_pytree(pytree.pytree)
else:
return pytree
def get_nested_attr(obj: object, attr: str) -> Any:
attrs = attr.split('.')
for a in attrs:
obj = getattr(obj, a)
return obj
So probably the simplest thing to do would be to make the MLP an attribute of the parent module, and then either (if you can) call it directly from there, or (if required) pass this MLP as a __call__
-time argument to each of your layers.
FWIW the use of Shared
shouldn't impose a runtime penalty -- the overhead should vanish under JIT.
Yea I gave up and did that, though it definitely feels like param-sharing could use a bit more love in equinox.
FWIW the use of Shared shouldn't impose a runtime penalty -- the overhead should vanish under JIT
Well it was a nested eqx.Shared
object, so I don't know how well JIT would be able to handle it - especially since it can easily get up to a depth of ~15. I think it might be some sort of a memory layout issue, where everytime unwrap_callable
unwraps the object, it runs the tree_at
which instead of pointers actually moves data around from somewhere.
That's the only thing I can imagine causing these huge slowdowns 🤷
Yea I gave up and did that, though it definitely feels like param-sharing could use a bit more love in equinox.
Great. For what it's worth this passing around is how we used to do param-sharing in the 'old days' before eqx.nn.Shared
. It's a pattern I actually quite like... !
Well it was a nested eqx.Shared object, so I don't know how well JIT would be able to handle it - especially since it can easily get up to a depth of ~15. I think it might be some sort of a memory layout issue, where everytime unwrap_callable unwraps the object, it runs the tree_at which instead of pointers actually moves data around from somewhere.
So neither eqx.nn.Shared
nor eqx.tree_at
actually copy memory around or anything like that. They just do pure-Python manipulation to put arrays in the right places. If you're seeing a slowdown here I'm fairly sure it will be due to some other aspect of this.
I have a use-case where I want to replicate the whole NN
n
times (where each NN consists of some amount ofAttentionBlock
s).This is effectively a Universal Transformer, where a set of layers is repeated
n
times.But I want to share the FF part, so that I only have one set of FF weights rather than
n
of them, but retain seperate attention weights for eachn
.Currently, the way I setup is that a single NN
PyTree
holds all the blocks (4
) simultaneously (This is from using thescan-over
trick to reduce compilation time):Thus, I have a
List
ofn
PyTree
s wherein I want to share the FF weights for.To accomplish this weight-tying, I was thinking of doing something like:
and then repeat it for other attributes - like
mlp.{layer_1 | layer_2}.{weight | bias}
, theLayernorm
s etc.which seems extremely messy.
My reasoning behind this approach was that during
__call__
, I could just index into the list ofn
weight-tied NNs and do something likeWhat would be the equinox-y way of achieving this?