Open DBraun opened 1 month ago
On a quick look at the torch documentation and the source code of jax.random.truncated_normal
, it seems that:
jax.random.uniform
called here called to uniformly sample between a customized min-max range)This might explain why the min/max values of Pytorch are more divergent from 0, as it is based on a distribution that has a higher chance to be out-of-bound.
If you'd like to know more, I'd recommend open an issue/question on JAX Github for a response from the authors.
Thanks for taking a look.
I've been plotting histograms and I've observed that I can get the same behavior between PyTorch and JAX with this procedure:
In JAX, if you change the std deviation parameter, the "shape" of the histogram doesn't change. If the xaxis is set to auto, then you essentially see the same shape but with different bounds. This is not true for PyTorch. In PyTorch to get the same behavior, you'd multiply both the lower/upper and std deviation by the same factor.
I think that Convert PyTorch models to Flax should have a section dedicated to initializers. I'm porting training code, not just weights, so it's helpful to have notes on initializers.
In my work so far I think I've noticed that to get PyTorch behavior
nn.initializers.variance_scaling(1/3, "fan_in", "uniform")
instead of lecun_normal. But for bias_init
, you have to implement it yourself: https://github.com/google/flax/discussions/4131nn.initializers.variance_scaling(1/3, "fan_out", "uniform")
instead of lecun_normal. But for bias_init
, you have to implement it yourself: https://github.com/google/flax/discussions/4131nn.initializers.normal(1)
instead of nn.initializers.variance_scaling(1.0, "fan_in", "normal", out_axis=0)
System information
Both
nn.initializers.truncated_normal
andjax.nn.initializers.truncated_normal
aren't similar enough to PyTorch'snn.init.trunc_normal_
. All of these use a lower of -2 and upper of 2 by default.I'm running a test to make sure the outputs are similar if given the same arguments.
Here's my JAX code.
Here's my PyTorch code:
JAX output:
PyTorch output:
Although the std values look close enough, the min and max seem off.
However, let's look at the JAX output again if I set
lower=-4
, even though PyTorch is using -2.JAX output:
Now min/max line up with PyTorch better. I haven't figured out in the source code what explains this, but it would be nice to document it if it's an intended design.