rust-ndarray / ndarray

ndarray: an N-dimensional array with array views, multidimensional slicing, and efficient operations
https://docs.rs/ndarray/
Apache License 2.0
3.53k stars 297 forks source link

Multivariate normal distribution in ndarray-rand #582

Open ManifoldFR opened 5 years ago

ManifoldFR commented 5 years ago

Implementing multivariate normal distributions involves a bit of boilerplate, and maybe the use of ndarray-linalg to perform a Cholesky decomposition. Would it be interesting to implement it on the crate's end ? I made a fork and started writing some code.

jturner314 commented 5 years ago

Fwiw, I've been using this for a while in my own code:

use failure::{Context, ResultExt};
use ndarray::{Data, DataClone, DataOwned, OwnedRepr, ViewRepr};
use ndarray::prelude::*;
use ndarray_linalg::cholesky::{CholeskyInto, UPLO};
use ndarray_rand::RandomExt;
use rand::distributions::{Distribution, Normal};
use rand::Rng;
use std::clone::Clone;
use std::fmt::{self, Debug};
use std::ops::AddAssign;

// ...

/// Multivariate Gaussian distribution.
#[derive(PartialEq, Deserialize, Serialize)]
#[serde(bound(deserialize = "S: DataOwned, S::Elem: ::serde::Deserialize<'de>"))]
#[serde(bound(serialize = "S: Data, S::Elem: ::serde::Serialize"))]
pub struct GaussianDistroBase<S>
where
    S: Data<Elem = f64>,
{
    pub mean: ArrayBase<S, Ix1>,
    pub covariance: ArrayBase<S, Ix2>,
}

pub type GaussianDistro = GaussianDistroBase<OwnedRepr<f64>>;

pub type GaussianDistroView<'a> = GaussianDistroBase<ViewRepr<&'a f64>>;

impl<S> GaussianDistroBase<S>
where
    S: Data<Elem = f64>,
{
    pub fn len(&self) -> usize {
        assert_eq!(self.mean.len(), self.covariance.len_of(Axis(0)));
        assert_eq!(self.mean.len(), self.covariance.len_of(Axis(1)));
        self.mean.len()
    }

    pub fn to_owned(&self) -> GaussianDistro {
        GaussianDistro {
            mean: self.mean.to_owned(),
            covariance: self.covariance.to_owned(),
        }
    }

    pub fn view(&self) -> GaussianDistroView {
        GaussianDistroView {
            mean: self.mean.view(),
            covariance: self.covariance.view(),
        }
    }
}

impl<S> Clone for GaussianDistroBase<S>
where
    S: DataClone<Elem = f64>,
{
    fn clone(&self) -> Self {
        GaussianDistroBase {
            mean: self.mean.clone(),
            covariance: self.covariance.clone(),
        }
    }
}

impl<S> Debug for GaussianDistroBase<S>
where
    S: Data<Elem = f64>,
{
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        fmt.debug_struct("GaussianDistroBase")
            .field("mean", &self.mean)
            .field("covariance", &self.covariance)
            .finish()
    }
}

#[derive(Debug, Fail)]
#[fail(display = "error sampling from multivariate normal distribution: {}", _0)]
pub struct GaussianSampleError(Context<String>);

impl From<Context<String>> for GaussianSampleError {
    fn from(context: Context<String>) -> GaussianSampleError {
        GaussianSampleError(context)
    }
}

impl<S> Distribution<Result<Array1<f64>, GaussianSampleError>> for GaussianDistroBase<S>
where
    S: Data<Elem = f64>,
{
    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<Array1<f64>, GaussianSampleError> {
        let mut cov = self.covariance.to_owned();
        // Add a small multiple of I for numerical reasons.
        cov.diag_mut().add_assign(1e2 * ::std::f64::EPSILON);
        let chol = cov.cholesky_into(UPLO::Lower)
            .context("error factoring covariance".into())?;
        Ok(chol.dot(&Array1::random_using(self.len(), Normal::new(0., 1.), rng)) + &self.mean)
    }
}

// ...

It's probably a bit more complex than what you're looking for (since GaussianDistroBase is generic over storage S), but it could be simplified.

For the purpose of ndarray-rand, adding a dependency on ndarray-linalg seems unfortunate because ndarray-linalg requires non-Rust code (the LAPACK implementation). I suppose it would be fine if we put the functionality behind a feature flag. What do you think @bluss?

ManifoldFR commented 5 years ago

Yes, a feature flag seems appropriate, I'm using the same thing on a MCMC algorithms crate I'm working on (for adding support to multivariate distributions using ndarray).

Your code seems a bit overkill, I was only planning on implementation for OwnedRepr<f64> datatype at first, but I'll into making it a bit more generic!

I'll make a pull request so you can check my work.