wilsonmr / anvil

Repository containing code related to flow based generative model
https://wilsonmr.github.io/anvil/
GNU General Public License v3.0
0 stars 2 forks source link

Spline flows #52

Closed jmarshrossney closed 3 years ago

jmarshrossney commented 4 years ago

A single branch / PR for all three implemented 'spline' flows, containing three new coupling layers

  1. Linear spline: [0, 1] -> [0, 1], introduced in https://arxiv.org/pdf/1808.03856.pdf.
  2. Quadratic spline: same as above.
  3. Circular spline: S1 -> S1, based on the "monotonic rational quadratic transform introduced in https://arxiv.org/pdf/1906.04032.pdf, but where the boundary conditions are such that the transformation maps 1-dimensional directional data, as proposed in https://arxiv.org/pdf/2002.02428.pdf (2.1.2).

They all use a third-party package called torchsearchsorted to sort the x_b into the correct bin. This has the advantage over the numpy version in that, even if the bin boundaries vary within the batch, it can still sort the entire batch in one go.

However, this does require reshaping the x_b tensor into a tensor of dimensions (n_batch * size_half, 1) so that the sorting can be done with one call to searchsorted. It does make me uneasy to mix the batch and lattice dimensions, but it's just for one step and it's a lot faster than looping!

One thing I haven't done is check for a valid base distribution - currently if you use the linear or quadratic splines with any base distribution not defined on [0, 1] it will break. Perhaps this is a good time to us reportengine checks?


If we wanted to improve performance we could, in future, try the "one blob encoding" discussed in the https://arxiv.org/pdf/1808.03856.pdf, which involves replacing a single 'x' with a Gaussian that activates several bins at once.

jmarshrossney commented 3 years ago

Done in #61