Chuck-Chae / RL_Abstraction_MC

0 stars 0 forks source link

VFA_8stastes_v3.m #17

Open LEEMINJIII opened 3 months ago

LEEMINJIII commented 3 months ago

function [loglik] = model_RL_8states_v2(parameters, subj)

num_state = 8; num_action = 2;

nd_alpha = parameters(1); % normally-distributed alpha alpha = 1/(1+exp(-nd_alpha)); % alpha (transformed to be between zero and one)

nd_beta = parameters(2); beta = exp(nd_beta);

% Initialize weights for linear function approximation theta = zeros(num_state * num_action, 1);

% unpack data actions = subj.actions + 1; % adding 1 such that actions are {1, 2} outcome = subj.Outcome; color = subj.Color; orientation = subj.Orientation; mouth_dir = subj.Mouth_dir;

% counter over all trials ctr = 1;

% calculating numTrails U = unique(subj.block()); numTrails = zeros(1, length(U)); for i = 1:length(U) ind = find(subj.block() == U(i)); numTrails(i) = length(ind); end

% to save probability of choice. Currently NaNs, will be filled below p = nan(1, sum(numTrails));

for t1 = 1:length(numTrails) % number of trials T = numTrails(t1);

for t = 1:T
    % determine state
    state = mouth_dir(ctr) * 4 + orientation(ctr) * 2 + color(ctr) + 1;
    action = actions(ctr);

    % Ensure state is within bounds
    if state > num_state || state < 1
        error('State index out of bounds: %d', state);
    end

    % Feature vector
    phi = zeros(num_state * num_action, 1);
    phi((state - 1) * num_action + action) = 1;

    % Q-value estimate for both actions
    %Q = theta' * phi;
    Q = zeros(1, num_action);
    for a = 1:num_action
        phi_a = zeros(num_state * num_action, 1);
        phi_a((state - 1) * num_action + a) = 1;
        Q(a) = theta' * phi_a;
    end
     % Display Q values for debugging
    %fprintf('Q values for trial %d: [%.4f, %.4f]\n', ctr, Q(1), Q(2));

    % probability of action 1
    p1 = 1 / (1 + exp(-beta * (Q(1) - Q(2))));

    % probability of action 2
    p2 = 1 - p1;

    % read info for the current trial
    a = actions(ctr); % action on this trial
    o = outcome(ctr); % outcome on this trial

    % store probability of the chosen action
    if a == 1
        p(ctr) = p1;
    elseif a == 2
        p(ctr) = p2;
    end

    % Value function approximation update using gradient descent
    delta = o - Q(action); % prediction error
    theta = theta + alpha * delta * phi;

    ctr = ctr + 1;
end

end

% log-likelihood is defined as the sum of log-probability of choice data % (given the parameters). loglik = sum(log(p + eps)); % Note that eps is a very small number in Matlab (type eps in the command % window to see how small it is), which does not have any effect in practice, % but it overcomes the problem of underflow when p is very very small % (effectively 0). end