larq / zoo

Reference implementations of popular Binarized Neural Networks
https://docs.larq.dev/zoo
Apache License 2.0
104 stars 18 forks source link

Reproducing R2B model #233

Open Hyungjun-K1m opened 4 years ago

Hyungjun-K1m commented 4 years ago

Hi, I tried to reproduce the r2b model results, but I couldn't reach 65.04% top-1 accuracy. In fact, I couldn't even reproduce the first stage which is training original ResNet18 model (expected to achieve 70.32% validation accuracy (https://github.com/larq/zoo/issues/196#issuecomment-658080118)). It seems like the default hyper-parameter settings provided in the model zoo is not the optimal setting. Can you provide the hyper-parameter settings for each stage in training r2b model?

P.S.) The training curve below is what we obtained by running the default script provided in larq model zoo. We could only achieve 66.03% top-1 val. accuracy with the default setting. train_curve

leonoverweel commented 4 years ago

Hi, thanks for raising this. The training script on Zoo is indeed not what we used to train the results we reported internally - that code is quite tightly coupled to our training infrastructure.

Specifically for the FP baseline (I think it makes sense to get that training properly before looking at the rest), here are a few things I noticed looking at the logs of our internal run vs the default training script on Zoo:

Hyungjun-K1m commented 4 years ago

Thanks for the quick reply!

I've checked what you advised and found that the second and the third points were already used by default setting. Regarding the first point (weight decay and L2 regularization), what do you mean by 'using weight decay instead of L2 regularization'? Do you mean that you used SGDW or ADAMW optimizer proposed by this paper? Also, it seems like you suggest to use LR=0.001 instead of 0.1 which is used in the default setting. Does that imply that you used ADAM optimizer with LR=0.001?

Thanks, Best regards, Hyungjun

leonoverweel commented 4 years ago

No problem!

We used regular Adam as optimizer. For weight decay we added the following kernel_constraint to Larq layers:

import tensorflow as tf
from larq import utils as lq_utils

@lq_utils.register_keras_custom_object
class WeightDecay(tf.keras.constraints.Constraint):
    def __init__(
        self,
        max_learning_rate: float,
        weight_decay_constant: float,
        optimizer: Optional[tf.keras.optimizers.Optimizer] = None,
    ) -> None:
        """Weight decay constraint which can be passed in as a `kernel_constraint`.
        This allows for applying weight decay correctly with any optimizer. This is not
        the case when using l2 regularization, which is not equivalent to using weight
        decay for anything other than SGD without momentum.
        When using this class, make sure to pass in the `learning_rate_variable` that is
        updated during training.
        :param max_learning_rate: maximum learning rate, used to normalize the current
            learning rate.
        :param optimizer: keras optimizer that has a lr variable (which can optionally be a schedule).
        :param weight_decay_constant: strength of the weight decay.
        """
        self.optimizer = optimizer if optimizer is not None else max_learning_rate
        self.max_learning_rate = max_learning_rate
        self.weight_decay_constant = weight_decay_constant

        if self.max_learning_rate <= 0:
            warnings.warn(
                "WeightDecay: no weight decay will be applied as the received learning rate is 0."
            )
            self.multiplier = 0
        else:
            self.multiplier = self.weight_decay_constant / self.max_learning_rate

    def __call__(self, x):
        if isinstance(
            self.optimizer.lr, tf.keras.optimizers.schedules.LearningRateSchedule,
        ):
            lr = self.optimizer.lr(self.optimizer.iterations)
        else:
            lr = self.optimizer.lr
        return (1.0 - lr * self.multiplier) * x

    def get_config(self):
        return {
            "max_learning_rate": self.max_learning_rate,
            "weight_decay_constant": self.weight_decay_constant,
        }

With max_learning_rate=0.001, weight_decay_constant=1e-5, and optimizer= our (Adam) optimizer instance.

You may want to try both 0.1 and 0.001 as learning rate (and max_learning_rate); I can't see which of those we used for the FP ResNet.

Hyungjun-K1m commented 4 years ago

Based on your comments, we tried to use adam optimizer for resnet18 FP model. First, we tried to train the model with regular Adam optimizer with L2 regularization. We didn't add the WeightDecay class you provided and just ran the training with LR=0.001, WD=1e-5. That results in 70.01% val. accuracy which is much higher than before. Then, we tried to use weight decay instead of L2 regularization as you suggested. Actually, it seems like your team indeed followed this paper to use Adam optimizer with fixed weight decay (instead of L2 regularization). Can you confirm this?

Anyway, we added the code you provided in the real_to_bin_nets.py. And also modified the ResNet18FPFactory like this.

@factory
class ResNet18FPFactory(ResNet18Factory):
    model_name = Field("resnet_fp")
    input_quantizer = None
    kernel_quantizer = None
    optimizer = lambda self: tf.keras.optimizers.Adam(
        CosineDecayWithWarmup(
            max_learning_rate=self.learning_rate,
            warmup_steps=self.warmup_duration * self.steps_per_epoch,
            decay_steps=(self.epochs - self.warmup_duration) * self.steps_per_epoch,
        )
    )
    kernel_constraint = WeightDecay(max_learning_rate=1e-3, weight_decay_constant=1e-5, optimizer=optimizer)                                                                                                                                                                                                                 

Since we were not clear about how to pass the Adam optimizer instance declared here to the WeightDecay class argument, we defined the optimizer right before the kernel_constraint again. We are not sure if this approach is correct way to use the weight decay with Adam optimizer. With this configuration, we achieved 67.85% val. accuracy which is lower than the case without weight decay. Please let me know if we've done anything wrong.

leonoverweel commented 4 years ago

First, we tried to train the model with regular Adam optimizer with L2 regularization. We didn't add the WeightDecay class you provided and just ran the training with LR=0.001, WD=1e-5. That results in 70.01% val. accuracy which is much higher than before.

Ah, nice - since this is quite close to the expected 70.32%, I'd recommend just using this setup (Adam + L2) then. Our use of the WeightDecay class instead of L2 actually does not follow the paper, so since that didn't work for you anyway it might be better to just stick with L2.


As a note: we provided the classes in multi_stage_experiments.py primarily as an example of how to use Zoo's multi-stage infrastructure to implement a set of training steps - it is not the exact setup we used to train the pretrained weights available for our implementation of Real-to-Binary nets. Therefore, these exact steps are not expected to reproduce our weights or reported accuracies. (I've added this note to the code in #234.)

Because of some internal (cluster) infrastructure changes between when we ran those old experiments and now, it's quite difficult to find these exact settings from the experiment logs that are still available (which is not all). Helping you more with reproducing R2B's exact training results would require us to redo these runs internally, for which I'm afraid we don't currently have the resources available. When the authors share their code (brais-martinez/real2binary), hopefully you'll be to get some insights from there.

Anyhow, best of luck with your reproduction! If you do figure it out and notice any obvious mistakes or easy fixes in our example training code, we'd very happy to review and merge any PRs. :)

Hyungjun-K1m commented 4 years ago

Okay, We'll try to find best hyper-parameter settings to achieve 65.04% val. accuracy on R2B model. If there's any improvement or news, we'll let you know.

Thanks for your help!