Closed cyprienc closed 1 year ago
Setting dtype in jax.nn.one_hot calls to avoid a return dtype different from the parameters' dtypes.
Thank you @cyprienc for your contribution to Distrax!
Setting dtype in jax.nn.one_hot calls to avoid a return dtype different from the parameters' dtypes.