HomebrewNLP / Olmax

HomebrewNLP in JAX flavour for maintable TPU-Training
BSD 2-Clause "Simplified" License
45 stars 6 forks source link

Balance update weights of depthwise vs. pointwise convolution #53

Closed ClashLuke closed 2 years ago

ClashLuke commented 2 years ago

Currently, we're balancing the update sizes by fan-in features. Unfortunately, our bottleneck convolution has a 5x lower learning rate and 4x fewer output features, meaning that the effective update size is 20x smaller. Similarly, our bottleneck block's dilated (#52) convolution has a 10x larger kernel size and 40x smaller updates. While we intended the 5x/10x difference (MuParametrization), the 4x happens because our current MuParametrization implementation accounts only for fan_in but not fan_out.\ This issue tracks the progress of implementing a "fix" for this 4x reweighting and benchmarking it against the baseline.

ClashLuke commented 2 years ago

One simple way to implement this would be to balance the updates dynamically optimizer side instead of static multipliers. The static multipliers of MuParametrization help it keep the l2norm of updates at approximately 1. However, we could also force the updates to have an l2norm of 1 by using exp_avg * rsqrt(exp_avg_sq * exp_avg_sq.sum()) instead of exp_avg * rsqrt(exp_avg_sq) for Adam. (Analogous for SM3.)

ClashLuke commented 2 years ago

There is a sweep on wandb, which checks if this matters using the sweep-mup-scales branch.