lnccbrown / HSSM

Development of HSSM package
Other
76 stars 11 forks source link

Long running time of `sample_posterior_predictive()` and eventual death by OOM #414

Closed jainraj closed 4 months ago

jainraj commented 4 months ago

Bug Description

I am using a ddm_sdv model with blackbox likelihood. I have 6 chains, each with 1250 samples for tuning and 1250 drawn samples. I want to estimate RT for 740 out-of-distribution (OOD) samples using all posterior parameter samples. But the code runs for quite a while and gets killed by the OOM killer (& this is happening on a machine that has 1000GB RAM!).

Point 1:

So, I did a quick & dirty profiling of the time it takes for sampling as a function of the number of OOD samples and n_samples, and this is what I found:

time (avg ed over n_samples) vs #OOD

(The light blue line is a linear fit to the dark blue data points) This is in line with what anyone would expect: Time growing linearly as the number of OOD samples.

time (avg ed over #OOD) vs  n_samples

(The light blue line is a exponential fit to the dark blue data points) This is what I found surprising. I would expect something along the lines of linear growth -- Ideally, for each sample of n_samples, we would get one RT and choice, which becomes part of the posterior samples.

Point 2:

Memory footprint keeps growing over time, causing OOM killer to terminate the process. My guess is that some sort of memory leak is causing this because storing 1250 RT & Response values for 740 samples doesn't require a large space.

HSSM version 0.2.0

To Reproduce To reproduce (something along the lines of) the above plot: (idata has data for 6 chains and 1250 draws)

for n_samples in [50, 100, 150, 200, 300, 350, 400]:
    for ood_samples in [6, 8, 10, 12, 14, 16, 18]:
        start = time()
        output = ddm_model.sample_posterior_predictive(idata=idata, inplace=False, data=test_df[:ood_samples], n_samples=n_samples)
        end = time()
        print(ood_samples, n_samples, end-start)

Additional context

  1. Data details:

    • Stimulus is coded as -1, 0 and 1 (three levels)
    • data has ~25 participants with about 100 trials per participant.
  2. Model fit details: (hierarchical with no common intercepts)

    parameters = [
    {
        "name": "v",
        "formula": "v ~ 0 + (1|participant_id) + (stimulus|participant_id)",
        "link": "identity",
        "prior": {
            "1|participant_id": {
                "name": "Normal",
                "mu": {
                    "name": "Normal",
                    "mu": 2,
                    "sigma": 3,
                    "initval": 2,
                },
                "sigma": {
                    "name": "HalfNormal",
                    "sigma": 2,
                    "initval": 0.1,
                },
            },
            "stimulus|participant_id": {
                "name": "Normal",
                "mu": {
                    "name": "Normal",
                    "mu": 0,
                    "sigma": 15,
                    "initval": 0,
                },
                "sigma": {
                    "name": "Uniform",
                    "lower": 1e-10,
                    "upper": 100,
                    "initval": 0.1,
                },
            },
        },
        "bounds": (-numpy.inf, numpy.inf),
    },
    {
        "name": "z",
        "formula": "z ~ 0 + (1|participant_id)",
        "link": "gen_logit",
        "prior": {
            "1|participant_id": {
                "name": "Normal",
                "mu": {
                    "name": "Normal",
                    "mu": 0.5,
                    "sigma": 0.5,
                    "initval": 0.5,
                },
                "sigma": {
                    "name": "HalfNormal",
                    "sigma": 0.05,
                    "initval": 0.1,
                },
            },
        },
        "bounds": (0.0, 1.0),
    },
    {
        "name": "a",
        "formula": "a ~ 0 + (1|participant_id)",
        "link": "identity",
        "prior": {
            "1|participant_id": {
                "name": "Gamma",
                "mu": {
                    "name": "Gamma",
                    "mu": 1.5,
                    "sigma": 0.75,
                    "initval": 1,
                },
                "sigma": {
                    "name": "HalfNormal",
                    "sigma": 0.1,
                    "initval": 0.1,
                },
            },
        },
        "bounds": (0, numpy.inf),
    },
    {
        "name": "t",
        "formula": "t ~ 0 + (1|participant_id)",
        "link": "identity",
        "prior": {
            "1|participant_id": {
                "name": "Gamma",  # deviating from Weicki's paper because HDDM library implements Gamma
                "mu": {
                    "name": "Gamma",
                    "mu": 0.4,
                    "sigma": 0.2,
                    "initval": 0.001,  # small value here can help in convergence
                },
                "sigma": {
                    "name": "HalfNormal",
                    "sigma": 1,
                    "initval": 0.2,
                },
            },
        },
        "bounds": (0.0, numpy.inf),
    },
    {
        "name": "sv",
        "prior": {
            "name": "HalfNormal",
            "sigma": 2.0,
            "initval": 1.0,
        },
        "bounds": (0, numpy.inf),
    },
    ]
    ddm_model = hssm.HSSM(
    data=data,
    model='ddm_sdv',
    loglik_kind='blackbox',
    p_outlier=0.05,
    include=parameters,
    )
jainraj commented 4 months ago

If there are any workarounds which could be suggested for now, while the RCA is found, it would be of great help!

AlexanderFengler commented 4 months ago

Hey @jainraj ,

we looked into this and can confirm that there is something funny happening with the scaling of the computation. So far we haven't figured out what the exact culprit is, but it seems that the n_samples argument doesn't scale as expected.

If you look at this script:

from time import time
timings_df = pd.DataFrame(columns = ['n_samples', 'ood_samples', 'time'])
divisor = 500
for n_draws in [500]:
    start = time()
    for i in range(n_draws // (divisor)):
        output = ddm_model.sample_posterior_predictive(idata=idata, inplace=False, 
                                                    data=data, n_samples=int(divisor))
        if i % 10 == 0:
            print(i)
    end = time()
    print('n_samples' + ', ' + 'ood_samples', n_draws, data.shape[0], 'time: ', end-start)
    timings_df.loc[-1] = [n_draws, data.shape[0], end-start]

If the divisor is changed from small to large, the computation slows down as suggested by you and my laptop crashed when I allow n_draws = 1000, divisor = 1000. As a workaround, you can instead call this type of computation with small n_samples sequentially, and later combine the results. We will incorporate a monkey-patch along those lines in the PR following from this issue as well.

For completeness, below the code to simulate the dataset that was used:

import ssms
import hssm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytensor
hssm.set_floatX('float32', jax = True)
import re

# number of subjects and trials
n_subj = 10
n_trial = 100

# group level parameters
group_mu_a = 1.2
group_sd_a = 0.3
group_mu_z = 0.6
group_sd_z = 0.075
group_mu_t = 0.3
group_sd_t = 0.075

# subject level parameters
subj_mu_a = np.random.normal(loc = group_mu_a, scale = group_sd_a, size = n_subj)
subj_mu_a_trialwise = np.repeat(subj_mu_a, n_trial) 
subj_mu_z = np.random.normal(loc = group_mu_z, scale = group_sd_z, size = n_subj)
subj_mu_z_trialwise = np.repeat(subj_mu_z, n_trial) 
subj_mu_t = np.random.normal(loc = group_mu_t, scale = group_sd_t, size = n_subj)
subj_mu_t_trialwise = np.repeat(subj_mu_t, n_trial) 

# parameter with continuous covariate
stimulus = np.random.uniform(low = -1,high = 1, size = n_subj * n_trial)
v_mu = 1.5
v_beta = 0.2
v_trialwise = v_mu + v_beta * stimulus 
participant_id = np.repeat(np.arange(n_subj), n_trial).astype(int)

# make dataframe
data = pd.DataFrame(np.vstack([v_trialwise, subj_mu_a_trialwise, 
                               subj_mu_z_trialwise, subj_mu_t_trialwise, 
                               stimulus, participant_id]).T,
                    columns = ['v', 'a', 'z', 't', 'stimulus', 'participant_id'])
data['participant_id'] = data['participant_id'].values.astype(int)
sim_out = ssms.basic_simulators.simulator.simulator(model = 'ddm',
                                                    theta = data[['v', 'a', 'z', 't']],
                                                    n_samples = 1)
data['rt'] = sim_out['rts'].squeeze()
data['response'] = sim_out['choices'].squeeze()

data[data.select_dtypes(include = ['float64']).columns] = \
    data.select_dtypes(include = ['float64']).astype('float32')
jainraj commented 4 months ago

Thanks, Dr. @AlexanderFengler, for reproducing the issue and the sample script.

Originally, I wanted the posterior predictive samples for all the posterior samples of the parameters (why consider a subset when I have access to the full set). I considered combining results by running small n_samples, but each function call would choose a different subsample. Thus, there is no guarantee that all the posterior samples of the parameters will be utilized, plus there is a risk that some samples will be selected multiple times.

Is there a way to use the current code with sequential small n_samples to iterate through the full set of posterior samples? If not, could the functionality be added to the monkey-patch (although I am unsure about the right software engineering approach)?

AlexanderFengler commented 4 months ago

Hey @jainraj,

trying to merge the monkey-patch today. Will ping here.

Best, Alex