juntang-zhuang / Adabelief-Optimizer

Repository for NeurIPS 2020 Spotlight "AdaBelief Optimizer: Adapting stepsizes by the belief in observed gradients"
BSD 2-Clause "Simplified" License
1.05k stars 109 forks source link

Similarity to AdaHessian #16

Closed davda54 closed 4 years ago

davda54 commented 4 years ago

Hi, first of all, thank you very much for sharing the code for AdaBelief, it looks like a very promising optimizer! :) Have you considered comparing it to AdaHessian? I feel like AdaHessian is using the same trick as you (but they do it less efficiently).

juntang-zhuang commented 4 years ago

Thanks for the comment. I quickly skimmed over their paper, the idea is roughly similar, but very different in terms of implementation. Adahessian directly computes the hessian diagonal, ours approximate this by change in gradient. I have not compared to AdaHessian in practice. From the paper, they are doing a back-prop through the back-prop, seems quite slow (2 to 3 times slower than Adam), while ours is similar to Adam in terms of speed. As for the trick, AdaHessian uses a block-averaging trick, which is rarely discussed in previous work, but I feel it might be helpful for AdaBelief and other optimizers too.

davda54 commented 4 years ago

Thanks for your quick answer. What I meant is that AdaBelief defines as (very informally), while AdaHessian as (where is the Hessian trace). Otherwise, both optimizers are the same. I believe that should be highly correlated with , as they both show the amount of local curvature.

So I am interested if there is a performance-speed tradeoff between these two optimizers, or if AdaBelief is strictly better than AdaHessian in both the speed and the accuracy.

juntang-zhuang commented 4 years ago

I don't know if AdaBelief is strictly better than AdaHessian, in fact I guess there will be both cases where one outperforms the other. It's hard to determine without extensive experiments. I want to point out that even in theory they are not the same, h_t is not the trace, but the amplitude of diagonal element of the Hessian if I'm correct. Whether EMA(gt-mt)^2 is a good approximation to Hessian is hard to say. Just like the convergence proof of Adam is so much more trouble than RMSProp, a single modification in practice could result in big difference in theory. In terms of empirical results, I think the implementation matters, even if the algorithm looks very similar. So I'm sorry I don't have a conclusion yet.

davda54 commented 4 years ago

I see :) Anyway, thanks again for your great work!

sjscotti commented 2 years ago

I know this is closed, but I wanted to agree that if an element of EMA(gt-mt)^2 was divided by the square of the change in the corresponding model parameter, it would be a close finite difference approximation of a diagonal Hessian - which is what AdaHessian uses. Have you tried including this division in your routine to see if it improves results?

juntang-zhuang commented 2 years ago

Thanks for the comment, I have not tried that in my experiments. It sounds like a very interesting idea. I think it would help, but just one more thing to consider, the gradient w.r.t. parameters depends on data, so we might need to use one batch twice or somehow find an approximation. I'll try to pursue it later.

sjscotti commented 2 years ago

Hi I made a try at implementing my suggestion above as another option in adabelief which is used when the flag fd_adahessian (which stands for finite difference Hessian used as in AdaHessian) is set to True. What it is using instead of g_t - m_t in the exponential moving average s_t (using your paper's notation) is the finite difference of the momentum m (which is the exponential moving average of the gradient g with respect to the parameter theta) with the corresponding model parameter theta between the present and previous step of the optimizer. The assumption made is that the change in this momentum m for parameter theta between steps is primarily due to the change in correspond parameter theta. This is similar to what adaHessian is assuming when it uses only the diagonal of the Hessian matrix in the update to its version of s_t. I am unsure of the best initialization s_0 for this version of adabelief since you can't form a finite difference until step 2 - any suggestions you have would be appreciated. I am training a model with it at present and it appears to be no worse than the default version of adabelief. I have time to check the code more carefully while it is training since I am not sure it doesn't have mistakes. Below is the code for this version if you would like to comment on it or try it yourself. Regards -Steve

    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                # cast data type
                half_precision = False
                if p.data.dtype == torch.float16:
                    half_precision = True
                    p.data = p.data.float()
                    p.grad = p.grad.float()

                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError(
                        'AdaBelief does not support sparse gradients, please consider SparseAdam instead')
                amsgrad = group['amsgrad']

                fd_adahessian = self.fd_adahessian

                state = self.state[p]

                beta1, beta2 = group['betas']

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
                        if version_higher else torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
                        if version_higher else torch.zeros_like(p.data)
                    if fd_adahessian: # create p_old
                        state['p_old'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
                            if version_higher else torch.zeros_like(p.data) 
                    if amsgrad:
                        # Maintains max of all exp. moving avg. of sq. grad. values
                        state['max_exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
                            if version_higher else torch.zeros_like(p.data)

                # get current state variable
                exp_avg, exp_avg_var, p_old = state['exp_avg'], state['exp_avg_var'], state['p_old']

                state['step'] += 1
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']

                if fd_adahessian: 
                    '''             
                    if state['step'] == 1:
                    #SJS zero out for first step
                        delta_m = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
                            if version_higher else torch.zeros_like(p.data) 
                    else:
                    #SJS below assume this also works for the first step assuming previous m is zero at p_old of zero
                    #SJS the step = 1 code above will do a divide by sqrt eps which may blow up the routine               
                    '''
                    # first calculate delta m --- uses present grad and previous m
                    delta_m = (grad - exp_avg) * (1 - beta1) #SJS new
                    delta_m.div_(torch.sub(p.data, p_old).add_(group['eps']))   # approximates delta m / delta p        
                # Update first and second moment running average
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                if not fd_adahessian:
                    delta_m = grad - exp_avg #SJS original adabelief...  this uses current m which is exp_avg
                exp_avg_var.mul_(beta2).addcmul_( delta_m, delta_m, value=1 - beta2)

                if amsgrad:
                    max_exp_avg_var = state['max_exp_avg_var']
                    # Maintains the maximum of all 2nd moment running avg. till now
                    torch.max(max_exp_avg_var, exp_avg_var.add_(group['eps']), out=max_exp_avg_var)  #SJS want add_ here?

                    # Use the max. for normalizing running avg. of gradient
                    denom = (max_exp_avg_var.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])  #SJS want add_ here?
                else:

                   denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])  #SJS want 2x add_ here?

                if fd_adahessian:
                    state['p_old'] = p.data.clone() #  p_old goes with present m for next step

                # perform weight decay, check if decoupled weight decay #SJS moved here from earlier because can modify grad that was needed for grad_residual
                if self.weight_decouple:
                    if not self.fixed_decay:
                        p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
                    else:
                        p.data.mul_(1.0 - group['weight_decay'])
                else:
                    if group['weight_decay'] != 0:
                        grad.add_(p.data, alpha=group['weight_decay'])

                # update
                if not self.rectify:
                    # Default update
                    step_size = group['lr'] / bias_correction1
                    p.data.addcdiv_( exp_avg, denom, value=-step_size)

                else:  # Rectified update, forked from RAdam
                    buffered = group['buffer'][int(state['step'] % 10)]
                    if state['step'] == buffered[0]:
                        N_sma, step_size = buffered[1], buffered[2]
                    else:
                        buffered[0] = state['step']
                        beta2_t = beta2 ** state['step']
                        N_sma_max = 2 / (1 - beta2) - 1
                        N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
                        buffered[1] = N_sma

                        # more conservative since it's an approximated value
                        if N_sma >= 5:
                            step_size = math.sqrt(
                                (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
                                        N_sma_max - 2)) / (1 - beta1 ** state['step'])
                        elif self.degenerated_to_sgd:
                            step_size = 1.0 / (1 - beta1 ** state['step'])
                        else:
                            step_size = -1
                        buffered[2] = step_size

                    if N_sma >= 5:
                        denom = exp_avg_var.sqrt().add_(group['eps'])
                        p.data.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
                    elif step_size > 0:
                        p.data.add_( exp_avg, alpha=-step_size * group['lr'])

                if half_precision:
                    p.data = p.data.half()
                    p.grad = p.grad.half() 

        return loss

UPDATE: I did a little derivation and found that the term in AdaBelief: g_t - m_t is equal to: beta2 *(mt - m(t-1) )/ (1 - beta2) So AdaBelief has the numerator part of the momentum finite difference mentioned above.