google-deepmind / distrax

Apache License 2.0
535 stars 32 forks source link

Possible feature request: `distrax` equivalent of `tfp.layers` #185

Closed homerjed closed 1 year ago

homerjed commented 2 years ago

Hi,

Big fan of the library. I was hoping to create a model like this one from here and was wondering if you plan on making any features analogous to those in tensorflow_probability.layers?

model = tfk.Sequential([
    tfkl.Dense(tfpl.MultivariateNormalTriL.params_size(d)),
    tfpl.MultivariateNormalTriL(d),
])

I have a haiku / distrax implementation of the above layer, but I am missing something to consistently constrain "allowable" the covariance matrix (full) and mean of the Gaussian distribution.

Cheers