Open mnslarcher opened 6 months ago
Good catch! Yeah, we should consider changing this, although at this point (given that it also seems to work well with the double norm), we also have to consider whether it's worth introducing potentially breaking changes for the public implementation.
Sure, it's not a big deal, maybe I'd consider adding something like use_norm=True
or norm=True
in the MappingFeedForwardBlock
. It'll keep things as they are for now, but later on, if you wanna turn off normalization when reusing the block, you'll have the option. Anyway, it's pretty much a ~0 impact thing, probably not worth making the code more complex.
Hi there!
I've noticed that in the
forward
method ofMappingNetwork
, you applyRMSNorm
to the input:However,
MappingFeedForwardBlock
also performs normalization, which means the first block normalizes input that has already been normalized. Here's the current implementation for reference:Wouldn't it make sense to introduce an option to toggle normalization in
MappingFeedForwardBlock
and turn it off for the first block?To be honest, even with this setup, the
RMSNorm
in the first block still plays a role, as there could be different scales forskip
andx
.Just a thought while reviewing the code – feel free to ignore if it's not relevant!