lawmurray / Birch

A probabilistic programming language that combines automatic differentiation, automatic marginalization, and automatic conditioning within Monte Carlo methods.
https://birch-lang.org
Apache License 2.0
112 stars 14 forks source link

Change from split() to vector_split() #14

Closed MaurusGubser closed 3 years ago

MaurusGubser commented 3 years ago

I use Birch for implementing a HMM for a model of two legs in a course of motion. Basically, I have a state transition model, which is linear and Gaussian with a non-diagonal covariance and an observation model, which is non-linear with Gaussian noise. Thus, I use a vector expression for the state variable, which I split up in a vector of expressions to compute the observation variable.

Previously, I used split in my code. After the most recent update of Birch, split is deprecated. I guess I should use _vectorsplit, but it is not clear to me, how a VectorSplitExpression is used. With split, I could define a new variable x (vector of expressions), whose elements I could access using the bracket operator, see code fragment below.

function state_to_observation(u:Random<Real[_]>) -> Array<Expression<Real>> {
  let x <- split<Real>(u);
  y:Array<Expression<Real>>;

  y.insert(1, x[15]*cst_0 + x[13]*cos(x[3]) + (x[14] + g)*sin(x[3]));
  ...
  return y;
}

How can I use _vectorsplit and how can I access the elements of the vector after using _vectorsplit?

Not sure if this is the right place to ask a question, but I did not find any other place to do so.

lawmurray commented 3 years ago

Thanks for the question. The aim has been to make the use of split unnecessary in circumstances such as this, so that you can just access elements directly with a new function element(). Can you try something like this?

function state_to_observation(u:Random<Real[_]>) -> Array<Expression<Real>> {
  y:Array<Expression<Real>>;
  y.insert(1, element(u, 15)*cst_0 + element(u, 13)*cos(element(u, 3)) + (element(u, 14) + g)*sin(element(u, 3));
  ...
  return y;
}

On the to-do list is to overload the square brackets again so that e.g. element(u, 15) can be replaced with (the nicer) u[15] again.