google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.93k stars 628 forks source link

Truncated Normal initializer doesn't match PyTorch #4091

Open DBraun opened 1 month ago

DBraun commented 1 month ago

System information

Both nn.initializers.truncated_normal and jax.nn.initializers.truncated_normal aren't similar enough to PyTorch's nn.init.trunc_normal_. All of these use a lower of -2 and upper of 2 by default.

I'm running a test to make sure the outputs are similar if given the same arguments.

Here's my JAX code.

import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn

def summary_stats(name, x):
    print(f'Stats for {name}:')
    print('shape: ', x.shape)
    print('min: ', x.min())
    print('max: ', x.max())
    print('std: ', jnp.std(x))
    # print(x)

def make_trunc(key, stddev):
    lower = -2
    upper = -lower
    shape = (4096,)
    return nn.initializers.truncated_normal(stddev, lower=lower, upper=upper)(key, shape=shape)
    # return jax.nn.initializers.truncated_normal(stddev, lower=lower, upper=upper)(key, shape=shape)

summary_stats('trunc', make_trunc(random.key(0), .02))
summary_stats('trunc', make_trunc(random.key(1), .04))
summary_stats('trunc', make_trunc(random.key(2), .06))

Here's my PyTorch code:

import torch
import torch.nn as nn

def summary_stats(name, x):
    print(f'Stats for {name}:')
    print('shape: ', x.shape)
    print('min: ', x.min().item())
    print('max: ', x.max().item())
    print('std: ', x.std().item())
    # print(x)

t = torch.zeros((4096,))

nn.init.trunc_normal_(t, std=0.02)
summary_stats('pytorch', t)
nn.init.trunc_normal_(t, std=0.04)
summary_stats('pytorch', t)
nn.init.trunc_normal_(t, std=0.06)
summary_stats('pytorch', t)

JAX output:

Stats for trunc:
shape:  (4096,)
min:  -0.039960504
max:  0.039985776
std:  0.018035976
Stats for trunc:
shape:  (4096,)
min:  -0.0797701
max:  0.07959695
std:  0.03491081
Stats for trunc:
shape:  (4096,)
min:  -0.11983735
max:  0.119733654
std:  0.053539284

PyTorch output:

Stats for pytorch:
shape:  torch.Size([4096])
min:  -0.06634494662284851
max:  0.0743524581193924
std:  0.020439231768250465
Stats for pytorch:
shape:  torch.Size([4096])
min:  -0.13382470607757568
max:  0.12441056221723557
std:  0.03931436687707901
Stats for pytorch:
shape:  torch.Size([4096])
min:  -0.22086666524410248
max:  0.20988918840885162
std:  0.05979840084910393

Although the std values look close enough, the min and max seem off.

However, let's look at the JAX output again if I set lower=-4, even though PyTorch is using -2.

JAX output:

Stats for trunc:
shape:  (4096,)
min:  -0.07253291
max:  0.076029524
std:  0.020624608
Stats for trunc:
shape:  (4096,)
min:  -0.13531744
max:  0.12941647
std:  0.039439432
Stats for trunc:
shape:  (4096,)
min:  -0.21360189
max:  0.20679429
std:  0.060959056

Now min/max line up with PyTorch better. I haven't figured out in the source code what explains this, but it would be nice to document it if it's an intended design.

IvyZX commented 1 month ago

On a quick look at the torch documentation and the source code of jax.random.truncated_normal, it seems that:

This might explain why the min/max values of Pytorch are more divergent from 0, as it is based on a distribution that has a higher chance to be out-of-bound.

If you'd like to know more, I'd recommend open an issue/question on JAX Github for a response from the authors.

DBraun commented 1 month ago

Thanks for taking a look.

I've been plotting histograms and I've observed that I can get the same behavior between PyTorch and JAX with this procedure:

  1. Use the same std deviation argument in both PyTorch and JAX.
  2. Take the lower/upper values that you're using in PyTorch and divide by the std deviation to get the lower/upper to use in JAX.

In JAX, if you change the std deviation parameter, the "shape" of the histogram doesn't change. If the xaxis is set to auto, then you essentially see the same shape but with different bounds. This is not true for PyTorch. In PyTorch to get the same behavior, you'd multiply both the lower/upper and std deviation by the same factor.

I think that Convert PyTorch models to Flax should have a section dedicated to initializers. I'm porting training code, not just weights, so it's helpful to have notes on initializers.

In my work so far I think I've noticed that to get PyTorch behavior