google-deepmind / distrax

Apache License 2.0
529 stars 32 forks source link

Update categorical.py #236

Closed cyprienc closed 1 year ago

cyprienc commented 1 year ago

Setting dtype in jax.nn.one_hot calls to avoid a return dtype different from the parameters' dtypes.

franrruiz commented 1 year ago

Thank you @cyprienc for your contribution to Distrax!