warner-benjamin / fastxtend

Train fastai models faster (and other useful tools)
https://fastxtend.benjaminwarner.dev
MIT License
62 stars 5 forks source link

StableAdamW: Missing divisor in RMS computation #27

Closed EricZimmermann closed 5 months ago

EricZimmermann commented 5 months ago

Hello

It seems as though there is a minor error in computing RMS in Stable Adam:

RMS is computed as: https://github.com/warner-benjamin/fastxtend/blob/e6b2c39a9d2b70ec1038a6142d79c39765795806/fastxtend/optimizer/stableadam.py#L53

And is missing a normalization coefficient for the mean: rms = torch.norm(p.grad.data.div(root_sqr_avg.maximum(eps_t)), 2) / math.sqrt(p.grad.data.div.numel())

as per image

Additionally, another eps can be added in the final normalization.

warner-benjamin commented 5 months ago

I'm in the process of porting over the fastxtend optimizers to use the optimi implementations, which should resolve this.

warner-benjamin commented 5 months ago

@EricZimmermann #28 tracks the progress of porting over the fastxtend optimizers to use the optimi implementations