Extension of #48 which uses a coupling transformation that is quadratic in the input data.
Also introduced in https://arxiv.org/pdf/1808.03856.pdf.
The bin widths are part of the neural net output (rather than being a global constant), which means the bisection search has to be done on a per-unit basis. To avoiding looping, which is very slow, the x_b tensor is temporarily reshaped into a tensor of dimensions (n_batch * size_half, 1) so that the sorting can be done with one call to searchsorted. It does make me uneasy to mix the batch and lattice dimensions, but it's just for one step and it's a lot faster than looping!
Extension of #48 which uses a coupling transformation that is quadratic in the input data. Also introduced in https://arxiv.org/pdf/1808.03856.pdf.
The bin widths are part of the neural net output (rather than being a global constant), which means the bisection search has to be done on a per-unit basis. To avoiding looping, which is very slow, the
x_b
tensor is temporarily reshaped into a tensor of dimensions(n_batch * size_half, 1)
so that the sorting can be done with one call tosearchsorted
. It does make me uneasy to mix the batch and lattice dimensions, but it's just for one step and it's a lot faster than looping!