Circular spline: S1 -> S1, based on the "monotonic rational quadratic transform introduced in https://arxiv.org/pdf/1906.04032.pdf, but where the boundary conditions are such that the transformation maps 1-dimensional directional data, as proposed in https://arxiv.org/pdf/2002.02428.pdf (2.1.2).
They all use a third-party package called torchsearchsorted to sort the x_b into the correct bin.
This has the advantage over the numpy version in that, even if the bin boundaries vary within the batch, it can still sort the entire batch in one go.
However, this does require reshaping the x_b tensor into a tensor of dimensions (n_batch * size_half, 1) so that the sorting can be done with one call to searchsorted. It does make me uneasy to mix the batch and lattice dimensions, but it's just for one step and it's a lot faster than looping!
One thing I haven't done is check for a valid base distribution - currently if you use the linear or quadratic splines with any base distribution not defined on [0, 1] it will break. Perhaps this is a good time to us reportengine checks?
If we wanted to improve performance we could, in future, try the "one blob encoding" discussed in the https://arxiv.org/pdf/1808.03856.pdf, which involves replacing a single 'x' with a Gaussian that activates several bins at once.
A single branch / PR for all three implemented 'spline' flows, containing three new coupling layers
They all use a third-party package called torchsearchsorted to sort the
x_b
into the correct bin. This has the advantage over the numpy version in that, even if the bin boundaries vary within the batch, it can still sort the entire batch in one go.However, this does require reshaping the
x_b
tensor into a tensor of dimensions(n_batch * size_half, 1)
so that the sorting can be done with one call tosearchsorted
. It does make me uneasy to mix the batch and lattice dimensions, but it's just for one step and it's a lot faster than looping!One thing I haven't done is check for a valid base distribution - currently if you use the linear or quadratic splines with any base distribution not defined on [0, 1] it will break. Perhaps this is a good time to us reportengine checks?
If we wanted to improve performance we could, in future, try the "one blob encoding" discussed in the https://arxiv.org/pdf/1808.03856.pdf, which involves replacing a single 'x' with a Gaussian that activates several bins at once.