tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.26k stars 1.1k forks source link

Autoregressive flow on jax #1169

Open mattwescott opened 3 years ago

mattwescott commented 3 years ago

The docstring for MaskedAutoregressiveFlow suggests AutoregressiveNetwork, which fails on the jax substrate due to Keras dependencies. If there are still missing pieces, are you interested in contributions?

@sharadmv @SiegeLordEx

brianwa84 commented 3 years ago

We haven't yet settled on a "right" way to do NNs in JAX. For now, an approach you can use is passing kwargs to log_prob, something like below.

If you have a proposal for how to integrate these bijectors (also thinking about RealNVP) with flax/objax/haiku/oryx, it'd be great to share that. Could perhaps send a PR to add a notebook under discussion/examples.

np.random.seed(213)

n = 2000 x2 = np.random.randn(n).astype(dtype=np.float32) 2. x1 = np.random.randn(n).astype(dtype=np.float32) + (x2 x2 / 4.) data = np.stack([x1, x2], axis=-1) scatter(*data.T);title('data');show()

def make_bijector(y, layers=None): for layer in layers[:-1]:

print(y.shape, [l.shape for l in layer])

y = jnp.tanh(jnp.einsum('...b,ba->...a', y, layer[0]) + layer[1])

print(y.shape, [l.shape for l in layers[-1]])

y = jnp.einsum('...b,ba->...a', y, layers[-1][0]) + layers[-1][1]

print('final',y.shape)

return tfb.Shift(y[..., :1])(tfb.Scale(log_scale=y[..., 1:]))

layers=[(np.random.normal(size=[2, 20]), np.random.normal(size=[20])), (np.random.normal(size=[20, 10]), np.random.normal(size=[10])), (np.random.normal(size=[10, 10]), np.random.normal(size=[10])), (np.random.normal(size=[10, 2]), np.random.normal(size=[2]))] total_steps = 0

made = tfp.bijectors.MaskedAutoregressiveFlow(bijector_fn=make_bijector)

print(made.forward(np.random.uniform(size=[5, 2]), layers=layers))

print(made.inverse(np.random.uniform(size=[2000, 2]), layers=layers))

distribution = tfd.TransformedDistribution( distribution=tfd.Sample(tfd.Normal(loc=0., scale=1.), sample_shape=[2]), bijector=made)

print(distribution.sample(len(data), bijector_kwargs=dict(layers=layers),

seed=jax.random.PRNGKey(123)).shape)

print(distribution.log_prob(data, bijector_kwargs=dict(layers=layers)))

lr = 1e-12 n=200 for i in range(n): lp, f_vjp = jax.vjp( lambda layers: ( distribution.log_prob(data, bijector_kwargs=dict(layers=layers)).sum() + # likelihood -jnp.sum([tfd.Normal(0, 1).log_prob(l).sum() for l in tree_util.tree_leaves(layers)])), # KL layers) if i % 10 == 0 or i == n-1: print(lp.sum(), total_steps) if not np.isfinite(lp).all(): raise ValueError('found nan/inf') if lp.sum() > -1e7: lr = 1e-8 if lp.sum() > -1e5: lr = 1e-6 if lp.sum() > -2e4: lr = 1e-5 grads = f_vjp(jnp.ones_like(lp) * lr) layers_flat, tree = tree_util.tree_flatten(layers) grads_flat = tree_util.tree_leaves(grads) layers = tree_util.tree_unflatten(tree, [l + g for (l, g) in zip(layers_flat, grads_flat)]) total_steps += 1

scatter(*distribution.sample(len(data), seed=jax.random.PRNGKey(123), bijector_kwargs=dict(layers=layers)).T) title('samples after fit') xlim(data[:,0].min(),data[:,0].max()) ylim(data[:,1].min(),data[:,1].max()) show()

hexbin(*made.inverse(data, layers=layers).T) title('pulled back data') show()

Brian Patton | Software Engineer | bjp@google.com

On Mon, Nov 16, 2020 at 8:29 AM Matt Wescott notifications@github.com wrote:

The docstring for MaskedAutoregressiveFlow suggests AutoregressiveNetwork, which fails on the jax substrate due to Keras dependencies. If there are still missing pieces, are you interested in contributions?

@sharadmv https://github.com/sharadmv @SiegeLordEx https://github.com/SiegeLordEx

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/tensorflow/probability/issues/1169, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSI2G663HVCAC33VD3C3SQESLJANCNFSM4TXFJSCA .

SiegeLordEx commented 3 years ago

The current recommendation is to import _make_dense_autoregressive_masks from the masked_autoregressive module and then use it to construct the MADE network like this using your favorite JAX nn library: https://github.com/tensorflow/probability/blob/84257b1071b35a8c20bf91ba413ccf0f3a313d7b/tensorflow_probability/python/bijectors/masked_autoregressive.py#L1166-L1185

Internally, we've experimented wring this in Flax and Oryx, but at least the Flax version it wasn't ideal. Flax recently released a new version of their API (Flax 2, aka Linen), so at the very least (given Flax's popularity), we'd appreciate a way to construct a Flax 2 MADE that would work well within the bijector framework. Looking at the Flax 2's examples, I think they appear to fit well within the conditional bijector framework that Brian showed above (where you use bijector_kwargs), but perhaps there's a more idiomatic way to do it. Ultimately, we're looking for an idiomatic integration that would feel natural to the JAX users.

Tell us if you have any ideas (otherwise, we don't have internal plans to push on this internally in the immediate future).

mattwescott commented 3 years ago

Thanks @brianwa84 and @SiegeLordEx. Here is an example with multiple layers using Haiku. This approach seems tolerable for my current usecases, but I could be missing something important. Comments, ideas, patches all appreciated.

And looking further ahead, Oryx has a promising set of abstractions, curious to see that implementation.

SiegeLordEx commented 3 years ago

Thanks @mattwescott, that looks really cool! If you have time, would you be willing to send a pull request to add it to core TFP? I think it's a useful example. The only change you'd need to make, as far as I can tell, is to add one of our standard headers (see the other examples) and there's a few very minor typos.

xaviergonzalez commented 1 year ago

Has there been any movement on this issue? When I run this very simple code in a google colab:

import jax
import jax.numpy as jnp
import jax.nn as jnn
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
base_dist = tfd.MultivariateNormalDiag(loc=jnp.zeros(2), scale_diag=jnp.ones(2))
maf = tfd.TransformedDistribution(
    distribution=base_dist,
    bijector=tfb.MaskedAutoregressiveFlow(
        shift_and_log_scale_fn=tfb.AutoregressiveNetwork(
            params=2, hidden_units=[512, 512])))

I get the error

AttributeError                            Traceback (most recent call last)
[<ipython-input-3-26736552cb30>](https://localhost:8080/#) in <cell line: 2>()
      1 base_dist = tfd.MultivariateNormalDiag(loc=jnp.zeros(2), scale_diag=jnp.ones(2))
----> 2 made = tfb.AutoregressiveNetwork(params=2, hidden_units=[32, 32], activation=jnn.relu)
      3 maf = tfd.TransformedDistribution(distribution=base_dist, bijector=made)

[/usr/local/lib/python3.10/dist-packages/tensorflow_probability/substrates/jax/bijectors/masked_autoregressive.py](https://localhost:8080/#) in __init__(self, params, event_shape, conditional, conditional_event_shape, conditional_input_layers, hidden_units, input_order, hidden_degrees, activation, use_bias, kernel_initializer, bias_initializer, kernel_regularizer, bias_regularizer, kernel_constraint, bias_constraint, validate_args, **kwargs)
    965     self._kernel_regularizer = kernel_regularizer
    966     self._bias_regularizer = bias_regularizer
--> 967     self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
    968     self._bias_constraint = bias_constraint
    969     self._validate_args = validate_args

AttributeError: module 'tensorflow_probability.python.internal.backend.jax.numpy_keras' has no attribute 'constraints'

Am I to understand from the above discussion that tfb.AutoregressiveNetwork does not work according to its docs in the jax substrate? and can only be made to work by somewhat manually constructing the masks?