tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.24k stars 1.1k forks source link

rational quadratic spline does not work with batchsize>1 #941

Closed VMBoehm closed 4 years ago

VMBoehm commented 4 years ago

https://github.com/tensorflow/probability/blob/f051e03dd3cc847d31061803c2b31c564562a993/tensorflow_probability/python/bijectors/rational_quadratic_spline.py#L59-L409

The given example in the documentation as well as my own code run fine with batchsize=1. Increasing the batchsize raises following error (batchsize is 2 in this example and data dim 64, out of which half gets masked):

InvalidArgumentError: 2 root error(s) found. (0) Invalid argument: Incompatible shapes: [2,32] vs. [64] [[{{node mynvp_1/log_prob/chain_of_MatvecLU_of_real_nvp_of_MatvecLU/inverse/real_nvp/inverse/mynvp_1_log_prob_chain_of_MatvecLU_of_real_nvp_of_MatvecLU_inverse_real_nvp_inverse_RationalQuadraticSpline/inverse/GreaterEqual}}]] [[Neg/_385]] (1) Invalid argument: Incompatible shapes: [2,32] vs. [64] [[{{node mynvp_1/log_prob/chain_of_MatvecLU_of_real_nvp_of_MatvecLU/inverse/real_nvp/inverse/mynvp_1_log_prob_chain_of_MatvecLU_of_real_nvp_of_MatvecLU_inverse_real_nvp_inverse_RationalQuadraticSpline/inverse/GreaterEqual}}]] 0 successful operations. 0 derived errors ignored.

Traceback: x1 = self._bijector_fn(y0, self._bijector_input_units(), File "/global/homes/v/vboehm/.conda/envs/tf22/lib/python3.8/site-packages/tensorflow_probability/python/bijectors/bijector.py", line 1086, in inverse return self._call_inverse(y, name, kwargs) File "/global/homes/v/vboehm/.conda/envs/tf22/lib/python3.8/site-packages/tensorflow_probability/python/bijectors/bijector.py", line 1058, in _call_inverse mapping = mapping.merge(x=self._inverse(y, kwargs)) File "/global/homes/v/vboehm/.conda/envs/tf22/lib/python3.8/site-packages/tensorflow_probability/python/bijectors/rational_quadratic_spline.py", line 307, in _inverse d = self._compute_shared(y=y) File "/global/homes/v/vboehm/.conda/envs/tf22/lib/python3.8/site-packages/tensorflow_probability/python/bijectors/rational_quadratic_spline.py", line 245, in _compute_shared out_of_bounds = (x_or_y <= kx_or_ky_min) | (x_or_y >= kx_or_ky_max) File "/global/homes/v/vboehm/.conda/envs/tf22/lib/python3.8/site-packages/tensorflow/python/ops/gen_math_ops.py", line 4049, in greaterequal , _, _op, _outputs = _op_def_library._apply_op_helper( File "/global/homes/v/vboehm/.conda/envs/tf22/lib/python3.8/site-packages/tensorflow/python/framework/op_def_library.py", line 742, in _apply_op_helper op = g._create_op_internal(op_type_name, inputs, dtypes=None, File "/global/homes/v/vboehm/.conda/envs/tf22/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 3319, in _create_op_internal ret = Operation( File "/global/homes/v/vboehm/.conda/envs/tf22/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 1791, in init self._traceback = tf_stack.extract_stack()

Tried fixing it at by changing the shapes, which resolves the tf.where error at this point but triggers further shape errors down the road.

VMBoehm commented 4 years ago

some reshaping acrobatics solved it

srbittner commented 3 years ago

some reshaping acrobatics solved it

@VMBoehm Thank you for raising this issue. Did you have to edit the Bijector source code, or did you just edit the SplineParams() example module?

The example in the documentation should be updated to support batch size > 1. This is really the only sensible use case for normalizing flows.

srbittner commented 3 years ago

I didn't have to change the source code in https://github.com/tensorflow/probability/blob/f051e03dd3cc847d31061803c2b31c564562a993/tensorflow_probability/python/bijectors/rational_quadratic_spline.py#L59-L409 to make it work for batch size > 1. Just the following changes to the provided example.

class SplineParams(tf.Module):

    def __init__(self, nunits, nbins=32):
        self._nunits = nunits
        self._nbins = nbins
        self._built = False
        self._bin_widths = None
        self._bin_heights = None
        self._knot_slopes = None

    def _bin_positions(self, x):
        x = tf.reshape(x, [-1, self._nunits, self._nbins])
        return tf.math.softmax(x, axis=-1) * (2 - self._nbins * 1e-2) + 1e-2

    def _slopes(self, x):
        x = tf.reshape(x, [-1, self._nunits, self._nbins - 1])
        return tf.math.softplus(x) + 1e-2

    def __call__(self, x, nunits):
        if not self._built:
            self._bin_widths = tf.keras.layers.Dense(
              nunits * self._nbins, activation=self._bin_positions, name='w')
            self._bin_heights = tf.keras.layers.Dense(
              nunits * self._nbins, activation=self._bin_positions, name='h')
            self._knot_slopes = tf.keras.layers.Dense(
              nunits * (self._nbins - 1), activation=self._slopes, name='s')
            self._built = True
        return tfb.RationalQuadraticSpline(
            bin_widths=self._bin_widths(x),
            bin_heights=self._bin_heights(x),
            knot_slopes=self._knot_slopes(x))

N = 100 # batch size
D = 15 # dimensionality
nsplits = 3

xs = np.random.randn(N, D).astype(np.float32)  # Keras won't Dense(.)(vec).
nmasked = [5*i for i in range(nsplits)] # dimensions to mask in RealNVP
nunits = [D - x for x in nmasked]
splines = [SplineParams(nunits[i]) for i in range(nsplits)]

def spline_flow():
    stack = tfb.Identity()
    for i in range(nsplits):
        stack = tfb.RealNVP(nmasked[i], bijector_fn=splines[i])(stack)
    return stack

ys = spline_flow().forward(xs)
ys_inv = spline_flow().inverse(ys)  # ys_inv ~= xs
assert(np.isclose(xs, ys_inv).all())

I'd be happy to make a PR for this change.

brianwa84 commented 3 years ago

Please do, you're much more recently acquainted w/ that code than I am at this point!

On Wed, Dec 16, 2020 at 12:21 PM Sean Robert Bittner < notifications@github.com> wrote:

I didn't have to change the source code in https://github.com/tensorflow/probability/blob/f051e03dd3cc847d31061803c2b31c564562a993/tensorflow_probability/python/bijectors/rational_quadratic_spline.py#L59-L409 to make it work for batch size > 1. Just the following changes to the provided example.

class SplineParams(tf.Module):

def __init__(self, nunits, nbins=32):
    self._nunits = nunits
    self._nbins = nbins
    self._built = False
    self._bin_widths = None
    self._bin_heights = None
    self._knot_slopes = None

def _bin_positions(self, x):
    x = tf.reshape(x, [-1, self._nunits, self._nbins])
    return tf.math.softmax(x, axis=-1) * (2 - self._nbins * 1e-2) + 1e-2

def _slopes(self, x):
    x = tf.reshape(x, [-1, self._nunits, self._nbins - 1])
    return tf.math.softplus(x) + 1e-2

def __call__(self, x, nunits):
    if not self._built:
        self._bin_widths = tf.keras.layers.Dense(
          nunits * self._nbins, activation=self._bin_positions, name='w')
        self._bin_heights = tf.keras.layers.Dense(
          nunits * self._nbins, activation=self._bin_positions, name='h')
        self._knot_slopes = tf.keras.layers.Dense(
          nunits * (self._nbins - 1), activation=self._slopes, name='s')
        self._built = True
    return tfb.RationalQuadraticSpline(
        bin_widths=self._bin_widths(x),
        bin_heights=self._bin_heights(x),
        knot_slopes=self._knot_slopes(x))
N = 100 # batch sizeD = 15 # dimensionalitynsplits = 3

xs = np.random.randn(N, D).astype(np.float32) # Keras won't Dense(.)(vec).nmasked = [5*i for i in range(nsplits)] # dimensions to mask in RealNVPnunits = [D - x for x in nmasked]splines = [SplineParams(nunits[i]) for i in range(nsplits)] def spline_flow(): stack = tfb.Identity() for i in range(nsplits): stack = tfb.RealNVP(nmasked[i], bijector_fn=splines[i])(stack) return stack ys = spline_flow().forward(xs)ys_inv = spline_flow().inverse(ys) # ys_inv ~= xsassert(np.isclose(xs, ys_inv).all())

I'd be happy to make a PR for this change.

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/tensorflow/probability/issues/941#issuecomment-746670538, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSI6MSQFLOXBXZG7XEGTSVDUAFANCNFSM4NFPC5FA .