paul-buerkner / brms

brms R package for Bayesian generalized multivariate non-linear multilevel models using Stan
https://paul-buerkner.github.io/brms/
GNU General Public License v2.0
1.28k stars 184 forks source link

Enable access to priors pre-trained with neural networks #1482

Closed elizavetasemenova closed 9 months ago

elizavetasemenova commented 1 year ago

It would be a great feature to enable access to priors pre-trained externally with neural networks. An immediate application are spatial models, where the computationally intensive Gaussian processes, or the multivariate normal distribution representing the spatial random effect, are substituted with an approximation provided by a pre-trained neural network, e.g. decoder of a variational autoencoder (VAE). The training of such priors is described in the papers:

1). PriorVAE 2). $\pi$VAE 3). PriorCVAE

PriorVAE and PriorCVAE are linked to a fixed spatial structure. For example, this could be borders of a country at admin2 level. Syntaxis could be similar to the CAR model and look, for example, as follows:

fit <- brm(y | trials(size) ~ x1 + x2 + nn(type='PriorCVAE', level='adm2', iso3='ZWE'), 
           data = dat, 
           family = binomial())

Here type is the type of neural network used (i.e. archtecture), level is the administrative level (typically, adm1 or adm2), and iso3 is the three-letter country abbreviation.

paul-buerkner commented 1 year ago

Thanks for opening this issue! A few questions to start with:

elizavetasemenova commented 1 year ago

Say, for example, we want to work with a GP with kernel $k(x_1,x_2)= \exp(-|x_1-x_2|^2/2 \theta ^2)$.

Normally, we would compute $z \sim N(0,In), f = L\theta z$, where $L_\theta = chol(K)$. At inference, updated parameters are $z$ and hyperparameter $\theta$.

In the new setting, we compute $z \sim N(0,Id), f = \phi{w, \theta}(z)$, where $\phi_{w, \theta}(.)$ is a deterministic transformation; parameters $w$ (weights) are constant and are passed to a Stan model as data. At inference, updated parameters are $z$ and hyperparameter $\theta$ (describing the PriorCVAE version here; PriorVAE wouldn't have $\theta$ to infer)

A Stan example (from https://github.com/MLGlobalHealth/pi-vae/blob/main/src_stan/stan_1D.stan):

functions {
  vector layer(vector x, matrix W, vector B) {
    return(transpose(transpose(x) * W) + B);
  }
  vector generator_stan(vector input, matrix W1, matrix W2, matrix W3, vector B1, vector B2, vector B3) {
    return(layer(tanh(layer(tanh(layer(input,W1,B1)),W2,B2)),W3,B3));
  }
}
data {
  int p; // input dimensionality (latent dimension of VAE)
  int p1; // hidden layer 1 number of units
  int p2; // hidden layer 2 number of units
  int n; // output dimensionality
  matrix[p,p1] W1; // weights matrix for layer 1
  vector[p1] B1; // bias matrix for layer 1
  matrix[p1,p2] W2; // weights matrix for layer 2
  vector[p2] B2; // bias matrix for layer 2
  matrix[p2,n] W3; // weights matrix for layer output
  vector[n] B3; // bias matrix for layer output

  vector[n] y;
}
parameters {
  vector[p] z;
  real<lower=0> sigma2;
}
transformed parameters {
  vector[n] f;
  f = generator_stan(z,W1,W2,W3,B1,B2,B3);
}
model {
  z ~ normal(0,1);
  sigma2 ~ normal(0,10);
  y ~ normal(f,sigma2);
}
generated quantities {
  vector[n] y2;
  for (i in 1:n)
    y2[i] = normal_rng(f[i], sigma2);
}

generator_stan function defines the deterministic transformation $\phi(.)$. In theory, one should be able to use NN architecture of any complexity, as long as it can be re-implemented in Stan. To get started, could be enough to use multilayer perceptrons, so that re-implementing it in Stan by hand is straightforward.

We could create a separate repo on GitHub containing pre-trained models

This is a hard question at this stage. Externally, for the time being this interface should suffice:

nn(
type='PriorCVAE',        # NN architecture
level='adm2',            # administrative level
iso3='ZWE'               # country
)

which would automatically pull NN architecture defined in Stan and the pre-trained weights from an external repository.

paul-buerkner commented 1 year ago

Thank you for the additional details. Do you think the whole methodogy is generally production ready? Asked differently, is now a good time working on a brms implementation or would it make more sense to wait a bit until the methodolgy is more mature?

paul-buerkner commented 9 months ago

I will close this issue now to reduce the load of the brms issue tracker. If you want to revisit this issue, just write here and we can discuss if the methods are worth implementing in brms.