allenai / OLMo

Modeling, training, eval, and inference code for OLMo
https://allenai.org/olmo
Apache License 2.0
4.48k stars 449 forks source link

Reverse weight decay #567

Open AkshitaB opened 5 months ago

AkshitaB commented 5 months ago

Goal: Perform reverse weight decay on embeddings

Multiply weight_decay factor for the embeddings layer by (1 - norm(embeddings))

TODO:

I tried this on a tiny test model config and got an overflow error. Possibly this will not be an issue with the actual model.

Note: I created the branch from train-olmo-large. See this for actual diffs for this PR.

dirkgr commented 5 months ago

You are right. Then we need to make sure we compute this every time.

On Fri, May 3, 2024, 08:45 Akshita Bhagia @.***> wrote:

@.**** commented on this pull request.

In olmo/train.py https://github.com/allenai/OLMo/pull/567#discussion_r1589381100:

  • if should_log_optim_metrics_this_step:
  • emb_decay_factor = 1.0 - optim_metrics["param/transformer.wte.weight.norm"]
  • else:
  • emb_decay_factor = 1.0

We compute the norm of the gradient every step ( grad/transformer.wte.weight.norm), not the norm of the parameter itself ( param/transformer.wte.weight.norm). Don't we need the latter?

— Reply to this email directly, view it on GitHub https://github.com/allenai/OLMo/pull/567#discussion_r1589381100, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAHAYPRVJA3KTX3NEXIYH5DZAOWLNAVCNFSM6AAAAABHFMSCVWVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMZDAMZYGQ2DGOJZHA . You are receiving this because you commented.Message ID: @.***>

AkshitaB commented 5 months ago

You are right. Then we need to make sure we compute this every time.

Done

dirkgr commented 5 months ago

@epwalsh , can you look at this as well? This gets all up in your code.