a pythonic interface for R's brms
brms is a fantastic R package that allows users to fit many kinds of Bayesian regression models - linear models, GLMs, survival analysis, etc - all in a multilevel context. Models are concisely specified using R's formula syntax, and the corresponding Stan program and data are automatically generated.
pybrms aims to bring the ease-of-use of brms to python users; more sampling, less index-gymnastics and shape errors.
Install via pip:
pip install pybrms
This installs the python package along with its pythonic dependencies; when you first call pybrms
, it'll install brms
and its dependencies. Don't worry if you don't have R or brms installed - they will be installed in your current virtual environment.
Let's use pybrms to fit a poisson regression model, including an interaction term and by-patient varying intercept. pybrms
can import all datasets that are included in brms
:
epilepsy = pybrms.get_brms_data("epilepsy")
Fitting the model is as simple as it is in brms
:
model = pybrms.fit(
formula = "count ~ zAge + zBase * Trt + (1 | patient)",
data = epilepsy,
family = "poisson"
)
INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_6f531b7e8a9bc73464e98930b52f4547 NOW.
The user can also specify a list of priors, a family
argument (is it a gaussian regression, binomial, poisson, etc), and optional pystan
arguments like the number of chains, samples, etc.
When sampling is completed, fit
returns a pystan
StanFit4Model
object. The generated stan code is also available:
model.get_stanmodel().show()
StanModel object 'anon_model_6f531b7e8a9bc73464e98930b52f4547' coded as follows:
// generated with brms 2.10.5
functions {
}
data {
int<lower=1> N; // number of observations
int Y[N]; // response variable
int<lower=1> K; // number of population-level effects
matrix[N, K] X; // population-level design matrix
// data for group-level effects of ID 1
int<lower=1> N_1; // number of grouping levels
int<lower=1> M_1; // number of coefficients per level
int<lower=1> J_1[N]; // grouping indicator per observation
// group-level predictor values
vector[N] Z_1_1;
int prior_only; // should the likelihood be ignored?
}
transformed data {
int Kc = K - 1;
matrix[N, Kc] Xc; // centered version of X without an intercept
vector[Kc] means_X; // column means of X before centering
for (i in 2:K) {
means_X[i - 1] = mean(X[, i]);
Xc[, i - 1] = X[, i] - means_X[i - 1];
}
}
parameters {
vector[Kc] b; // population-level effects
// temporary intercept for centered predictors
real Intercept;
vector<lower=0>[M_1] sd_1; // group-level standard deviations
// standardized group-level effects
vector[N_1] z_1[M_1];
}
transformed parameters {
// actual group-level effects
vector[N_1] r_1_1 = (sd_1[1] * (z_1[1]));
}
model {
// initialize linear predictor term
vector[N] mu = Intercept + Xc * b;
for (n in 1:N) {
// add more terms to the linear predictor
mu[n] += r_1_1[J_1[n]] * Z_1_1[n];
}
// priors including all constants
target += student_t_lpdf(Intercept | 3, 1, 10);
target += student_t_lpdf(sd_1 | 3, 0, 10)
- 1 * student_t_lccdf(0 | 3, 0, 10);
target += normal_lpdf(z_1[1] | 0, 1);
// likelihood including all constants
if (!prior_only) {
target += poisson_log_lpmf(Y | mu);
}
}
generated quantities {
// actual population-level intercept
real b_Intercept = Intercept - dot_product(means_X, b);
}
Behind the scene, pybrms
calls brms
via rpy2
, handling the python-to-R-objects transitions in both directions and making sure that Stan gets the dtypes it expects by parsing the model's data block.
More specifically, pybrms
calls two brms
functions: make_stancode
and make_standata
, which are used to generate the appropriate model code, design matrices, etc. These are then "pulled back" to python and fed into pystan
.
By defaults, brms
uses non- or weakly-informative priors on model parameters. You can specify more informative priors using the following syntax:
model = pybrms.fit("count ~ zAge + zBase * Trt + (1|patient) + (1|obs)",
data = epilepsy, family = "poisson",
priors = [("student_t(5,0,10)", "b"),
("cauchy(0,2)", "sd")]
)
INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_3b2730a3196dc4b804b959d98397ba09 NOW.
Priors are passed as a list of tuples that conform to brms
set_prior
order of arguments - the first element is a Stan distribution, second is the class (b, Intercept, sd), third is coef, etc.
Error handling happens at the rpy2
level, which catches the (R) error and displays it as a python RRuntimeError
exception:
model = pybrms.fit("count_typo ~ zAge + zBase * Trt + (1|patient) + (1|obs)",
data = epilepsy, family = "poisson",
priors = [("student_t(5,0,10)", "b"),
("cauchy(0,2)", "sd")]
)
WARNING:rpy2.rinterface_lib.callbacks:R[write to console]: Error: The following variables are missing in 'data':
'count_typo'
---------------------------------------------------------------------------
RRuntimeError Traceback (most recent call last)
<ipython-input-6-3103bec4b060> in <module>
2 data = epilepsy, family = "poisson",
3 priors = [("student_t(5,0,10)", "b"),
----> 4 ("cauchy(0,2)", "sd")]
5 )
~/projects/pybrms/pybrms/pybrms.py in fit(formula, data, priors, family, sample_prior, sample, **pystan_args)
128 family=family,
129 priors=brms_prior,
--> 130 sample_prior=sample_prior,
131 )
132 model_data = _convert_R_to_python(formula, data, family)
~/projects/pybrms/pybrms/pybrms.py in get_stan_code(formula, data, priors, family, sample_prior)
54 if len(priors)>0:
55 return brms.make_stancode(
---> 56 formula=formula, data=data, prior=priors, family=family, sample_prior=sample_prior
57 )[0]
58 else:
~/miniconda3/envs/rpy/lib/python3.6/site-packages/rpy2/robjects/functions.py in __call__(self, *args, **kwargs)
190 kwargs[r_k] = v
191 return (super(SignatureTranslatedFunction, self)
--> 192 .__call__(*args, **kwargs))
193
194
~/miniconda3/envs/rpy/lib/python3.6/site-packages/rpy2/robjects/functions.py in __call__(self, *args, **kwargs)
119 else:
120 new_kwargs[k] = conversion.py2rpy(v)
--> 121 res = super(Function, self).__call__(*new_args, **new_kwargs)
122 res = conversion.rpy2py(res)
123 return res
~/miniconda3/envs/rpy/lib/python3.6/site-packages/rpy2/rinterface_lib/conversion.py in _(*args, **kwargs)
26 def _cdata_res_to_rinterface(function):
27 def _(*args, **kwargs):
---> 28 cdata = function(*args, **kwargs)
29 # TODO: test cdata is of the expected CType
30 return _cdata_to_rinterface(cdata)
~/miniconda3/envs/rpy/lib/python3.6/site-packages/rpy2/rinterface.py in __call__(self, *args, **kwargs)
783 error_occured))
784 if error_occured[0]:
--> 785 raise embedded.RRuntimeError(_rinterface._geterrmessage())
786 return res
787
RRuntimeError: Error: The following variables are missing in 'data':
'count_typo'
This makes sure you can debug your data/formula/etc without actually leaving python.
Since pybrms
returns a pystan
object, we can easily visualize the results using arviz
:
import arviz as az
inference_data = az.from_pystan(model)
az.plot_posterior(inference_data, var_names=['b', 'Intercept']);
For a more detailed walkthrough, see the accompanying blog post.