kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.27k stars 890 forks source link

Verifying logic for LR schedule #185

Closed gupta-abhay closed 2 years ago

gupta-abhay commented 2 years ago

given the config

  "warmup_steps": 3000,
  "anneal_steps": 300000,
  "lr": 1.2e-4,
  "end_lr": 1.2e-5,
  "weight_decay": 0.1,
  "total_steps": 350000,

is my understanding correct that for the last 47K steps (after ending the cosine annealing schedule) -- we are running with a constant LR of 1.2e-5?