Chuck-Chae / RL_Abstraction_MC

0 stars 0 forks source link

TDLAMBDA_8states_v2.m #5

Open Chuck-Chae opened 3 months ago

Chuck-Chae commented 3 months ago

function [loglik] = model_RL_8states_v2(parameters, subj)

num_state = 8; num_action = 2;

nd_alpha = parameters(1); % normally-distributed alpha alpha = 1 / (1 + exp(-nd_alpha)); % alpha (transformed to be between zero and one)

nd_beta = parameters(2); beta = exp(nd_beta);

% Lambda value for TD(lambda) lambda = parameters(3); % Lambda value

% Unpack data actions = subj.actions + 1; % adding 1 such that actions are {1, 2} outcome = subj.Outcome; color = subj.Color; orientation = subj.Orientation; mouth_dir = subj.Mouth_dir;

% Counter over all trials ctr = 1;

% Calculate numTrails U = unique(subj.block()); numTrails = zeros(1, length(U)); for i = 1:length(U) ind = find(subj.block() == U(i)); numTrails(i) = length(ind); end

% Initialize eligibility traces for each state-action pair e = zeros(num_state, num_action);

for t1 = 1:size(numTrails, 2) % Number of trials T = numTrails(t1);

% Q-value for each action
q = 0.5 * ones(num_state, num_action); % Q-value for both actions initialized at 0.5

% To save probability of choice. Currently NaNs, will be filled below
p = [];

for t = 1:T   
    if num_state == 8
        state = mouth_dir(ctr) * 2 * 2 + orientation(ctr) * 2 + color(ctr) + 1;
    elseif num_state == 4
        state = subj.States4(ctr) + 1;
    elseif num_state == 2
        state = subj.States2(ctr) + 1;
    end

    % Probability of action 1
    % This is equivalent to the softmax function, but overcomes the problem
    % of overflow when q-values or beta are large.        
    p1 = 1 / (1 + exp(-beta * (q(state, 1) - q(state, 2))));

    % Probability of action 2
    p2 = 1 - p1;

    % Read info for the current trial
    a = actions(ctr); % action on this trial
    o = outcome(ctr); % outcome on this trial

    % Store probability of the chosen action
    if a == 1
        p(ctr) = p1;
    elseif a == 2
        p(ctr) = p2;
    end

    % Prediction error
    delta = o - q(state, a);

    % Update eligibility traces
    e = lambda * e;
    e(state, a) = e(state, a) + 1;

    % Update Q-values for all state-action pairs
    q = q + alpha * delta * e;

    ctr = ctr + 1;
end

end

% Log-likelihood is defined as the sum of log-probability of choice data % (given the parameters). % Log-likelihood is defined as the sum of log-probability of choice data % (given the parameters). % Log-likelihood is defined as the sum of log-probability of choice data % (given the parameters). loglik = sum(log(p + eps));

% Return the negative log likelihood (to be minimized by fminunc)
neg_loglik = -loglik;

end