Chuck-Chae / RL_Abstraction_MC

0 stars 0 forks source link

MC_8states_v2.m #27

Open LEEMINJIII opened 3 months ago

LEEMINJIII commented 3 months ago

function [loglik] = model_RL_4states_v2_MC(parameters, subj)

num_state = 8; num_action = 2;

nd_alpha = parameters(1); % normally-distributed alpha alpha = 1/(1+exp(-nd_alpha)); % alpha (transformed to be between zero and one)

nd_beta = parameters(2); beta = exp(nd_beta);

% unpack data actions = subj.actions + 1; % adding 1 such that actions are {1, 2} outcome = subj.Outcome; color = subj.Color; orientation = subj.Orientation; mouth_dir = subj.Mouth_dir;

% counter over all trials ctr = 1; % calculating numTrials U = unique(subj.block()); numTrials = zeros(1, length(U)); for i = 1:length(U) ind = find(subj.block() == U(i)); numTrials(i) = length(ind); end

% Q-value for each action q = .5 * ones(num_state, num_action); % Q-value for both actions initialized at 0

% to save probability of choice. Currently NaNs, will be filled below p = [];

for t1 = 1:size(numTrials, 2) % number of trials T = numTrials(t1);

% Initialize episode data
episode = struct('state', [], 'action', [], 'reward', []);

for t = 1:T   

    if num_state == 8
        state = mouth_dir(ctr) * 2 * 2 + orientation(ctr) * 2 + color(ctr) + 1;
    elseif num_state == 4
        state = subj.States4(ctr) + 1;
    elseif num_state == 2
        state = subj.States2(ctr) + 1;
    end

    % probability of action 1
    % this is equivalent to the softmax function, but overcomes the problem
    % of overflow when q-values or beta are large.        
    p1 = 1 / (1 + exp(-beta * (q(state, 1) - q(state, 2))));

    % probability of action 2
    p2 = 1 - p1;

    % read info for the current trial
    a = actions(ctr); % action on this trial
    o = outcome(ctr); % outcome on this trial

    % store probability of the chosen action
    if a == 1
        p(ctr) = p1;
    elseif a == 2
        p(ctr) = p2;
    end

    % Save episode data
    episode.state(end + 1) = state;
    episode.action(end + 1) = a;
    episode.reward(end + 1) = o;

    ctr = ctr + 1;

        % Monte Carlo update after the episode ends
    % Monte Carlo update after the episode ends
    G = 0;
    for t = T:-1:1
    % Ensure the index is within the range of episode.reward
        if t <= length(episode.reward)
            G = episode.reward(t) + alpha * G;
            state = episode.state(t);
            a = episode.action(t);

            q(state, a) = q(state, a) + (1 / T) * (G - q(state, a));

            % Update the value for the other action as well
            a2 = mod(a, 2) + 1;
            q(state, a2) = q(state, a2) + (1 / T) * ((1 - G) - q(state, a2));
        end
    end
end

end

% log-likelihood is defined as the sum of log-probability of choice data % (given the parameters). loglik = sum(log(p + eps)); % Note that eps is a very small number in MATLAB (type eps in the command % window to see how small it is), which does not have any effect in practice, % but it overcomes the problem of underflow when p is very very small % (effectively 0). end