pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.19k stars 239 forks source link

Estimate the free parameters of the Q learning model using MCMC inference #1602

Closed fangzefunny closed 1 year ago

fangzefunny commented 1 year ago

Hi,

I recently am working on a psychological project on estimating the model parameters using numpyro MCMC inference. However, I've found no tutorials within the numpyro documentation to guide me. Therefore, I have ran into several challenges and am seeking your assistance and advice.

The project aims to fit a simple Q-learning model to a psychological instrumental learning task. In each trial, a participant is presented with one of two stimuli and is required to make a correct choice (from two options) based on the feedback received. Each participant completes a total of 320 trials.

For modeling, we assume each individual maintains a 2 (stimuli) x 2 (actions) Q table, which is updated in the following recursive form


def q_update(q, info):

            # upack
            s, a, r, alpha, beta = info

            # forward  
            f = beta*q[s, :]
            p = jnp.exp(f - jnp.log(jnp.exp(f).sum()))[1]

            # update 
            rpe = r - q[s, a]
            q = q.at[s, a].set(q[s, a]+alpha*rpe)

            return q, p

where q is the Q table, s is the state at trial t, a is the action at t, r is the reward (feedback). The model contains two free parameters, alpha (the learning rate) and beta (the inverse temperature).

We want to estimate the model parameters using a hierarchical form, where each participant (i) has a specific set of free parameters (alpha_i, beta_i) and all participant-specific parameters share the same group-level distribution (alpha_i ~ N(mu_a, sig_a), beta_i ~ N(mu_b, sig_b)).

We construct the model following:


class rl:

    def __init__(self, nS, nA):
          self.nS = nS
          self.nA = nA 

    def loop_model(self, data):

        def q_update(q, info):

            # upack 
            s, a, r, alpha, beta = info

            # forward  
            f = beta*q[s, :]
            p = jnp.exp(f - jnp.log(jnp.exp(f).sum()))[1]

            # update 
            rpe = r - q[s, a]
            q = q.at[s, a].set(q[s, a]+alpha*rpe)

            return q, p

        # get subject list 
        sub_lst = data['sub_id'].unique()
        n_sub  = len(sub_lst)
        a_mu   = npyro.sample(f'alpha_mu', dist.Normal(.4, .5))
        a_sig  = npyro.sample(f'alpha_sig', dist.HalfNormal(.5))
        b_mu   = npyro.sample(f'beta_mu', dist.Normal(1, .5))
        b_sig  = npyro.sample(f'beta_sig', dist.HalfNormal(.5))
        with npyro.plate('sub_id', n_sub):
            with npyro.handlers.reparam(config={'alpha': TransformReparam(), 
                                                'beta': TransformReparam()}):
                alpha0 = npyro.sample(
                    'alpha', dist.TransformedDistribution(dist.Normal(0., 1.),
                             dist.transforms.AffineTransform(a_mu, a_sig)))
                beta0  = npyro.sample(
                    'beta', dist.TransformedDistribution(dist.Normal(0., 1.),
                             dist.transforms.AffineTransform(b_mu, b_sig)))
        # input 
        for i, sub_id in enumerate(sub_lst):
            q0 = jnp.zeros([self.nS, self.nA])
            s = data.query(f'sub_id=={sub_id}')['s'].values
            a = data.query(f'sub_id=={sub_id}')['a'].values
            r = data.query(f'sub_id=={sub_id}')['r'].values
            T = len(r)
            alpha = alpha0[i] * jnp.ones([T,]) 
            beta  = beta0[i] * jnp.ones([T,]) 
            q, probs = scan(q_update, q0, (s, a, r, alpha, beta), length=T)
            npyro.sample(f'a_hat_{sub_id}', dist.Bernoulli(probs=probs), obs=a)

    def sample(self, data, mode='loop', seed=1234, 
                    n_samples=50000, n_warmup=100000):

        # set the random key 
        rng_key = jax.random.PRNGKey(seed)

        # sampling 
        start_time = time.time()
        kernel = NUTS(eval(f'self.{mode}_model'))
        posterior = MCMC(kernel, num_chains=4,
                                 num_samples=n_samples,
                                 num_warmup=n_warmup)
        posterior.run(rng_key, data)
        samples = posterior.get_samples()
        posterior.print_summary()
        end_time = time.time()
        print(f'Sampling takes {end_time - start_time:2f}s')

        with open(f'data/rlq_{mode}.pkl', 'wb')as handle:
            pickle.dump(samples, handle)

if __name__ == '__main__':
    npyro.set_host_device_count(4)
    agent = rl(2, 2)
    agent.sample(sim_data, mode='loop')

This is what I got.

image

The r_hat looks OK, but there is too many divergences. Here are some questions:

  1. Am I implementing the model correctly? Is there any significant mistakes? For me, the estimated parameters are pretty accurate so I do not know if I have made any mistakes.
  2. Can we replace the "for loop" operation with any other parallel methods? I tried the "vmap" method is vectorized the scan over each participant, but it was slow and diverged (see vmap_model in the script). Do you have any advice?

Here is the script you can run with: https://github.com/fangzefunny/rl-mcmcest

Looking forward your reply! Thank you!

Best, Zeming

martinjankowiak commented 1 year ago

@fangzefunny can you please ask your modeling question on the forum? github issues are meant for bug reports, feature requests, etc