Closed PhilipVinc closed 3 years ago
I'd say that jax.random.truncated_normal
is defined correctly as is, with no complex dtype support, since there isn't a canonical notion of the truncated normal distribution over complex numbers.
This situation seems to suggest a feature request for Flax—if it isn't already supported—of accepting custom distributions for initialization (or anything else along those lines). You could then write and supply whatever distribution you have in mind, whether that's with a box or spherical constraint. It seems like you've requested roughly this in google/flax#805.
Discussion seems to have picked up on the Flax thread, where it probably ought to go for now, so I'm going to close this. Feel free to reopen if it comes back to jax.random
.
Hello @froystig . I agree that there is no consensus on how to generalize jax.random.uniform
and truncated_normal
to complex numbers (while normal
is already done in #4805), and I'd suggest to implement complex variance scaling initializers in jax.nn.initializers
(and flax.nn.initializers
just imports them).
The generalization of glorot_normal
to complex numbers is discussed in, e.g., C. Trabelsi, et al., Deep Complex Networks, 2017 and widely cited after that. They proposed an axisymmetric distribution implemented by first sampling the modulus from the radial CDF, then uniformly sampling the phase. It's not hard to truncate the modulus, although not discussed in that paper.
Also, I'd suggest to generalize glorot_uniform
to a uniform disk, rather than a uniform square. It makes more sense in the context of variance scaling initializers, where we multiply the weight matrix to another random matrix of inputs and analyze the variance of the outputs.
They've been implemented in netket/netket#840 . If you agree with the design, I'll open a PR to JAX. Thank you!
Paging @jekbradbury for thoughts on adding to jax.nn
.
Also cc @avital and @jheek from the flax thread mentioned above.
Any thought on this (or everyone's on holiday)?
Bump @jekbradbury what do you think about this proposal?
It sounds like there's consensus from researchers on a subset of initializers, and I think we'd be happy to include those.
Follow up to #4680 and #4805
My objective is to use flax to build complex-valued neural networks in the following fashion:
however, as you can see right now this returns the wrong result and the network will have real parameters. The culprit is
jax.random.truncated_normal
which does not support complex dtype.I would like to add support for complex dtypes to it (and possibly to
jax.random.uniform
) similarly to what I have done fornormal
in #4805, however some care should be taken in the API:The signature of
truncated normal
right now is:However such a generalisation is not trivial: while the
normal
distribution in the complex plane is well defined as a rescaled normal distribution of both independent real and complex parts, how should the truncation be applied in the complex plane?One possibility would be to truncate in a square box between the
lower_left
andupper_right
edges in the complex plane, however this is not consistent with the definition of a complex-uniform distribution.A more consistent definition would be to maybe truncate in a circle, where we specify the centre and the radius.
However this would make the arguments
lower
andupper
assume a different meaning depending on the input. Would this be ok?