I recently am working on a psychological project on estimating the model parameters using numpyro MCMC inference. However, I've found no tutorials within the numpyro documentation to guide me. Therefore, I have ran into several challenges and am seeking your assistance and advice.
The project aims to fit a simple Q-learning model to a psychological instrumental learning task. In each trial, a participant is presented with one of two stimuli and is required to make a correct choice (from two options) based on the feedback received. Each participant completes a total of 320 trials.
For modeling, we assume each individual maintains a 2 (stimuli) x 2 (actions) Q table, which is updated in the following recursive form
def q_update(q, info):
# upack
s, a, r, alpha, beta = info
# forward
f = beta*q[s, :]
p = jnp.exp(f - jnp.log(jnp.exp(f).sum()))[1]
# update
rpe = r - q[s, a]
q = q.at[s, a].set(q[s, a]+alpha*rpe)
return q, p
where q is the Q table, s is the state at trial t, a is the action at t, r is the reward (feedback). The model contains two free parameters, alpha (the learning rate) and beta (the inverse temperature).
We want to estimate the model parameters using a hierarchical form, where each participant (i) has a specific set of free parameters (alpha_i, beta_i) and all participant-specific parameters share the same group-level distribution (alpha_i ~ N(mu_a, sig_a), beta_i ~ N(mu_b, sig_b)).
We construct the model following:
class rl:
def __init__(self, nS, nA):
self.nS = nS
self.nA = nA
def loop_model(self, data):
def q_update(q, info):
# upack
s, a, r, alpha, beta = info
# forward
f = beta*q[s, :]
p = jnp.exp(f - jnp.log(jnp.exp(f).sum()))[1]
# update
rpe = r - q[s, a]
q = q.at[s, a].set(q[s, a]+alpha*rpe)
return q, p
# get subject list
sub_lst = data['sub_id'].unique()
n_sub = len(sub_lst)
a_mu = npyro.sample(f'alpha_mu', dist.Normal(.4, .5))
a_sig = npyro.sample(f'alpha_sig', dist.HalfNormal(.5))
b_mu = npyro.sample(f'beta_mu', dist.Normal(1, .5))
b_sig = npyro.sample(f'beta_sig', dist.HalfNormal(.5))
with npyro.plate('sub_id', n_sub):
with npyro.handlers.reparam(config={'alpha': TransformReparam(),
'beta': TransformReparam()}):
alpha0 = npyro.sample(
'alpha', dist.TransformedDistribution(dist.Normal(0., 1.),
dist.transforms.AffineTransform(a_mu, a_sig)))
beta0 = npyro.sample(
'beta', dist.TransformedDistribution(dist.Normal(0., 1.),
dist.transforms.AffineTransform(b_mu, b_sig)))
# input
for i, sub_id in enumerate(sub_lst):
q0 = jnp.zeros([self.nS, self.nA])
s = data.query(f'sub_id=={sub_id}')['s'].values
a = data.query(f'sub_id=={sub_id}')['a'].values
r = data.query(f'sub_id=={sub_id}')['r'].values
T = len(r)
alpha = alpha0[i] * jnp.ones([T,])
beta = beta0[i] * jnp.ones([T,])
q, probs = scan(q_update, q0, (s, a, r, alpha, beta), length=T)
npyro.sample(f'a_hat_{sub_id}', dist.Bernoulli(probs=probs), obs=a)
def sample(self, data, mode='loop', seed=1234,
n_samples=50000, n_warmup=100000):
# set the random key
rng_key = jax.random.PRNGKey(seed)
# sampling
start_time = time.time()
kernel = NUTS(eval(f'self.{mode}_model'))
posterior = MCMC(kernel, num_chains=4,
num_samples=n_samples,
num_warmup=n_warmup)
posterior.run(rng_key, data)
samples = posterior.get_samples()
posterior.print_summary()
end_time = time.time()
print(f'Sampling takes {end_time - start_time:2f}s')
with open(f'data/rlq_{mode}.pkl', 'wb')as handle:
pickle.dump(samples, handle)
if __name__ == '__main__':
npyro.set_host_device_count(4)
agent = rl(2, 2)
agent.sample(sim_data, mode='loop')
This is what I got.
The r_hat looks OK, but there is too many divergences. Here are some questions:
Am I implementing the model correctly? Is there any significant mistakes? For me, the estimated parameters are pretty accurate so I do not know if I have made any mistakes.
Can we replace the "for loop" operation with any other parallel methods? I tried the "vmap" method is vectorized the scan over each participant, but it was slow and diverged (see vmap_model in the script). Do you have any advice?
Hi,
I recently am working on a psychological project on estimating the model parameters using numpyro MCMC inference. However, I've found no tutorials within the numpyro documentation to guide me. Therefore, I have ran into several challenges and am seeking your assistance and advice.
The project aims to fit a simple Q-learning model to a psychological instrumental learning task. In each trial, a participant is presented with one of two stimuli and is required to make a correct choice (from two options) based on the feedback received. Each participant completes a total of 320 trials.
For modeling, we assume each individual maintains a 2 (stimuli) x 2 (actions) Q table, which is updated in the following recursive form
where q is the Q table, s is the state at trial t, a is the action at t, r is the reward (feedback). The model contains two free parameters, alpha (the learning rate) and beta (the inverse temperature).
We want to estimate the model parameters using a hierarchical form, where each participant (i) has a specific set of free parameters (alpha_i, beta_i) and all participant-specific parameters share the same group-level distribution (alpha_i ~ N(mu_a, sig_a), beta_i ~ N(mu_b, sig_b)).
We construct the model following:
This is what I got.
The r_hat looks OK, but there is too many divergences. Here are some questions:
Here is the script you can run with: https://github.com/fangzefunny/rl-mcmcest
Looking forward your reply! Thank you!
Best, Zeming