One bug with the get_lr() function:
If decay_start_step + num_decay_steps < total_steps, then the LR does the wrong thing after the decay period is over. I assume it should keep the LR constant after the decay period is over. You may need another conditional block to handle this case.
Minor suggestion with visualize: Do some formatting of the long arrays, to make them easier to read. e.g.
np.array([
39884406, 39043, 17289, 7420, 20263,
3, 7120, 1543, 63, 38532951, 2953546..]
One bug with the get_lr() function: If decay_start_step + num_decay_steps < total_steps, then the LR does the wrong thing after the decay period is over. I assume it should keep the LR constant after the decay period is over. You may need another conditional block to handle this case.
Minor suggestion with visualize: Do some formatting of the long arrays, to make them easier to read. e.g. np.array([ 39884406, 39043, 17289, 7420, 20263, 3, 7120, 1543, 63, 38532951, 2953546..]