google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.15k stars 648 forks source link

nn.WeightNorm needs a special scale_init to match PyTorch weight_norm #4138

Open DBraun opened 3 months ago

DBraun commented 3 months ago

System information

Problem you have encountered:

flax.linen.WeightNorm needs an special scale_init in order to match PyTorch. I have written an example in both PyTorch and Flax that produces the same outputs.

About Conv

Before talking about WeightNorm, I first have to show that the convolutions before the weight norm produce the same outputs. That's the purpose of run_custom_conv() in both scripts. The torch documentation for Conv2d gives a formula for initializing the kernel and the bias. In my Flax script, I have a make_initializer which uses in_channels, like a fan-in operation described by the torch docs. I looked at the source code of variance_scaling, and it turns out that you can use kernel_init = nn.initializers.variance_scaling(1/3, "fan_in", "uniform") in JAX instead of make_initializer(...). Needing to use 1/3 is a little unintuitive, but no big deal.

Other users have pointed out that you can't use variance_scaling for the bias_init (https://github.com/google/flax/issues/2749). One solution is to refactor one's code to use make_initializer. If you need a fan-out operation, like how torch does ConvTranspose, it's also easy to refactor make_initializer.

About WeightNorm

I have a guess that Flax WeightNorm needs scale_init = nn.initializers.constant(1/jnp.sqrt(3)) in order to match PyTorch. I arrived at this number through a bit of trial and error, and I also think it's not 0.5. I would like to know if someone can explain why.

Here's the PyTorch:

from einops import rearrange
from matplotlib import pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.nn.utils import weight_norm
from torch.nn import functional as F

def WNConv2d(*args, act=True, **kwargs):
    conv = weight_norm(nn.Conv2d(*args, **kwargs))
    if not act:
        return conv
    return nn.Sequential(conv, nn.LeakyReLU(0.1))

class MPD(nn.Module):
    def __init__(self, period):
        super().__init__()
        self.period = period
        self.convs = nn.ModuleList(
            [
                WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
                WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
                WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
                WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
                WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
            ]
        )
        self.conv_post = WNConv2d(
            1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
        )

    def pad_to_period(self, x):
        t = x.shape[-1]
        x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
        return x

    def forward(self, x):
        fmap = []

        x = self.pad_to_period(x)
        x = rearrange(x, "b c (l p) -> b c l p", p=self.period)

        for layer in self.convs:
            x = layer(x)
            fmap.append(x)

        x = self.conv_post(x)
        fmap.append(x)

        return fmap

def summary_stats(name, x, ax):
    x = x.detach().cpu().numpy()
    ax.hist(x.reshape(-1), bins=100, alpha=0.5, label=name)
    ax.set_title(f'PyTorch Histogram of {name}')
    ax.set_xlabel('Value')
    ax.set_ylabel('Frequency')
    ax.legend(loc='upper right')

    print(f'Stats for {name}:')
    print(f'shape:', list(x.shape))
    print(f'mean: {np.mean(x):,.5f} min: {np.min(x):,.5f} max: {np.max(x):,.5f} std: {np.std(x):,.5f}')

def run_MPD():

    B, C, T = 1, 1, 44100
    x = torch.rand((B, C, T)).cuda()*2-1
    period = 2

    model = MPD(period).cuda()

    fmaps = model(x)

    # Create a tall figure with one subplot for each feature map
    fig, axs = plt.subplots(len(fmaps), 1, figsize=(10, 18))
    fig.tight_layout(pad=5.0)  # Adjust the spacing between subplots

    # Plot each histogram on a different subplot
    for i, (fmap, ax) in enumerate(zip(fmaps, axs)):
        summary_stats(f"fmap {i}", fmap, ax)
        print()

    plt.show()

    from torchinfo import summary
    summary(model,
            col_names=['input_size', 'output_size', 'num_params'],
            input_size=x.shape,
            depth=5,
            verbose=1,
            )

def run_custom_conv():

    B, C, H, W = 1, 1, 25, 25
    x = torch.rand((B, C, H, W)).cuda()*2-1

    model = nn.Conv2d(C, out_channels=32, kernel_size=(3, 3), padding=0).cuda()

    fmap = model(x)

    fig, ax = plt.subplots(1, 1, figsize=(8, 6))

    summary_stats(f"fmap 0", fmap.reshape(-1), ax)
    print()

    plt.show()

    from torchinfo import summary
    summary(model,
            col_names=['input_size', 'output_size', 'num_params'],
            input_size=x.shape,
            depth=5,
            verbose=1,
            )

if __name__ == '__main__':
    print('running custom conv:')
    run_custom_conv()
    print('running MPD:')
    run_MPD()

and its output:

running custom conv:
Stats for fmap 0:
shape: [16928]
mean: -0.02200 min: -1.46258 max: 1.46658 std: 0.38855

===================================================================================================================
Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
===================================================================================================================
Conv2d                                   [1, 1, 25, 25]            [1, 32, 23, 23]           320
===================================================================================================================
Total params: 320
Trainable params: 320
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.17
===================================================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.14
Params size (MB): 0.00
Estimated Total Size (MB): 0.14
===================================================================================================================
running MPD:
C:\Python311\Lib\site-packages\torch\nn\utils\weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
  warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
Stats for fmap 0:
shape: [1, 32, 7351, 2]
mean: 0.18682 min: -0.19188 max: 1.80779 std: 0.27846

Stats for fmap 1:
shape: [1, 128, 2451, 2]
mean: 0.07238 min: -0.10415 max: 0.91518 std: 0.11993

Stats for fmap 2:
shape: [1, 512, 817, 2]
mean: 0.02764 min: -0.03801 max: 0.46201 std: 0.04894

Stats for fmap 3:
shape: [1, 1024, 273, 2]
mean: 0.01188 min: -0.01440 max: 0.14850 std: 0.02060

Stats for fmap 4:
shape: [1, 1024, 273, 2]
mean: 0.00579 min: -0.00673 max: 0.07628 std: 0.00976

Stats for fmap 5:
shape: [1, 1, 273, 2]
mean: -0.00274 min: -0.01184 max: 0.00474 std: 0.00262

===================================================================================================================
Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
===================================================================================================================
MPD                                      [1, 1, 44100]             [1, 32, 7351, 2]          --
├─ModuleList: 1-1                        --                        --                        --
│    └─Sequential: 2-1                   [1, 1, 22051, 2]          [1, 32, 7351, 2]          --
│    │    └─Conv2d: 3-1                  [1, 1, 22051, 2]          [1, 32, 7351, 2]          224
│    │    └─LeakyReLU: 3-2               [1, 32, 7351, 2]          [1, 32, 7351, 2]          --
│    └─Sequential: 2-2                   [1, 32, 7351, 2]          [1, 128, 2451, 2]         --
│    │    └─Conv2d: 3-3                  [1, 32, 7351, 2]          [1, 128, 2451, 2]         20,736
│    │    └─LeakyReLU: 3-4               [1, 128, 2451, 2]         [1, 128, 2451, 2]         --
│    └─Sequential: 2-3                   [1, 128, 2451, 2]         [1, 512, 817, 2]          --
│    │    └─Conv2d: 3-5                  [1, 128, 2451, 2]         [1, 512, 817, 2]          328,704
│    │    └─LeakyReLU: 3-6               [1, 512, 817, 2]          [1, 512, 817, 2]          --
│    └─Sequential: 2-4                   [1, 512, 817, 2]          [1, 1024, 273, 2]         --
│    │    └─Conv2d: 3-7                  [1, 512, 817, 2]          [1, 1024, 273, 2]         2,623,488
│    │    └─LeakyReLU: 3-8               [1, 1024, 273, 2]         [1, 1024, 273, 2]         --
│    └─Sequential: 2-5                   [1, 1024, 273, 2]         [1, 1024, 273, 2]         --
│    │    └─Conv2d: 3-9                  [1, 1024, 273, 2]         [1, 1024, 273, 2]         5,244,928
│    │    └─LeakyReLU: 3-10              [1, 1024, 273, 2]         [1, 1024, 273, 2]         --
├─Conv2d: 1-2                            [1, 1024, 273, 2]         [1, 1, 273, 2]            3,074
===================================================================================================================
Total params: 8,221,154
Trainable params: 8,221,154
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 8.23
===================================================================================================================
Input size (MB): 0.18
Forward/backward pass size (MB): 24.43
Params size (MB): 32.88
Estimated Total Size (MB): 57.49
===================================================================================================================

and its two graphs: image image

Here's the Flax:

from einops import rearrange
from flax import linen as nn
import jax
from jax import numpy as jnp
from matplotlib import pyplot as plt
import numpy as np

def make_initializer(in_channels, out_channels, kernel_size, groups):
    # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
    k = groups / (in_channels * jnp.prod(jnp.array(kernel_size)))
    scale = jnp.sqrt(k)

    def init_fn(key, shape, dtype):
        return jax.random.uniform(key, shape, minval=-scale, maxval=scale, dtype=dtype)

    return init_fn

class CustomConv(nn.Conv):

    @nn.compact
    def __call__(self, x):

        # note: we just ignore whatever self.kernel_init is
        kernel_init = make_initializer(
            x.shape[-1], self.features, self.kernel_size, self.feature_group_count
        )

        if self.use_bias:
            # note: we just ignore whatever self.bias_init is
            bias_init = make_initializer(
                x.shape[-1], self.features, self.kernel_size, self.feature_group_count
            )
        else:
            bias_init = None

        # todo: try using these instead
        # kernel_init = nn.initializers.variance_scaling(1/3, "fan_in", "uniform")  # same as kernel_init above
        # bias_init = nn.initializers.constant(1)

        return nn.Conv(
            features=self.features,
            kernel_size=self.kernel_size,
            strides=self.strides,
            padding=self.padding,
            input_dilation=self.input_dilation,
            kernel_dilation=self.kernel_dilation,
            feature_group_count=self.feature_group_count,
            use_bias=self.use_bias,
            mask=self.mask,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            precision=self.precision,
            kernel_init=kernel_init,
            bias_init=bias_init
        )(x)

class LeakyReLU(nn.Module):

    negative_slope: float = .01

    @nn.compact
    def __call__(self, x):
        return nn.leaky_relu(x, negative_slope=self.negative_slope)

def WNConv2d(*args, **kwargs):
    scale_init = nn.initializers.constant(1/jnp.sqrt(3))
    # scale_init = nn.initializers.constant(1)  # todo: try using this instead
    conv = nn.WeightNorm(CustomConv(*args, **kwargs), scale_init=scale_init)
    return conv

class MPD(nn.Module):

    period: int

    def pad_to_period(self, x):
        t = x.shape[-1]
        x = jnp.pad(x, pad_width=((0, 0), (0, 0), (0, self.period - t % self.period)), mode='reflect')
        return x

    @nn.compact
    def __call__(self, x):
        convs = [
            WNConv2d(features=32, kernel_size=(5, 1), strides=(3, 1), padding=((2, 2), (0, 0))),
            WNConv2d(features=128, kernel_size=(5, 1), strides=(3, 1), padding=((2, 2), (0, 0))),
            WNConv2d(features=512, kernel_size=(5, 1), strides=(3, 1), padding=((2, 2), (0, 0))),
            WNConv2d(features=1024, kernel_size=(5, 1), strides=(3, 1), padding=((2, 2), (0, 0))),
            WNConv2d(features=1024, kernel_size=(5, 1), strides=(1, 1), padding=((2, 2), (0, 0))),
            WNConv2d(features=1, kernel_size=(3, 1), strides=(1, 1), padding=((1, 1), (0, 0))),
        ]

        fmap = []

        x = self.pad_to_period(x)
        x = rearrange(x, "b c (l p) -> b l p c", p=self.period)

        for i, layer in enumerate(convs):
            x = layer(x)
            if i != (len(convs) - 1):
                x = LeakyReLU(negative_slope=0.1)(x)
            fmap.append(x)

        return fmap

def summary_stats(name, x, ax):
    x = np.array(x)
    ax.hist(x.reshape(-1), bins=100, alpha=0.5, label=name)
    ax.set_title(f'JAX Histogram of {name}')
    ax.set_xlabel('Value')
    ax.set_ylabel('Frequency')
    ax.legend(loc='upper right')

    print(f'Stats for {name}:')
    print(f'shape:', list(x.shape))
    print(f'mean: {np.mean(x):,.5f} min: {np.min(x):,.5f} max: {np.max(x):,.5f} std: {np.std(x):,.5f}')

def run_MPD():

    key = jax.random.PRNGKey(0)
    B, C, T = 1, 1, 44100
    x = jax.random.uniform(jax.random.PRNGKey(0), shape=(B, C, T), minval=-1.0, maxval=1.0)
    period = 2

    model = MPD(period)
    fmaps, variables = model.init_with_output({"params": key}, x)

    # Create a tall figure with one subplot for each feature map
    fig, axs = plt.subplots(len(fmaps), 1, figsize=(10, 18))
    fig.tight_layout(pad=5.0)  # Adjust the spacing between subplots

    # Plot each histogram on a different subplot
    for i, (fmap, ax) in enumerate(zip(fmaps, axs)):
        summary_stats(f"fmap {i}", fmap, ax)
        print()

    plt.show()

    print(model.tabulate({"params": key}, x, console_kwargs={"width": 400}))

def run_custom_conv():

    key = jax.random.PRNGKey(0)
    B, C, H, W = 1, 1, 25, 25
    x = jax.random.uniform(jax.random.PRNGKey(0), shape=(B, H, W, C), minval=-1.0, maxval=1.0)

    model = CustomConv(features=32, kernel_size=(3, 3), padding='VALID')
    fmap, variables = model.init_with_output({"params": key}, x)

    fig, axs = plt.subplots(1, 1, figsize=(8, 6))
    fig.tight_layout(pad=5.0)  # Adjust the spacing between subplots

    # Plot each histogram on a different subplot
    summary_stats(f"fmap 0", fmap.reshape(-1), axs)
    print()

    plt.show()

    print(model.tabulate({"params": key}, x, console_kwargs={"width": 400}))

if __name__ == '__main__':
    print('running custom conv:')
    run_custom_conv()
    print('running MPD:')
    run_MPD()

Here's the Flax output:

running custom conv:
Stats for fmap 0:
shape: [16928]
mean: -0.03363 min: -1.40942 max: 1.28537 std: 0.37116

                                      CustomConv Summary                                      
┏━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ path   ┃ module     ┃ inputs             ┃ outputs             ┃ params                    ┃
┡━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│        │ CustomConv │ float32[1,25,25,1] │ float32[1,23,23,32] │                           │
├────────┼────────────┼────────────────────┼─────────────────────┼───────────────────────────┤
│ Conv_0 │ Conv       │ float32[1,25,25,1] │ float32[1,23,23,32] │ bias: float32[32]         │
│        │            │                    │                     │ kernel: float32[3,3,1,32] │
│        │            │                    │                     │                           │
│        │            │                    │                     │ 320 (1.3 KB)              │
├────────┼────────────┼────────────────────┼─────────────────────┼───────────────────────────┤
│        │            │                    │               Total │ 320 (1.3 KB)              │
└────────┴────────────┴────────────────────┴─────────────────────┴───────────────────────────┘

                                Total Parameters: 320 (1.3 KB)                                

running MPD:
Stats for fmap 0:
shape: [1, 7351, 2, 32]
mean: 0.15735 min: -0.14764 max: 1.46771 std: 0.26314

Stats for fmap 1:
shape: [1, 2451, 2, 128]
mean: 0.05720 min: -0.08959 max: 0.69397 std: 0.10307

Stats for fmap 2:
shape: [1, 817, 2, 512]
mean: 0.02382 min: -0.03053 max: 0.30742 std: 0.04272

Stats for fmap 3:
shape: [1, 273, 2, 1024]
mean: 0.01169 min: -0.01467 max: 0.13217 std: 0.01918

Stats for fmap 4:
shape: [1, 273, 2, 1024]
mean: 0.00552 min: -0.00686 max: 0.06548 std: 0.00945

Stats for fmap 5:
shape: [1, 273, 2, 1]
mean: 0.02159 min: 0.01417 max: 0.02803 std: 0.00248

                                                              MPD Summary                                                               
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ path                ┃ module     ┃ inputs                ┃ outputs                 ┃ params                                          ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│                     │ MPD        │ float32[1,1,44100]    │ - float32[1,7351,2,32]  │                                                 │
│                     │            │                       │ - float32[1,2451,2,128] │                                                 │
│                     │            │                       │ - float32[1,817,2,512]  │                                                 │
│                     │            │                       │ - float32[1,273,2,1024] │                                                 │
│                     │            │                       │ - float32[1,273,2,1024] │                                                 │
│                     │            │                       │ - float32[1,273,2,1]    │                                                 │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ WeightNorm_0        │ WeightNorm │ float32[1,22051,2,1]  │ float32[1,7351,2,32]    │ CustomConv_0/Conv_0/kernel/scale: float32[32]   │
│                     │            │                       │                         │                                                 │
│                     │            │                       │                         │ 32 (128 B)                                      │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ CustomConv_0        │ CustomConv │ float32[1,22051,2,1]  │ float32[1,7351,2,32]    │                                                 │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ CustomConv_0/Conv_0 │ Conv       │ float32[1,22051,2,1]  │ float32[1,7351,2,32]    │ bias: float32[32]                               │
│                     │            │                       │                         │ kernel: float32[5,1,1,32]                       │
│                     │            │                       │                         │                                                 │
│                     │            │                       │                         │ 192 (768 B)                                     │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ LeakyReLU_0         │ LeakyReLU  │ float32[1,7351,2,32]  │ float32[1,7351,2,32]    │                                                 │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ WeightNorm_1        │ WeightNorm │ float32[1,7351,2,32]  │ float32[1,2451,2,128]   │ CustomConv_1/Conv_0/kernel/scale: float32[128]  │
│                     │            │                       │                         │                                                 │
│                     │            │                       │                         │ 128 (512 B)                                     │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ CustomConv_1        │ CustomConv │ float32[1,7351,2,32]  │ float32[1,2451,2,128]   │                                                 │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ CustomConv_1/Conv_0 │ Conv       │ float32[1,7351,2,32]  │ float32[1,2451,2,128]   │ bias: float32[128]                              │
│                     │            │                       │                         │ kernel: float32[5,1,32,128]                     │
│                     │            │                       │                         │                                                 │
│                     │            │                       │                         │ 20,608 (82.4 KB)                                │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ LeakyReLU_1         │ LeakyReLU  │ float32[1,2451,2,128] │ float32[1,2451,2,128]   │                                                 │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ WeightNorm_2        │ WeightNorm │ float32[1,2451,2,128] │ float32[1,817,2,512]    │ CustomConv_2/Conv_0/kernel/scale: float32[512]  │
│                     │            │                       │                         │                                                 │
│                     │            │                       │                         │ 512 (2.0 KB)                                    │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ CustomConv_2        │ CustomConv │ float32[1,2451,2,128] │ float32[1,817,2,512]    │                                                 │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ CustomConv_2/Conv_0 │ Conv       │ float32[1,2451,2,128] │ float32[1,817,2,512]    │ bias: float32[512]                              │
│                     │            │                       │                         │ kernel: float32[5,1,128,512]                    │
│                     │            │                       │                         │                                                 │
│                     │            │                       │                         │ 328,192 (1.3 MB)                                │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ LeakyReLU_2         │ LeakyReLU  │ float32[1,817,2,512]  │ float32[1,817,2,512]    │                                                 │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ WeightNorm_3        │ WeightNorm │ float32[1,817,2,512]  │ float32[1,273,2,1024]   │ CustomConv_3/Conv_0/kernel/scale: float32[1024] │
│                     │            │                       │                         │                                                 │
│                     │            │                       │                         │ 1,024 (4.1 KB)                                  │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ CustomConv_3        │ CustomConv │ float32[1,817,2,512]  │ float32[1,273,2,1024]   │                                                 │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ CustomConv_3/Conv_0 │ Conv       │ float32[1,817,2,512]  │ float32[1,273,2,1024]   │ bias: float32[1024]                             │
│                     │            │                       │                         │ kernel: float32[5,1,512,1024]                   │
│                     │            │                       │                         │                                                 │
│                     │            │                       │                         │ 2,622,464 (10.5 MB)                             │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ LeakyReLU_3         │ LeakyReLU  │ float32[1,273,2,1024] │ float32[1,273,2,1024]   │                                                 │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ WeightNorm_4        │ WeightNorm │ float32[1,273,2,1024] │ float32[1,273,2,1024]   │ CustomConv_4/Conv_0/kernel/scale: float32[1024] │
│                     │            │                       │                         │                                                 │
│                     │            │                       │                         │ 1,024 (4.1 KB)                                  │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ CustomConv_4        │ CustomConv │ float32[1,273,2,1024] │ float32[1,273,2,1024]   │                                                 │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ CustomConv_4/Conv_0 │ Conv       │ float32[1,273,2,1024] │ float32[1,273,2,1024]   │ bias: float32[1024]                             │
│                     │            │                       │                         │ kernel: float32[5,1,1024,1024]                  │
│                     │            │                       │                         │                                                 │
│                     │            │                       │                         │ 5,243,904 (21.0 MB)                             │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ LeakyReLU_4         │ LeakyReLU  │ float32[1,273,2,1024] │ float32[1,273,2,1024]   │                                                 │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ WeightNorm_5        │ WeightNorm │ float32[1,273,2,1024] │ float32[1,273,2,1]      │ CustomConv_5/Conv_0/kernel/scale: float32[1]    │
│                     │            │                       │                         │                                                 │
│                     │            │                       │                         │ 1 (4 B)                                         │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ CustomConv_5        │ CustomConv │ float32[1,273,2,1024] │ float32[1,273,2,1]      │                                                 │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ CustomConv_5/Conv_0 │ Conv       │ float32[1,273,2,1024] │ float32[1,273,2,1]      │ bias: float32[1]                                │
│                     │            │                       │                         │ kernel: float32[3,1,1024,1]                     │
│                     │            │                       │                         │                                                 │
│                     │            │                       │                         │ 3,073 (12.3 KB)                                 │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│                     │            │                       │                   Total │ 8,221,154 (32.9 MB)                             │
└─────────────────────┴────────────┴───────────────────────┴─────────────────────────┴─────────────────────────────────────────────────┘

                                                 Total Parameters: 8,221,154 (32.9 MB)       

and its two graphs: image image

DBraun commented 3 months ago

There's an explanation for 1/sqrt(3). It's because the variance of a uniform distribution between -1 and 1 is 1/3, so the standard deviation is 1/sqrt(3). I hope that's a clue for finding why PyTorch seems to do WeightNorm one way and Flax does it another.