tensorflow / similarity

TensorFlow Similarity is a python package focused on making similarity learning quick and easy.
Apache License 2.0
1.01k stars 104 forks source link

WarmUpCosine Wrong Computation #299

Closed yonigottesman closed 2 years ago

yonigottesman commented 2 years ago

Hi, while implementing the dino paper in tf #108 I noticed the WarmUpCosine output looks different. I checked the original simCLR implementation and its consistent with the dino one. Check out the outputs:

DINO

https://github.com/facebookresearch/dino/blob/main/utils.py#L187

def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
    warmup_schedule = np.array([])
    warmup_iters = warmup_epochs * niter_per_ep
    if warmup_epochs > 0:
        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)

    iters = np.arange(epochs * niter_per_ep - warmup_iters)
    schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))

    schedule = np.concatenate((warmup_schedule, schedule))
    assert len(schedule) == epochs * niter_per_ep
    return schedule

lr_schedule = cosine_scheduler(
    lr ,  
    min_lr,
    epochs, steps_per_epoch,
    warmup_epochs=10,
)
plt.plot(lr_schedule)
image

simclr

https://github.com/google-research/simclr/blob/master/tf2/model.py#L78

class WarmUpAndCosineDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
  """Applies a warmup schedule on a given learning rate decay schedule."""

  def __init__(self, base_learning_rate, name="None"):
    super(WarmUpAndCosineDecay, self).__init__()
    self.base_learning_rate = base_learning_rate

    self._name = name

  def __call__(self, step):
    with tf.name_scope(self._name or 'WarmUpAndCosineDecay'):
      scaled_lr=self.base_learning_rate

      learning_rate = (
          step / float(warmup_steps) * scaled_lr if warmup_steps else scaled_lr)

      # Cosine decay learning rate schedule

      # TODO(srbs): Cache this object.
      cosine_decay = tf.keras.experimental.CosineDecay(
          scaled_lr, total_steps - warmup_steps)
      learning_rate = tf.where(step < warmup_steps, learning_rate,
                               cosine_decay(step - warmup_steps))

      return learning_rate

c = WarmUpAndCosineDecay(lr)
plt.plot([c(i) for i in range(total_steps)])
image

tensorflow-similarity

c = WarmUpCosine(lr,total_steps,warmup_steps,alpha=min_lr/lr)
plt.plot([c(i) for i in range(total_steps)])
image

New Implementation

I suggest this implementation for this package:

class WarmupCosineDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, base_learning_rate, steps, warmup_steps, alpha=0, name=None):
        super().__init__()
        self.base_learning_rate = tf.convert_to_tensor(base_learning_rate, dtype=tf.float32)
        self.cosine_decay = tf.keras.experimental.CosineDecay(base_learning_rate, steps - warmup_steps, alpha)
        self.warmup_steps = tf.convert_to_tensor(warmup_steps, dtype=tf.float32)
        self.name = name

    def __call__(self, step):
        with tf.name_scope(self.name or "WarmupCosineDecay") as name:
            step = tf.cast(step, tf.float32)
            learning_rate = tf.cond(
                step < self.warmup_steps,
                lambda: tf.math.divide_no_nan(step, self.warmup_steps) * self.base_learning_rate,
                lambda: self.cosine_decay(step - self.warmup_steps),
                name=name,
            )
            return learning_rate
c = WarmupCosineDecay(lr,total_steps,warmup_steps,alpha=min_lr/lr)
plt.plot([c(i) for i in range(total_steps)])
image

@owenvallis what do you think?

owenvallis commented 2 years ago

Thanks for finding this @yonigottesman, and for the proposed solution. Can you add this as a PR to the developments branch? I'll cherry pick the changes into the master branch once we merge the PR.

owenvallis commented 2 years ago

@yonigottesman, I'm actually using this for some benchmark test at the moment. I added your changes to my local branch and I can merge the changes if that works for you. Let me know.

yonigottesman commented 2 years ago

yes thats cool. I wanted to double check but if its working for you then go ahead

owenvallis commented 2 years ago

Uploaded the new version to the development branch. @yonigottesman, let me know if this works for you and I'll cherry pick it into the master branch. Here are a few plots of the new output with various settings.

first_warmup_cosine_decay

second_warmup_cosine_decay

third_warmup_cosine_decay