Chuck-Chae / RL_Abstraction_MC

0 stars 0 forks source link

VFA_4states_v3.m #13

Closed LEEMINJIII closed 3 months ago

LEEMINJIII commented 3 months ago

function [loglik] = model_VFA_4states_v3(parameters, subj)

num_state = 4; num_action = 2;

nd_alpha = parameters(1); % normally-distributed alpha alpha = 1/(1+exp(-nd_alpha)); % alpha (transformed to be between zero and one)

nd_beta = parameters(2); beta = exp(nd_beta);

% Initialize weights for linear function approximation theta = zeros(num_state * num_action, 1);

% unpack data actions = subj.actions + 1; % adding 1 such that actions are {1, 2} outcome = subj.Outcome; color = subj.Color; orientation = subj.Orientation; mouth_dir = subj.Mouth_dir;

% counter over all trials ctr = 1;

% calculating numTrails U = unique(subj.block()); numTrails = zeros(1, length(U)); for i = 1:length(U) ind = find(subj.block() == U(i)); numTrails(i) = length(ind); end

% to save probability of choice. Currently NaNs, will be filled below p = nan(1, sum(numTrails));

for t1 = 1:length(numTrails) % number of trials T = numTrails(t1);

for t = 1:T
    % determine state
    state = subj.States4(ctr) + 1;  % Ensure state is within bounds
    action = actions(ctr);

    % Ensure state is within bounds
    if state > num_state || state < 1
        error('State index out of bounds: %d', state);
    end

    % Feature vector
    phi = zeros(num_state * num_action, 1);
    phi((state - 1) * num_action + action) = 1;

    % Q-value estimate for both actions
    Q = theta' * phi;

    % Ensure Q vector has sufficient elements
    if length(Q) < num_action
        % Fill missing elements with some default values or handle the
        % exception appropriately
        % For example:
        Q = zeros(1, num_action); % Default Q values
        % Or handle the exception and continue the loop
        % continue;
    end

    % probability of action 1
    p1 = 1 / (1 + exp(-beta * (Q(1) - Q(2))));

    % probability of action 2
    p2 = 1 - p1;

    % read info for the current trial
    a = actions(ctr); % action on this trial
    o = outcome(ctr); % outcome on this trial

    % store probability of the chosen action
    if a == 1
        p(ctr) = p1;
    elseif a == 2
        p(ctr) = p2;
    end

    % Value function approximation update using gradient descent
    delta = o - Q(action); % prediction error
    theta = theta + alpha * delta * phi;

    ctr = ctr + 1;
end

end

% log-likelihood is defined as the sum of log-probability of choice data % (given the parameters). loglik = sum(log(p + eps)); % Note that eps is a very small number in Matlab (type eps in the command % window to see how small it is), which does not have any effect in practice, % but it overcomes the problem of underflow when p is very very small % (effectively 0). end