jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.49k stars 2.8k forks source link

Complex-valued initializers: `lecun_normal` #5312

Closed PhilipVinc closed 3 years ago

PhilipVinc commented 3 years ago

Follow up to #4680 and #4805

My objective is to use flax to build complex-valued neural networks in the following fashion:

>>> import jax, flax
>>> m=flax.linen.Dense(3, dtype=jax.numpy.complex64)
>>> _, weights = m.init_with_output(jax.random.PRNGKey(0), (3,))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
>>> weights
FrozenDict({
    params: {
        kernel: DeviceArray([[ 0.67380387, -0.3294223 , -0.9614107 ]], dtype=float32),
        bias: DeviceArray([0., 0., 0.], dtype=float32),
    },
})

however, as you can see right now this returns the wrong result and the network will have real parameters. The culprit is jax.random.truncated_normal which does not support complex dtype.

I would like to add support for complex dtypes to it (and possibly to jax.random.uniform ) similarly to what I have done for normal in #4805, however some care should be taken in the API:

The signature of truncated normal right now is:

truncated_normal(key, 
            lower: Union[float, jax._src.numpy.lax_numpy.ndarray], 
            upper: Union[float, jax._src.numpy.lax_numpy.ndarray], 
            shape = None, dtype: numpy.dtype = <class 'numpy.float64'>) -> jax._src.numpy.lax_numpy.ndarray
    Args:
      key: a PRNGKey used as the random key.
      lower: a float or array of floats representing the lower bound for
        truncation. Must be broadcast-compatible with ``upper``.
      upper: a float or array of floats representing the  upper bound for
        truncation. Must be broadcast-compatible with ``lower``.

However such a generalisation is not trivial: while the normal distribution in the complex plane is well defined as a rescaled normal distribution of both independent real and complex parts, how should the truncation be applied in the complex plane?

One possibility would be to truncate in a square box between the lower_left and upper_right edges in the complex plane, however this is not consistent with the definition of a complex-uniform distribution.

A more consistent definition would be to maybe truncate in a circle, where we specify the centre and the radius.

However this would make the arguments lower and upper assume a different meaning depending on the input. Would this be ok?

froystig commented 3 years ago

I'd say that jax.random.truncated_normal is defined correctly as is, with no complex dtype support, since there isn't a canonical notion of the truncated normal distribution over complex numbers.

This situation seems to suggest a feature request for Flax—if it isn't already supported—of accepting custom distributions for initialization (or anything else along those lines). You could then write and supply whatever distribution you have in mind, whether that's with a box or spherical constraint. It seems like you've requested roughly this in google/flax#805.

froystig commented 3 years ago

Discussion seems to have picked up on the Flax thread, where it probably ought to go for now, so I'm going to close this. Feel free to reopen if it comes back to jax.random.

wdphy16 commented 3 years ago

Hello @froystig . I agree that there is no consensus on how to generalize jax.random.uniform and truncated_normal to complex numbers (while normal is already done in #4805), and I'd suggest to implement complex variance scaling initializers in jax.nn.initializers (and flax.nn.initializers just imports them).

The generalization of glorot_normal to complex numbers is discussed in, e.g., C. Trabelsi, et al., Deep Complex Networks, 2017 and widely cited after that. They proposed an axisymmetric distribution implemented by first sampling the modulus from the radial CDF, then uniformly sampling the phase. It's not hard to truncate the modulus, although not discussed in that paper.

Also, I'd suggest to generalize glorot_uniform to a uniform disk, rather than a uniform square. It makes more sense in the context of variance scaling initializers, where we multiply the weight matrix to another random matrix of inputs and analyze the variance of the outputs.

They've been implemented in netket/netket#840 . If you agree with the design, I'll open a PR to JAX. Thank you!

froystig commented 3 years ago

Paging @jekbradbury for thoughts on adding to jax.nn.

froystig commented 3 years ago

Also cc @avital and @jheek from the flax thread mentioned above.

PhilipVinc commented 3 years ago

Any thought on this (or everyone's on holiday)?

PhilipVinc commented 3 years ago

Bump @jekbradbury what do you think about this proposal?

jekbradbury commented 3 years ago

It sounds like there's consensus from researchers on a subset of initializers, and I think we'd be happy to include those.