keras-team / tf-keras

The TensorFlow-specific implementation of the Keras API, which was the default Keras from 2019 to 2023.
Apache License 2.0
62 stars 28 forks source link

Cant slice along batch for MixtureSameDistribution produced by Keras layer #309

Open henrypinkard opened 1 year ago

henrypinkard commented 1 year ago

I'm trying to make some dense layers that output the parameters of Gaussian mixture (a mixture density network). I want to run a batch of data through the network (for speed), get out a batch of distributions, and then slice to work with only some elements of the batch at a time. If I were doing this with just tfp, I call:

import tensorflow_probability as tfp
tfd = tfp.distributions

num_mixture_components = 12
batch_size = 4

probs = np.random.rand(batch_size, num_mixture_components)
loc = np.random.rand(batch_size, num_mixture_components)
scale = np.random.rand(batch_size, num_mixture_components)

gm = tfd.MixtureSameFamily(
      mixture_distribution=tfd.Categorical(probs=probs),
      components_distribution=tfd.Normal(loc=loc, scale=scale))  

print(gm)
# slice along batch
print(gm[1])

This works as expected giving

tfp.distributions.MixtureSameFamily("MixtureSameFamily", batch_shape=[4], event_shape=[], dtype=float64)
tfp.distributions.MixtureSameFamily("MixtureSameFamily", batch_shape=[], event_shape=[], dtype=float64)

However when I try the same thing with a mixture density network in Keras I get an error

import tensorflow.keras.layers as tfkl
import tensorflow.keras as tfk

num_mixture_components = 12

l = tfkl.Input(shape=(100))

# Make a fully connected network that outputs parameters of Gaussian mixture
mu = tfkl.Dense(units=num_mixture_components, activation=None)(l)
sigma = tfkl.Dense(units=num_mixture_components, activation='softplus')(l)
alpha = tfkl.Dense(units=num_mixture_components, activation='softmax')(l)
stacked = tfkl.Concatenate()([mu, sigma, alpha])

mixture = tfp.layers.MixtureNormal(num_mixture_components, 
                                        event_shape=[], name="test")(stacked)
model = tf.keras.Model(inputs=l, outputs=mixture)

out = model(np.random.rand(4, 100))

gm = out.tensor_distribution;

print(gm)
# slice along batch
print(gm[1])

Gives me this cryptic error:

tfp.distributions._MixtureSameFamily("MixtureSameFamily", batch_shape=[4], event_shape=[], dtype=float32)

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In [72], line 27
     25 print(gm)
     26 # slice along batch
---> 27 print(gm[1])

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/distributions/distribution.py:852, in Distribution.__getitem__(self, slices)
    825 def __getitem__(self, slices):
    826   """Slices the batch axes of this distribution, returning a new instance.
    827 
    828   ```python
   (...)
    850     dist: A new `tfd.Distribution` instance with sliced parameters.
    851   """
--> 852   return slicing.batch_slice(self, {}, slices)

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/slicing.py:220, in batch_slice(batch_object, params_overrides, slices, bijector_x_event_ndims)
    217 slice_overrides_seq = slice_overrides_seq + [(slices, params_overrides)]
    218 # Re-doing the full sequence of slice+copy override work here enables
    219 # gradients all the way back to the original batch_objectribution's arguments.
--> 220 batch_object = _apply_slice_sequence(
    221     orig_batch_object,
    222     slice_overrides_seq,
    223     bijector_x_event_ndims=bijector_x_event_ndims)
    224 setattr(batch_object,
    225         PROVENANCE_ATTR,
    226         batch_object._no_dependency((orig_batch_object, slice_overrides_seq)))  # pylint: disable=protected-access
    227 return batch_object

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/slicing.py:179, in _apply_slice_sequence(batch_object, slice_overrides_seq, bijector_x_event_ndims)
    177 """Applies a sequence of slice or copy-with-overrides operations to `batch_object`."""
    178 for slices, overrides in slice_overrides_seq:
--> 179   batch_object = _apply_single_step(
    180       batch_object,
    181       slices,
    182       overrides,
    183       bijector_x_event_ndims=bijector_x_event_ndims)
    184 return batch_object

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/slicing.py:168, in _apply_single_step(batch_object, slices, params_overrides, bijector_x_event_ndims)
    166   override_dict = {}
    167 else:
--> 168   override_dict = _slice_params_to_dict(
    169       batch_object, slices, bijector_x_event_ndims=bijector_x_event_ndims)
    170 override_dict.update(params_overrides)
    171 parameters = dict(batch_object.parameters, **override_dict)

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/slicing.py:153, in _slice_params_to_dict(batch_object, slices, bijector_x_event_ndims)
    150 else:
    151   batch_shape = batch_object.experimental_batch_shape_tensor(
    152       x_event_ndims=bijector_x_event_ndims)
--> 153 return batch_shape_lib.map_fn_over_parameters_with_event_ndims(
    154     batch_object,
    155     functools.partial(_slice_single_param,
    156                       slices=slices,
    157                       batch_shape=batch_shape),
    158     bijector_x_event_ndims=bijector_x_event_ndims)

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/batch_shape_lib.py:367, in map_fn_over_parameters_with_event_ndims(batch_object, fn, bijector_x_event_ndims, require_static, **parameter_kwargs)
    361     elif (properties.is_tensor
    362           and not tf.is_tensor(param)
    363           and not tf.nest.is_nested(param_event_ndims)):
    364       # As a last resort, try an explicit conversion.
    365       param = tensor_util.convert_nonref_to_tensor(param, name=param_name)
--> 367   results[param_name] = nest.map_structure_up_to(
    368       param, fn, param, param_event_ndims)
    369 return results

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow/python/util/nest.py:1435, in map_structure_up_to(shallow_tree, func, *inputs, **kwargs)
   1361 @tf_export("__internal__.nest.map_structure_up_to", v1=[])
   1362 def map_structure_up_to(shallow_tree, func, *inputs, **kwargs):
   1363   """Applies a function or op to a number of partially flattened inputs.
   1364 
   1365   The `inputs` are flattened up to `shallow_tree` before being mapped.
   (...)
   1433     `shallow_tree`.
   1434   """
-> 1435   return map_structure_with_tuple_paths_up_to(
   1436       shallow_tree,
   1437       lambda _, *values: func(*values),  # Discards the path arg.
   1438       *inputs,
   1439       **kwargs)

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow/python/util/nest.py:1535, in map_structure_with_tuple_paths_up_to(shallow_tree, func, *inputs, **kwargs)
   1526 flat_value_gen = (
   1527     flatten_up_to(  # pylint: disable=g-complex-comprehension
   1528         shallow_tree,
   1529         input_tree,
   1530         check_types,
   1531         expand_composites=expand_composites) for input_tree in inputs)
   1532 flat_path_gen = (
   1533     path
   1534     for path, _ in _yield_flat_up_to(shallow_tree, inputs[0], is_nested_fn))
-> 1535 results = [
   1536     func(*args, **kwargs) for args in zip(flat_path_gen, *flat_value_gen)
   1537 ]
   1538 return pack_sequence_as(structure=shallow_tree, flat_sequence=results,
   1539                         expand_composites=expand_composites)

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow/python/util/nest.py:1536, in <listcomp>(.0)
   1526 flat_value_gen = (
   1527     flatten_up_to(  # pylint: disable=g-complex-comprehension
   1528         shallow_tree,
   1529         input_tree,
   1530         check_types,
   1531         expand_composites=expand_composites) for input_tree in inputs)
   1532 flat_path_gen = (
   1533     path
   1534     for path, _ in _yield_flat_up_to(shallow_tree, inputs[0], is_nested_fn))
   1535 results = [
-> 1536     func(*args, **kwargs) for args in zip(flat_path_gen, *flat_value_gen)
   1537 ]
   1538 return pack_sequence_as(structure=shallow_tree, flat_sequence=results,
   1539                         expand_composites=expand_composites)

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow/python/util/nest.py:1437, in map_structure_up_to.<locals>.<lambda>(_, *values)
   1361 @tf_export("__internal__.nest.map_structure_up_to", v1=[])
   1362 def map_structure_up_to(shallow_tree, func, *inputs, **kwargs):
   1363   """Applies a function or op to a number of partially flattened inputs.
   1364 
   1365   The `inputs` are flattened up to `shallow_tree` before being mapped.
   (...)
   1433     `shallow_tree`.
   1434   """
   1435   return map_structure_with_tuple_paths_up_to(
   1436       shallow_tree,
-> 1437       lambda _, *values: func(*values),  # Discards the path arg.
   1438       *inputs,
   1439       **kwargs)

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/slicing.py:101, in _slice_single_param(param, param_event_ndims, slices, batch_shape)
     85 """Slices into the batch shape of a single parameter.
     86 
     87 Args:
   (...)
     98     `slices`.
     99 """
    100 # Broadcast the parmameter to have full batch rank.
--> 101 param = batch_shape_lib.broadcast_parameter_with_batch_shape(
    102     param, param_event_ndims, ps.ones_like(batch_shape))
    103 param_batch_shape = batch_shape_lib.get_batch_shape_tensor_part(
    104     param, param_event_ndims)
    105 # At this point the param should have full batch rank, *unless* it's an
    106 # atomic object like `tfb.Identity()` incapable of having any batch rank.

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/batch_shape_lib.py:274, in broadcast_parameter_with_batch_shape(param, param_event_ndims, batch_shape)
    270 base_shape = ps.concat([batch_shape,
    271                         ps.ones([param_event_ndims], dtype=np.int32)],
    272                        axis=0)
    273 if hasattr(param, '_broadcast_parameters_with_batch_shape'):
--> 274   return param._broadcast_parameters_with_batch_shape(base_shape)  # pylint: disable=protected-access
    275 elif hasattr(param, 'matmul'):
    276   # TODO(davmre): support broadcasting LinearOperator parameters.
    277   return param

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/distributions/distribution.py:952, in Distribution._broadcast_parameters_with_batch_shape(self, batch_shape)
    926 def _broadcast_parameters_with_batch_shape(self, batch_shape):
    927   """Broadcasts each parameter's batch shape with the given `batch_shape`.
    928 
    929   This is semantically equivalent to wrapping with the `BatchBroadcast`
   (...)
    950       the given `batch_shape`.
    951   """
--> 952   return self.copy(
    953       **batch_shape_lib.broadcast_parameters_with_batch_shape(
    954           self, batch_shape))

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/distributions/distribution.py:915, in Distribution.copy(self, **override_parameters_kwargs)
    897 """Creates a deep copy of the distribution.
    898 
    899 Note: the copy distribution may continue to depend on the original
   (...)
    909     `dict(self.parameters, **override_parameters_kwargs)`.
    910 """
    911 try:
    912   # We want track provenance from origin variables, so we use batch_slice
    913   # if this distribution supports slicing. See the comment on
    914   # PROVENANCE_ATTR in batch_slicing.py
--> 915   return slicing.batch_slice(self, override_parameters_kwargs, Ellipsis)
    916 except NotImplementedError:
    917   pass

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/slicing.py:220, in batch_slice(batch_object, params_overrides, slices, bijector_x_event_ndims)
    217 slice_overrides_seq = slice_overrides_seq + [(slices, params_overrides)]
    218 # Re-doing the full sequence of slice+copy override work here enables
    219 # gradients all the way back to the original batch_objectribution's arguments.
--> 220 batch_object = _apply_slice_sequence(
    221     orig_batch_object,
    222     slice_overrides_seq,
    223     bijector_x_event_ndims=bijector_x_event_ndims)
    224 setattr(batch_object,
    225         PROVENANCE_ATTR,
    226         batch_object._no_dependency((orig_batch_object, slice_overrides_seq)))  # pylint: disable=protected-access
    227 return batch_object

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/slicing.py:179, in _apply_slice_sequence(batch_object, slice_overrides_seq, bijector_x_event_ndims)
    177 """Applies a sequence of slice or copy-with-overrides operations to `batch_object`."""
    178 for slices, overrides in slice_overrides_seq:
--> 179   batch_object = _apply_single_step(
    180       batch_object,
    181       slices,
    182       overrides,
    183       bijector_x_event_ndims=bijector_x_event_ndims)
    184 return batch_object

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/slicing.py:172, in _apply_single_step(batch_object, slices, params_overrides, bijector_x_event_ndims)
    170 override_dict.update(params_overrides)
    171 parameters = dict(batch_object.parameters, **override_dict)
--> 172 return type(batch_object)(**parameters)

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/decorator.py:231, in decorate.<locals>.fun(*args, **kw)
    229 def fun(*args, **kw):
    230     if not kwsyntax:
--> 231         args, kw = fix(args, kw, sig)
    232     return caller(func, *(extras + args), **kw)

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/decorator.py:203, in fix(args, kwargs, sig)
    199 def fix(args, kwargs, sig):
    200     """
    201     Fix args and kwargs to be consistent with the signature
    202     """
--> 203     ba = sig.bind(*args, **kwargs)
    204     ba.apply_defaults()  # needed for test_dan_schult
    205     return ba.args, ba.kwargs

File ~/mambaforge/envs/phenotypes/lib/python3.10/inspect.py:3179, in Signature.bind(self, *args, **kwargs)
   3174 def bind(self, /, *args, **kwargs):
   3175     """Get a BoundArguments object, that maps the passed `args`
   3176     and `kwargs` to the function's signature.  Raises `TypeError`
   3177     if the passed arguments can not be bound.
   3178     """
-> 3179     return self._bind(args, kwargs)

File ~/mambaforge/envs/phenotypes/lib/python3.10/inspect.py:3168, in Signature._bind(self, args, kwargs, partial)
   3166         arguments[kwargs_param.name] = kwargs
   3167     else:
-> 3168         raise TypeError(
   3169             'got an unexpected keyword argument {arg!r}'.format(
   3170                 arg=next(iter(kwargs))))
   3172 return self._bound_arguments_cls(self, arguments)

TypeError: got an unexpected keyword argument 'reinterpreted_batch_ndims'

v2.11.0-rc2-17-gd5b57ca93e5 2.11.0 tensorflow-probability 0.19.0 Python 3.10.6 Ubuntu 18.04

Any ideas why this is happening and how to fix?

tilakrayal commented 1 year ago

@henrypinkard, The reinterpreted_batch_ndims parameter controls the number of batch dims which are absorbed as event dims; reinterpreted_batch_ndims <= len(batch_shape). For example, the _logprob function entails a reduce_sum over the rightmost reinterpreted_batch_ndims after calling the base distribution's log_prob. Also the batch dimension(s) index independent distributions, the resultant multivariate will have independent components. Also reinterpreted_batch_ndims is part of tensorflow probability(tfp), So i request you to check in this repo for more assistance. Thank you!

henrypinkard commented 1 year ago

Thanks for the explanation. I also opened an issue on TFP (https://github.com/tensorflow/probability/issues/1679), though no response yet

tilakrayal commented 1 year ago

@henrypinkard, Could ypu please feel free to move this issue to closed status, since it is already being tracked there? Thank you!

google-ml-butler[bot] commented 1 year ago

This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you.

henrypinkard commented 1 year ago

Still no activity on the TFP issue or clarity no how this might be fixed