biocore / BIRDMAn

Bayesian Inferential Regression for Differential Microbiome Analysis
BSD 3-Clause "New" or "Revised" License
22 stars 5 forks source link

Abstracting out parallelism #51

Closed mortonjt closed 3 years ago

mortonjt commented 3 years ago

As discussed earlier, I think the most reasonable step forward is if we don't try to handle parallelism, but keep the doors open for users to use their own favorite parallelism tool (i.e. multiprocessing, joblib, dask, snakemake, jobarrays, disbatch ...)

I'm thinking that introducing some form of ModelIterator abstraction would be key for enabling this when fitting multiple features in an embarrassingly parallel fashion. Off the cuff, I'd imagine that the architecture would look something like this

class Model(SomeABC):
    def __init__(self, feature_id, **kwargs):
        self.feature_id = feature_id
        ...
    def fit(self, args):
        """ Fit model """
        ...
    def save(self, filename):
        """ Serialize model to some location via json format. """
        ...
    @staticmethod
    def load(self, filename):
        """ Load model from path to json file. Important for job array-like parallelism"""
        ...

class ModelIterator(SomeOtherABC):
    def __init__(self, **kwargs):
        # some setup
        self.feature_ids = # specify all feature ids from biom table or something
        self.models = [Model(fid, **kwargs) for fid in self.feature_ids] # instantiate all models
    def __iter__(self):
        for feature_id, model in zip(self.feature_ids, self.models):
            yield feature_id, model
    def merge(self, paths):
        """ update models with fitted analogues. """
    def fit(self):
        """ Fit in serial in case users don't want to do it themselves. """

From there, if the user really wants parallelism, they can code it up themselves -- we can provide tutorials on how to do this, but birdman won't be providing any support on parallelism. Below are some code skeletons on what these tutorials could look like.

## An example of running this with dask
import dask
fit_input = # some arguments to be passed into fit
nb_models = ModelIterator(biom_table, sample_metadata)
outputs = []
for i, m in nb_models:
    o = dask.delay(m.fit())(fit_input)
    outputs.append(o)
outputs = dask.compute(outputs)

## An example of running this with disbatch / jobarrays
nb_models = ModelIterator(biom_table, sample_metadata)
save_directory = 'somewhere/out/there'
paths = []
with tempfile.TemporaryDirectory() as temp_dir_name:
    task_fp = os.path.join(temp_dir_name, 'tasks.txt')
    with open(task_fp, 'w') as fh:
        for i, m in nb_models:
            model_path = os.join(save_directory, i)
            paths.append(model_path)
            m.save(model_path)
            ### Assume that a separate CLI script has already been created that takes in a model as input via load function 
            ### fit_single_model.py --input-model <> --biom-table <> --sample-metadata <> ...
            cmd_ = f'fit_single_model.py --biom-table {args.biom_table} --metadata-file {args.metadata_file}  --input-model {model_path} --fitted_model {model_path}'
            fh.write(cmd_)
    ## Run disBatch with the SLURM environmental parameters
    cmd = f'disBatch {task_fp}'
    output = subprocess.run(cmd, env=slurm_env, check=True, shell=True)
    ## probably will need to merge results.
    nb_models.merge(paths)

I may have left out some details, feel free to follow up.

mortonjt commented 3 years ago

Another thought on the save / load serialization -- perhaps we don't need to be able to serialize the CmdStanMCMC model itself; we just need to serialize the parameters required to initialize a CmdStanMCMC model. This can be easily done by saving / loading json files.