Chuck-Chae / RL_Abstraction_MC

0 stars 0 forks source link

DoubleQ_cbm_optim #25

Open duriseul opened 3 months ago

duriseul commented 3 months ago

function [tx,F,H,G,flag,k,P,NLL] = cbm_optim(h,optconfig,rng,numrep,init0,fid) % This function minimizes h using the fminunc routine in matlab (with % matlab versions after 2013a). % % implemented by Payam Piray, Aug 2018 %========================================================================== if nargin<4, numrep=1; end if nargin<5, init0=[]; end if nargin<6, fid =1; end

options = optimoptions('fminunc','Algorithm',optconfig.algorithm,'Display','off',... 'GradObj',optconfig.gradient,'Hessian',optconfig.hessian,... 'ObjectiveLimit',optconfig.ObjectiveLimit);

tolG = optconfig.tolgrad; numrep_up = optconfig.numinit_up; numrep_med = optconfig.numinit_med; verbose = optconfig.verbose;

F = 10^16; flag = 0; tx = nan; H = nan; G = nan; r = rng(2,:)-rng(1,:); % for random initialization numrep = numrep + size(init0,1);

k = 0; P = []; NLL= []; while( (k<numrep) || (k>=numrep && k<numrep_med && flag==.5) || (k>=numrep && k<numrep_up && flag==0) ) k=k+1;

% 다양한 초기화 방법 적용

optconfig.initmethod = 'random';
switch optconfig.initmethod
    case 'fixed'
        try
            init = init0(k,:);
        catch
            init = mean(rng); % 범위의 중간값으로 초기화
        end
    case 'random'
        init = rand(size(r)).*r + rng(1,:); % 범위 내 랜덤 초기화
    case 'previous'
        if k > 1
            init = P(end,:) + 0.01*randn(size(r)); % 이전 값에 작은 변화를 추가하여 초기화
        else
            init = rand(size(r)).*r + rng(1,:);
        end
    otherwise
        init = rand(size(r)).*r + rng(1,:);
end

try
    [tx_tmp, F_tmp, ~,~,G_tmp,H_tmp] = fminunc(h, init, options);
    [~,ishesspos] = chol(H_tmp);
    ishesspos = ~logical(ishesspos);

    sumG = mean(abs(G_tmp));

    if (flag~=1 || (F_tmp<F)) && ishesspos && (sumG<tolG)
            flag = 1;
            tx = tx_tmp;
            F = F_tmp;
            H = H_tmp;
            G = G_tmp;
    end
    if (flag~=1 && (F_tmp<F)) && ishesspos && (sumG>tolG) % minimal condition
        flag  = .5;
        tx = tx_tmp;
        F  = F_tmp;
        H  = H_tmp;
        G  = G_tmp;
    end        

    P   = [P; tx_tmp]; %#ok<AGROW>
    NLL = [NLL; F_tmp]; %#ok<AGROW>
catch msg
    logging(verbose,fid,sprintf('--- This initialization was aborted (there might be a problem with the model)\n'));
    logging(verbose,fid,sprintf('--- The message of optimization routine is:\n'));
    logging(verbose,fid,sprintf('---    %s\n',msg.message));
end

end

switch flag case 0 logging(verbose,fid,sprintf('--- No positive hessian found in spite of %d initialization.\n',k)); case .5 logging(verbose,fid,sprintf('--- Positive hessian found, but not a good gradient in spite of %d initialization.\n',k)); case 1 if k>numrep % logging(verbose,fid,sprintf('--- Optimized with %d initializations(>%d specified by user).\n',k,numrep)); end end

end

function logging(verbose,fid,str) % this function is similar to fprintf! if verbose, fprintf(str); end if fid>1, fprintf(fid,str); end end