hiroyuki-kasai / SGDLibrary

MATLAB/Octave library for stochastic optimization algorithms: Version 1.0.20
MIT License
215 stars 85 forks source link

damping scheme in obfgs #9

Open hxyokokok opened 4 years ago

hxyokokok commented 4 years ago

Dear Prof. Hiroyuki KASAI,

I have some confusion about the damping scheme in obfgs. You gave a citation of Wang et al. 's "Stochastic Quasi-Newton Methods for Nonconvex Stochastic Optimization," when mentioning the damping, but I found that the code is more likely to be the classical Powell's damping. In the limited memory setting, Wang's damping only requires one two-loop-recursion but in your implementation you call it twice. In addition, the threshold value indicating when to carray out the damping is set to 0.2 in your code, which is the same to Powell's method but different from the Wang's. I also found a probable bug, if it is indeed the Powell's damping scheme. In Line 144, the probablly correct code is "lbfgs_two_loop_recursion(s, s_array, y_array)" (otherwise you may directly reuse the variable calculated before). Is this a typo or something wrong with my opinion? Thanks!

hiroyuki-kasai commented 4 years ago

Thank you for your comments. Let me have a closer look at the code. By the way, if you can provide the two "correct codes" of Wang's and Powell's methods, it would be very appreciated. Best, Hiro

hxyokokok commented 4 years ago

Sorry for this late reply, the message has been blocked by my mailbox.


% @Author: hxy
% @Date:   2020-05-07 16:53:18
% @Last Modified by:   Xiaoyu He
% @Last Modified time: 2020-09-27 20:57:30

% oLBFGS + damping
% Wang X, Ma S, Goldfarb D, Liu W. 
%   Stochastic Quasi-Newton Methods for Nonconvex Stochastic Optimization. 
%   SIAM Journal on Optimization, 2017, 27(2): 927-956.

function [w, infos] = oLBFGSd(problem, in_opts)
    % set dimensions and samples
    d = problem.dim();
    n = problem.samples();  

    % set local options 
    local_opts.k = 5;

    % merge options
    opts = mergeOptions(get_default_options(d), local_opts);   
    opts = mergeOptions(opts, in_opts);  

    % counters
    iters = 0; % index of mini-batch processing
    epoch = 0; % index of epochs
    grad_calc_count = 0; % number of gradient evaluation

    w = opts.w_init; % initial variable

    hist_idx = [];
    S = zeros(d,opts.k);
    Y = zeros(d,opts.k);

    % store first infos
    clear infos;    
    [infos, f_val, optgap] = store_infos(problem, w, opts, [], epoch, grad_calc_count, 0);

    % display infos
    if opts.verbose > 0
        fprintf('oLBFGSd: Epoch = %03d, cost = %.16e, optgap = %.4e\n', epoch, f_val, optgap);
    end    

    % set start time
    start_time = tic();

    % main loop
    while (optgap > opts.tol_optgap) && (epoch < opts.max_epoch)

        % re-permute in each epoch
        if opts.permute_on
            perm_idx = randperm(n);
        else
            perm_idx = 1:n;
        end

        for j = 1 : floor(n / opts.batch_size)

            % mini-batch
            indice_j = (j-1) * opts.batch_size + (1:opts.batch_size);
            indice_j = perm_idx(indice_j);

            grad = problem.grad(w, indice_j);
            grad_calc_count = grad_calc_count + opts.batch_size;        

            ss = opts.stepsizefun(iters, opts);             
            % two-loop recursion
            q = grad;
            for i = hist_idx(end:-1:1)
                alpha(i) = rho(i) * S(:,i)'*q;
                q = q - Y(:,i) * alpha(i);
            end
            % rescaling
            if ~isempty(hist_idx) 
                q = q / gamma_; 
            end
            for i = hist_idx
                q = q + S(:,i)*(alpha(i)-rho(i)*Y(:,i)'*q);
            end

            % descend
            w = w - ss * q;

            ptr = mod(iters,opts.k) + 1;
            if iters < opts.k
                hist_idx = 1:ptr;
            else
                hist_idx = [(ptr+1):opts.k 1:ptr];
            end

            % update gradient variance
            grad_new = problem.grad(w, indice_j);
            grad_calc_count = grad_calc_count + opts.batch_size;        

            s = -ss * q;
            y = grad_new - grad; 
            sy = s'*y;
            gamma_ = max(y'*y/sy,opts.delta); 
            s_invH_s = s'*s*gamma_;
            if 0.25*s_invH_s>sy
                theta_ = 0.75*s_invH_s/(s_invH_s - sy);
                y = theta_ * y + (1-theta_) * gamma_ * s;
            end
            % assert(y'*s>0);
            S(:,ptr) = s;
            Y(:,ptr) = y;
            rho(ptr) = 1/(s'*y);
            iters = iters + 1;
        end

        % measure elapsed time
        elapsed_time = toc(start_time);

        % count gradient evaluations
        epoch = epoch + 1;

        % store infos
        [infos, f_val, optgap] = store_infos(problem, w, opts, infos, epoch, grad_calc_count, elapsed_time);        

        % display infos
        if opts.verbose > 0
            fprintf('oLBFGSd: Epoch = %03d, cost = %.16e, optgap = %.4e, |g| = %.4e\n', epoch, f_val, optgap, infos.gnorm(end));
        end

    end

    if opts.verbose > 0
        if optgap < opts.tol_optgap
            fprintf('Optimality gap tolerance reached: tol_optgap = %g\n', opts.tol_optgap);
        elseif epoch == opts.max_epoch
            fprintf('Max epoch reached: max_epoch = %g\n', opts.max_epoch);
        end
    end

end
`

```matlab
% @Author: hxy
% @Date:   2020-05-07 16:53:18
% @Last Modified by:   Xiaoyu He
% @Last Modified time: 2020-08-08 22:03:14

% BFGS + powell damping

function [w, infos] = oLBFGSpd(problem, in_opts)
    % set dimensions and samples
    d = problem.dim();
    n = problem.samples();  

    % set local options 
    local_opts.k = 5;

    % merge options
    opts = mergeOptions(get_default_options(d), local_opts);   
    opts = mergeOptions(opts, in_opts);  

    % counters
    iters = 0; % index of mini-batch processing
    epoch = 0; % index of epochs
    grad_calc_count = 0; % number of gradient evaluation

    w = opts.w_init; % initial variable

    hist_idx = [];
    S = zeros(d,opts.k);
    Y = zeros(d,opts.k);
    rho = [];
    % store first infos
    clear infos;    
    [infos, f_val, optgap] = store_infos(problem, w, opts, [], epoch, grad_calc_count, 0);

    % display infos
    if opts.verbose > 0
        fprintf('oLBFGSpd: Epoch = %03d, cost = %.16e, optgap = %.4e\n', epoch, f_val, optgap);
    end    

    % set start time
    start_time = tic();

    % main loop
    while (optgap > opts.tol_optgap) && (epoch < opts.max_epoch)

        % re-permute in each epoch
        if opts.permute_on
            perm_idx = randperm(n);
        else
            perm_idx = 1:n;
        end

        for j = 1 : floor(n / opts.batch_size)

            % mini-batch
            indice_j = (j-1) * opts.batch_size + (1:opts.batch_size);
            indice_j = perm_idx(indice_j);

            grad = problem.grad(w, indice_j);
            grad_calc_count = grad_calc_count + opts.batch_size;        

            ss = opts.stepsizefun(iters, opts);             

            q = bfgs_two_loop(grad,hist_idx,rho,S,Y);

            % descend
            w = w - ss * q;

            % update gradient variance
            grad_new = problem.grad(w, indice_j);
            grad_calc_count = grad_calc_count + opts.batch_size;        
            s_ = -ss*q;
            y_ = grad_new - grad; 

            Bs = bfgs_two_loop(s_,hist_idx,rho,S,Y);
            sBs = s_'*Bs;
            sy = s_'*y_;
            if sy < 0.2*sBs
                theta_ = 0.8 * sBs / (sBs - sy);
                y_ = theta_ * y_ + (1-theta_)*Bs;
                sy = s_'*y_;
            end
            % assert(sy>0);

            ptr = mod(iters,opts.k) + 1;
            if iters < opts.k
                hist_idx = 1:ptr;
            else
                hist_idx = [(ptr+1):opts.k 1:ptr];
            end

            S(:,ptr) = s_;
            Y(:,ptr) = y_;
            rho(ptr) = 1/sy;

            iters = iters + 1;
        end

        % measure elapsed time
        elapsed_time = toc(start_time);

        % count gradient evaluations
        epoch = epoch + 1;

        % store infos
        [infos, f_val, optgap] = store_infos(problem, w, opts, infos, epoch, grad_calc_count, elapsed_time);        

        % display infos
        if opts.verbose > 0
            fprintf('oLBFGSpd: Epoch = %03d, cost = %.16e, optgap = %.4e, |g| = %.4e\n', epoch, f_val, optgap, infos.gnorm(end));
        end

    end

    if opts.verbose > 0
        if optgap < opts.tol_optgap
            fprintf('Optimality gap tolerance reached: tol_optgap = %g\n', opts.tol_optgap);
        elseif epoch == opts.max_epoch
            fprintf('Max epoch reached: max_epoch = %g\n', opts.max_epoch);
        end
    end

end

function q = bfgs_two_loop(v,hist_idx,rho,S,Y)
    q = v;
    for i = hist_idx(end:-1:1)
        alpha(i) = rho(i) * S(:,i)'*q;
        q = q - Y(:,i) * alpha(i);
    end
    % rescaling
    if ~isempty(hist_idx) 
        i = hist_idx(end);
        q = q * (S(:,i)'*Y(:,i))/(Y(:,i)'*Y(:,i));
    end
    for i = hist_idx
        q = q + S(:,i)*(alpha(i)-rho(i)*Y(:,i)'*q);
    end
end
`
hiroyuki-kasai commented 4 years ago

Dear Xiaoyu He,

Thank you for your codes. Let me go through them.

Best,

Hiro