Open ManifoldFR opened 5 years ago
Fwiw, I've been using this for a while in my own code:
use failure::{Context, ResultExt};
use ndarray::{Data, DataClone, DataOwned, OwnedRepr, ViewRepr};
use ndarray::prelude::*;
use ndarray_linalg::cholesky::{CholeskyInto, UPLO};
use ndarray_rand::RandomExt;
use rand::distributions::{Distribution, Normal};
use rand::Rng;
use std::clone::Clone;
use std::fmt::{self, Debug};
use std::ops::AddAssign;
// ...
/// Multivariate Gaussian distribution.
#[derive(PartialEq, Deserialize, Serialize)]
#[serde(bound(deserialize = "S: DataOwned, S::Elem: ::serde::Deserialize<'de>"))]
#[serde(bound(serialize = "S: Data, S::Elem: ::serde::Serialize"))]
pub struct GaussianDistroBase<S>
where
S: Data<Elem = f64>,
{
pub mean: ArrayBase<S, Ix1>,
pub covariance: ArrayBase<S, Ix2>,
}
pub type GaussianDistro = GaussianDistroBase<OwnedRepr<f64>>;
pub type GaussianDistroView<'a> = GaussianDistroBase<ViewRepr<&'a f64>>;
impl<S> GaussianDistroBase<S>
where
S: Data<Elem = f64>,
{
pub fn len(&self) -> usize {
assert_eq!(self.mean.len(), self.covariance.len_of(Axis(0)));
assert_eq!(self.mean.len(), self.covariance.len_of(Axis(1)));
self.mean.len()
}
pub fn to_owned(&self) -> GaussianDistro {
GaussianDistro {
mean: self.mean.to_owned(),
covariance: self.covariance.to_owned(),
}
}
pub fn view(&self) -> GaussianDistroView {
GaussianDistroView {
mean: self.mean.view(),
covariance: self.covariance.view(),
}
}
}
impl<S> Clone for GaussianDistroBase<S>
where
S: DataClone<Elem = f64>,
{
fn clone(&self) -> Self {
GaussianDistroBase {
mean: self.mean.clone(),
covariance: self.covariance.clone(),
}
}
}
impl<S> Debug for GaussianDistroBase<S>
where
S: Data<Elem = f64>,
{
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("GaussianDistroBase")
.field("mean", &self.mean)
.field("covariance", &self.covariance)
.finish()
}
}
#[derive(Debug, Fail)]
#[fail(display = "error sampling from multivariate normal distribution: {}", _0)]
pub struct GaussianSampleError(Context<String>);
impl From<Context<String>> for GaussianSampleError {
fn from(context: Context<String>) -> GaussianSampleError {
GaussianSampleError(context)
}
}
impl<S> Distribution<Result<Array1<f64>, GaussianSampleError>> for GaussianDistroBase<S>
where
S: Data<Elem = f64>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<Array1<f64>, GaussianSampleError> {
let mut cov = self.covariance.to_owned();
// Add a small multiple of I for numerical reasons.
cov.diag_mut().add_assign(1e2 * ::std::f64::EPSILON);
let chol = cov.cholesky_into(UPLO::Lower)
.context("error factoring covariance".into())?;
Ok(chol.dot(&Array1::random_using(self.len(), Normal::new(0., 1.), rng)) + &self.mean)
}
}
// ...
It's probably a bit more complex than what you're looking for (since GaussianDistroBase
is generic over storage S
), but it could be simplified.
For the purpose of ndarray-rand
, adding a dependency on ndarray-linalg
seems unfortunate because ndarray-linalg
requires non-Rust code (the LAPACK implementation). I suppose it would be fine if we put the functionality behind a feature flag. What do you think @bluss?
Yes, a feature flag seems appropriate, I'm using the same thing on a MCMC algorithms crate I'm working on (for adding support to multivariate distributions using ndarray
).
Your code seems a bit overkill, I was only planning on implementation for OwnedRepr<f64>
datatype at first, but I'll into making it a bit more generic!
I'll make a pull request so you can check my work.
Implementing multivariate normal distributions involves a bit of boilerplate, and maybe the use of
ndarray-linalg
to perform a Cholesky decomposition. Would it be interesting to implement it on the crate's end ? I made a fork and started writing some code.