keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.69k stars 19.43k forks source link

Probability Distributions in Keras-Core #18435

Open LukeWood opened 1 year ago

LukeWood commented 1 year ago

Hello everyone! I recently wrote https://github.com/keras-team/keras-io/pull/1440 and I'd like to make this example fully keras-core friendly. This requires implementations of probability distributions at some sort.

While reviewing TensorFlow probability distributions & torch.distribution I noticed that the distribution classes house many methods, not just a sample() method!

Roughly they share (from TFP):

Reference: https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/Distribution


I personally think we should introduce some probability distribution class to Keras core! We can back them with numpy()/tfp.Categorical()/torch.distributions().

It might be a bit of work to make the APIs match - but seems worthwhile. Once we agree on a class design, I'd love to take a stab at adding one!

fchollet commented 1 year ago

Thanks for the suggestion. Are these methods / attributes widely used? Do we really need a class as opposed to just functions (like we already have normal, truncated_normal, uniform)? What are the probability distributions we should include?

For instance JAX includes many probability distributions (e.g. in the scipy module) but they're all functions. https://jax.readthedocs.io/en/latest/jax.scipy.html

LukeWood commented 1 year ago

A good question 🤔 the answer is not clear to me.

One thing that makes the distribution class feel like an anti pattern is that I feel you will always immediately invoke the method.

Ie: Categorical(logits=…).log_prob()

so it does seem like a better fit for functions. The only issue is then the namespace could get a bit ugly:

categorical_log_prob(logits=…, samples=…)

it reminds me of what we have for data types in KerasCV. Every type has a format converter or two - maybe it makes sense to namespace them somehow, though I know that’s not common practice in keras.

fchollet commented 1 year ago

The only issue is then the namespace could get a bit ugly

If there is a standard precedent for it (like scipy and JAX) this should not be a big issue. From the user's perspective it's about the same amount of information to ingest but somewhat easier to manage with functions -- navigating a longer list of functions that all work roughly the same, vs navigating the combination of a medium-size list of classes and a medium-size list of accompanying methods (nested complexity is worse).