Open DBraun opened 3 months ago
There's an explanation for 1/sqrt(3)
. It's because the variance of a uniform distribution between -1 and 1 is 1/3, so the standard deviation is 1/sqrt(3)
. I hope that's a clue for finding why PyTorch seems to do WeightNorm one way and Flax does it another.
System information
pip show flax jax jaxlib
:Problem you have encountered:
flax.linen.WeightNorm
needs an specialscale_init
in order to match PyTorch. I have written an example in both PyTorch and Flax that produces the same outputs.About Conv
Before talking about WeightNorm, I first have to show that the convolutions before the weight norm produce the same outputs. That's the purpose of
run_custom_conv()
in both scripts. The torch documentation for Conv2d gives a formula for initializing the kernel and the bias. In my Flax script, I have amake_initializer
which usesin_channels
, like a fan-in operation described by the torch docs. I looked at the source code ofvariance_scaling
, and it turns out that you can usekernel_init = nn.initializers.variance_scaling(1/3, "fan_in", "uniform")
in JAX instead ofmake_initializer(...)
. Needing to use 1/3 is a little unintuitive, but no big deal.Other users have pointed out that you can't use
variance_scaling
for the bias_init (https://github.com/google/flax/issues/2749). One solution is to refactor one's code to usemake_initializer
. If you need a fan-out operation, like how torch does ConvTranspose, it's also easy to refactormake_initializer
.About WeightNorm
I have a guess that Flax WeightNorm needs
scale_init = nn.initializers.constant(1/jnp.sqrt(3))
in order to match PyTorch. I arrived at this number through a bit of trial and error, and I also think it's not 0.5. I would like to know if someone can explain why.Here's the PyTorch:
and its output:
and its two graphs:
Here's the Flax:
Here's the Flax output:
and its two graphs: