Chuck-Chae / RL_Abstraction_MC

0 stars 0 forks source link

DoubleQ_4states #23

Open duriseul opened 3 months ago

duriseul commented 3 months ago

function [loglik] = model_RL_4states_v2(parameters, subj)

num_state = 4; num_action = 2;

nd_alpha = parameters(1); % normally-distributed alpha alpha = 1/(1+exp(-nd_alpha)); % alpha (transformed to be between zero and one)

nd_beta = parameters(2); beta = exp(nd_beta);

% unpack data actions = subj.actions+1; % adding 1 such that actions are {1, 2} outcome = subj.Outcome; color = subj.Color ; orientation = subj.Orientation ; mouth_dir = subj.Mouth_dir ;

% counter over all trials ctr = 1; % calculating numTrails U = unique(subj.block()); numTrails = zeros(1, length(U)); for i=1:length(U) ind = find(subj.block() == U(i)); numTrails(i) = length(ind); end

for t1 = 1:size(numTrails, 2) % number of trials T = numTrails(t1);

% Q-value for each action (two Q tables for Double Q-learning)
Q1      = .5 * ones(num_state, num_action);
Q2      = .5 * ones(num_state, num_action);

% to save probability of choice. Currently NaNs, will be filled below
p       = [];  

for t = 1:T   

    if num_state == 8
        state = mouth_dir(ctr) * 2 * 2 + orientation(ctr) * 2 + color(ctr) + 1;
    elseif num_state == 4
        state = subj.States4(ctr) + 1;
    elseif num_state == 2
        state = subj.States2(ctr) + 1;
    end

    % Calculate combined Q-value
    Q_combined = Q1 + Q2;

    % probability of action 1 using combined Q-value
    p1 = 1 / (1 + exp(-beta * (Q_combined(state, 1) - Q_combined(state, 2))));

    % probability of action 2
    p2 = 1 - p1;

    % read info for the current trial
    a = actions(ctr); % action on this trial
    o = outcome(ctr); % outcome on this trial

    % store probability of the chosen action
    if a == 1
        p(ctr) = p1;
    elseif a == 2
        p(ctr) = p2;
    end

    % Randomly update one of the Q-tables
    if rand < 0.5
        % Update Q1
        delta1 = o - Q1(state, a); % prediction error
        Q1(state, a) = Q1(state, a) + (alpha * delta1);

        a2 = mod(a, 2) + 1;
        delta2 = 1 - o - Q1(state, a2); % prediction error for the non-chosen action
        Q1(state, a2) = Q1(state, a2) + (alpha * delta2);
    else
        % Update Q2
        delta1 = o - Q2(state, a); % prediction error
        Q2(state, a) = Q2(state, a) + (alpha * delta1);

        a2 = mod(a, 2) + 1;
        delta2 = 1 - o - Q2(state, a2); % prediction error for the non-chosen action
        Q2(state, a2) = Q2(state, a2) + (alpha * delta2);
    end

    ctr = ctr + 1;
end

end

% log-likelihood is defined as the sum of log-probability of choice data % (given the parameters). loglik = sum(log(p + eps)); % Note that eps is a very small number in matlab (type eps in the command % window to see how small it is), which does not have any effect in practice, % but it overcomes the problem of underflow when p is very very small % (effectively 0). end