graphcore-research / unit-scaling

A library for unit scaling in PyTorch
https://graphcore-research.github.io/unit-scaling/
Apache License 2.0
104 stars 7 forks source link

[Question] What is 'mult' parameter? #62

Open mmorinag127 opened 3 months ago

mmorinag127 commented 3 months ago

Dear authors,

Thank you very much for this great work!! I have a question about the 'mult' parameter in the library. Is this a hyperparameter which is described in the u-muP paper?

If yes, should we distinguish these parameters separately in some sweep runs?

DouglasOrr commented 3 months ago

Hi @mmorinag127, thanks for your question.

The parameter mult is shorthand for the $\alpha$ multipliers from the paper. Where the context disambiguates them, we've just called them mult. Versus Table 3 of the paper (v1),

HP Implementation
$\alpha_{ffn\mbox{-}act}$ functional.silu(mult=?)
$\alpha_{attn\mbox{-}softmax}$ functional.scaled_dot_product_attention(mult=?)
$\alpha_{res}$ core.functional.transformer_residual_scaling_rule(residual_mult=?)
$\alpha_{res\mbox{-}attn\mbox{-}ratio}$ core.functional.transformer_residual_scaling_rule(residual_attn_ratio=?)
$\alpha_{loss\mbox{-}softmax}$ functional.cross_entropy(mult=?)

They should be included in the associated unit_scaling.* modules, except that they're not all plumbed through to the TransformerLayer, which is more of a demo than a component designed for general reuse, since we expect library users to have various requirements at this level.

If yes, should we distinguish these parameters separately in some sweep runs?

Yes, in general all mults should be considered separately. The recipe of the paper suggests that setting most of them to 1 is a good starting point; but if swept during hyperparameter search, they should be independent. Our experience of sweeping one at a time is shown in Figure 11 (lower) in the paper.

mmorinag127 commented 2 months ago

Thank you so much for your help! This is incredible!

I have another question about other optimizers like ScheduleFreeAdam or SharpnessAwareMinimization. Both optimizers can be based on Adam(or AdamW); in this case, can we expand the u-muP for these optimizers?

thecharlieblake commented 2 months ago

Hi @mmorinag127. I don't see any reason why that wouldn't be possible. In the same way these could normally be implemented on top of Adam(W) one could implement them on top of our code. At a glance I don't see any reason why the unit scaling / u-µP rules would need to be modified in these instances.

If you're keen on trying one / both optimizers we'd be happy to add an implementation ourselves or review a PR

mmorinag127 commented 2 months ago

Hello,

Thanks a lot! Actualy I am working on JAX project, so I need to write this great work into JAX relevant code. I also have naive questions about how to implement the following case,

a, b, c = linear(y)
out = x * (1 + a) + b
out = linear(out)
out = x + c * out

In this case, we have two input tensors(x and y), then x is applied affine transform by y, then x is residual-added.

This case tells me a lot about the u-muP.