Open mmorinag127 opened 3 months ago
Hi @mmorinag127, thanks for your question.
The parameter mult
is shorthand for the $\alpha$ multipliers from the paper. Where the context disambiguates them, we've just called them mult
. Versus Table 3 of the paper (v1),
HP | Implementation |
---|---|
$\alpha_{ffn\mbox{-}act}$ | functional.silu(mult=?) |
$\alpha_{attn\mbox{-}softmax}$ | functional.scaled_dot_product_attention(mult=?) |
$\alpha_{res}$ | core.functional.transformer_residual_scaling_rule(residual_mult=?) |
$\alpha_{res\mbox{-}attn\mbox{-}ratio}$ | core.functional.transformer_residual_scaling_rule(residual_attn_ratio=?) |
$\alpha_{loss\mbox{-}softmax}$ | functional.cross_entropy(mult=?) |
They should be included in the associated unit_scaling.*
modules, except that they're not all plumbed through to the TransformerLayer
, which is more of a demo than a component designed for general reuse, since we expect library users to have various requirements at this level.
If yes, should we distinguish these parameters separately in some sweep runs?
Yes, in general all mults should be considered separately. The recipe of the paper suggests that setting most of them to 1
is a good starting point; but if swept during hyperparameter search, they should be independent. Our experience of sweeping one at a time is shown in Figure 11 (lower) in the paper.
Thank you so much for your help! This is incredible!
I have another question about other optimizers like ScheduleFreeAdam or SharpnessAwareMinimization. Both optimizers can be based on Adam(or AdamW); in this case, can we expand the u-muP for these optimizers?
Hi @mmorinag127. I don't see any reason why that wouldn't be possible. In the same way these could normally be implemented on top of Adam(W) one could implement them on top of our code. At a glance I don't see any reason why the unit scaling / u-µP rules would need to be modified in these instances.
If you're keen on trying one / both optimizers we'd be happy to add an implementation ourselves or review a PR
Hello,
Thanks a lot! Actualy I am working on JAX project, so I need to write this great work into JAX relevant code. I also have naive questions about how to implement the following case,
a, b, c = linear(y)
out = x * (1 + a) + b
out = linear(out)
out = x + c * out
In this case, we have two input tensors(x and y), then x is applied affine transform by y, then x is residual-added.
This case tells me a lot about the u-muP.
Dear authors,
Thank you very much for this great work!! I have a question about the 'mult' parameter in the library. Is this a hyperparameter which is described in the u-muP paper?
If yes, should we distinguish these parameters separately in some sweep runs?