rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

Create an `Array1D` type for position, momentum, etc in sampling #2

Closed rlouf closed 4 years ago

rlouf commented 4 years ago

Shapes can quickly get confusing in the dynamics generators, kernel, adaptive schemes, etc since we constantly have a choice of passing positions and momentum as 1D array or a collection of objects with different shapes.

In mcx, inputs to elementary parts of the samplers are assumed to be 1D arrays. I specified it in the dosctring of each file, but it can still be confusing for someone who missed the note, or read it but is not sure to which variable this applies.

To improve readability of the code, and allow some static type checking I propose to define an Array1D type to indicate which variables are supposed to be flat arrays. I also suggest, while we're at it, to create a type for mass matrices, which can be 1D or 2D arrays.

rlouf commented 4 years ago

Not needed anymore that we use 1D arrays everywhere in the core. Closing.