AlexRodis / bayesian-models

A small library build on top of `pymc` that implements many common models
Apache License 2.0
0 stars 0 forks source link

[ENH]: Context Manager API #80

Open AlexRodis opened 1 year ago

AlexRodis commented 1 year ago

Complicated models, particularly deep models end up with rather ugly APIs with deep nesting. For example:

from bayesian_models.core import ResponseFunctionsComponent, LikelihoodComponent, FreeVarsComponent
from bayesian_models.core import GaussianProcessCoreComponent, ModelAdaptorComponent, Distribution, distribution
from bayesian_models.core import GPLayer, GaussianSubprocess

layers = [
    GPLayer(
    [
        GaussiaSubprocess(...), ....
    ]
    ),
   ...
]

likelihood = LikelihoodComponent(pm.StudentT, var_mappings= dict(...))
responses = ResponseFunctionsComponent(...)
rvs = FreeVarsComponent(...)
 obj = GaussianProcess(

)
obj = GaussianProcess(layers, likelihood, responses= responses, extra_rvs=rvs...)

Its possible to simplify this API by hijacking the context manager:

with GaussianProcess() as obj:
    likelihood = LikelihoodComponent(...)
    layers = [...]
    responses = ResponsFunctionsComponent(...)
    free_rvs = FreeVarsComponent(...)
    ...

Ordinarily, the context managers' __enter__ and __exit__ methods do not have access to the callers' score and hence cannot hook into the variables defined in the code block under the manager. pymc accomplishes something very similar. To do so, the patter is vaguely:

import threading

class Context(object):

    contexts = threading.local()

    def __enter__(self):
        type(self).get_contexts().append(self)
        return self

    def __exit__(self, typ, value, traceback):
        type(self).get_contexts().pop()

    @classmethod
    def get_contexts(cls):
        # no race-condition here, cls.contexts is a thread-local object
        # be sure not to override contexts in a subclass however!
        if not hasattr(cls.contexts, 'stack'):
            cls.contexts.stack = []
        return cls.contexts.stack

    @classmethod
    def get_context(cls):
        """Return the deepest context on the stack."""
        try:
            return cls.get_contexts()[-1]
        except IndexError:
            raise TypeError("No context on context stack")

def modelcontext(model):
    """return the given model or try to find it in the context if there was
    none supplied.
    """
    if model is None:
        return Model.get_context()
    return model    

Component objects need to call get_context to latch unto the top of the context stack.This is an instance of BayesianModel which should not have __slots__ and then register themselves as attributes:

class FreeVarsComponent:
        ...
    def __post_init__(self):
        ...
        # Components need a ref to their caller
        type(caller).get_context().free_vars_component  = self

This may be a backwards incompatible change and its not clear how exactly it should be implemented