Chuck-Chae / RL_Abstraction_MC

0 stars 0 forks source link

MC_8states_code #2

Closed duriseul closed 3 months ago

duriseul commented 3 months ago

function [loglik] = model_RL_8states_v2(parameters, subj)

num_state = 8; num_action = 2;

nd_alpha = parameters(1); % normally-distributed alpha alpha = 1 / (1 + exp(-nd_alpha)); % alpha (transformed to be between zero and one)

nd_beta = parameters(2); beta = exp(nd_beta);

% unpack data actions = subj.actions + 1; % adding 1 such that actions are {1, 2} outcome = subj.Outcome; color = subj.Color; orientation = subj.Orientation; mouth_dir = subj.Mouth_dir;

% counter over all trials ctr = 1; % calculating numTrails U = unique(subj.block()); numTrails = zeros(1, length(U)); for i = 1:length(U) ind = find(subj.block() == U(i)); numTrails(i) = length(ind); end

loglik = 0; % Initialize log-likelihood

for t1 = 1:size(numTrails, 2) % number of trials T = numTrails(t1);

% Q-value for each action
q = .5 * ones(num_state, num_action); % Q-value for both actions initialized at 0

% to save episode data
episode_states = zeros(T, 1);
episode_actions = zeros(T, 1);
episode_rewards = zeros(T, 1);
p = zeros(T, 1); % to save probability of choice

for t = 1:T
    if num_state == 8
        state = mouth_dir(ctr) * 2 * 2 + orientation(ctr) * 2 + color(ctr) + 1;
    elseif num_state == 4
        state = subj.States4(ctr) + 1;
    elseif num_state == 2
        state = subj.States2(ctr) + 1;
    end

    % probability of action 1
    % this is equivalent to the softmax function, but overcomes the problem
    % of overflow when q-values or beta are large.
    p1 = 1 / (1 + exp(-beta * (q(state, 1) - q(state, 2))));

    % probability of action 2
    p2 = 1 - p1;

    % read info for the current trial
    a = actions(ctr); % action on this trial
    o = outcome(ctr); % outcome on this trial

    % store probability of the chosen action
    if a == 1
        p(t) = p1;
    elseif a == 2
        p(t) = p2;
    end

    % Save the episode data
    episode_states(t) = state;
    episode_actions(t) = a;
    episode_rewards(t) = o;

    ctr = ctr + 1;
end

% Monte Carlo update after the episode ends
returns = zeros(T, 1);
G = 0;
for t = T:-1:1
    G = episode_rewards(t) + G;
    returns(t) = G;
end

for t = 1:T
    state = episode_states(t);
    a = episode_actions(t);
    G = returns(t);

    q(state, a) = q(state, a) + alpha * (G - q(state, a));

    a2 = mod(a, 2) + 1;
    q(state, a2) = q(state, a2) + alpha * ((1 - G) - q(state, a2));
end

% Accumulate log-likelihood
loglik = loglik + sum(log(p + eps));

end end