Open LukeWood opened 1 year ago
Thanks for the suggestion. Are these methods / attributes widely used? Do we really need a class as opposed to just functions (like we already have normal
, truncated_normal
, uniform
)? What are the probability distributions we should include?
For instance JAX includes many probability distributions (e.g. in the scipy module) but they're all functions. https://jax.readthedocs.io/en/latest/jax.scipy.html
A good question 🤔 the answer is not clear to me.
One thing that makes the distribution class feel like an anti pattern is that I feel you will always immediately invoke the method.
Ie: Categorical(logits=…).log_prob()
so it does seem like a better fit for functions. The only issue is then the namespace could get a bit ugly:
categorical_log_prob(logits=…, samples=…)
it reminds me of what we have for data types in KerasCV. Every type has a format converter or two - maybe it makes sense to namespace them somehow, though I know that’s not common practice in keras.
The only issue is then the namespace could get a bit ugly
If there is a standard precedent for it (like scipy and JAX) this should not be a big issue. From the user's perspective it's about the same amount of information to ingest but somewhat easier to manage with functions -- navigating a longer list of functions that all work roughly the same, vs navigating the combination of a medium-size list of classes and a medium-size list of accompanying methods (nested complexity is worse).
Hello everyone! I recently wrote https://github.com/keras-team/keras-io/pull/1440 and I'd like to make this example fully keras-core friendly. This requires implementations of probability distributions at some sort.
While reviewing TensorFlow probability distributions & torch.distribution I noticed that the distribution classes house many methods, not just a
sample()
method!Roughly they share (from TFP):
Reference: https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/Distribution
I personally think we should introduce some probability distribution class to Keras core! We can back them with numpy()/tfp.Categorical()/torch.distributions().
It might be a bit of work to make the APIs match - but seems worthwhile. Once we agree on a class design, I'd love to take a stab at adding one!