aesara-devs / aehmc

An HMC/NUTS implementation in Aesara
MIT License
33 stars 6 forks source link
hmc mcmc nuts symbolic-computation
# Aehmc [![Pypi][pypi-badge]][pypi] [![Gitter][gitter-badge]][gitter] [![Discord][discord-badge]][discord] [![Twitter][twitter-badge]][twitter] AeHMC provides implementations for the HMC and NUTS samplers in [Aesara](https://github.com/aesara-devs/aesara). [Features](#features) • [Get Started](#get-started) • [Install](#install) • [Get help](#get-help) • [Contribute](#contribute)

Get started

import aesara
from aesara import tensor as at
from aesara.tensor.random.utils import RandomStream

from aeppl import joint_logprob

from aehmc import nuts

# A simple normal distribution
Y_rv = at.random.normal(0, 1)

def logprob_fn(y):
    return joint_logprob(realized={Y_rv: y})[0]

# Build the transition kernel
srng = RandomStream(seed=0)
kernel = nuts.new_kernel(srng, logprob_fn)

# Compile a function that updates the chain
y_vv = Y_rv.clone()
initial_state = nuts.new_state(y_vv, logprob_fn)

step_size = at.as_tensor(1e-2)
inverse_mass_matrix=at.as_tensor(1.0)
chain_info, updates = kernel(initial_state, step_size, inverse_mass_matrix)

next_step_fn = aesara.function([y_vv], chain_info.state.position, updates=updates)

print(next_step_fn(0))
# 1.1034719409361107

Install

The latest release of AeHMC can be installed from PyPI using pip:

pip install aehmc

Or via conda-forge:

conda install -c conda-forge aehmc

The current development branch of AeHMC can be installed from GitHub using pip:

pip install git+https://github.com/aesara-devs/aehmc

Get help

Report bugs by opening an issue. If you have a question regarding the usage of AeHMC, start a discussion. For real-time feedback or more general chat about AeHMC use our Discord server or Gitter room.

Contribute

AeHMC welcomes contributions. A good place to start contributing is by looking at the issues.

If you want to implement a new feature, open a discussion or come chat with us on Discord or Gitter.