google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

NotImplementedError: When I use stax.DotGeneral #121

Open kkeevin123456 opened 3 years ago

kkeevin123456 commented 3 years ago

Hi, when I try to implement two-layer coupling layer like below image. I got this error, do you have any insight to fix it?

image

The error looks like:

image

Some direction I had try

Here is some code can reproduce error:

    from jax import random
    from neural_tangents import stax
    import jax.numpy as np
    import neural_tangents as nt

    def DenseBlock(neurons):
        return stax.serial(
            stax.Dense(neurons), stax.Relu()
        )

    def ReluNetwork(latent_dim, hidden_dim, num_layers):
        """Create the network which is embedd in flow_base model

        Args:
            latent_dim: input and output dim
            hidden_dim: the width dim of the ReluNetwork
            num_layers: depth of the ReluNetwork

        Returns:
            stax.serial(ReluNetwork)
        """
        blocks = [DenseBlock(hidden_dim)]
        for _ in range(num_layers):
            blocks += [DenseBlock(hidden_dim)]
        blocks += [stax.Dense(latent_dim)]

        return stax.serial(*blocks)

    def lower_path(input_dim):
        helf_dim = input_dim//2
        # pre_half's rhs
        rhs1 = np.identity(helf_dim)
        rhs1 = np.pad(rhs1, ((0, 0), (0, helf_dim)))
        rhs1 = np.reshape(rhs1, (*rhs1.shape, 1))

        # post_half's rhs
        rhs2 = np.identity(helf_dim)
        rhs2 = np.pad(rhs2, ((helf_dim, 0), (helf_dim, 0)))
        rhs2 = np.reshape(rhs2, (*rhs2.shape, 1))

        rhs4 = np.identity(helf_dim)
        rhs4 = np.pad(rhs4, ((helf_dim, 0), (0, 0)))
        rhs4 = np.reshape(rhs4, (*rhs4.shape, 1))

        blocks = [
            stax.DotGeneral(
                    rhs = rhs1,
                    dimension_numbers = (((2,), (1,)), ((), ())),
                    channel_axis = 1
                ), 
            stax.DotGeneral(
                    rhs = np.array([1]),
                    dimension_numbers = (((3,), (0,)), ((), ())),
                    channel_axis = 1
                )]
        blocks += [ReluNetwork(latent_dim=helf_dim, hidden_dim=helf_dim//2, num_layers=4)]
        blocks += [
            stax.DotGeneral(
                    rhs = rhs4,
                    dimension_numbers = (((2,), (1,)), ((), ())),
                    channel_axis = 1
                ), 
            stax.DotGeneral(
                    rhs = np.array([1]),
                    dimension_numbers = (((3,), (0,)), ((), ())),
                    channel_axis = 1
                )]

        pre_half = stax.serial(
            *blocks
        )
        post_half = stax.serial(
            stax.DotGeneral(
                    rhs = rhs2,
                    dimension_numbers = (((2,), (1,)), ((), ())),
                    channel_axis = 1
                ), 
            stax.DotGeneral(
                    rhs = np.array([1]),
                    dimension_numbers = (((3,), (0,)), ((), ())),
                    channel_axis = 1
                )
        )
        return stax.serial(stax.FanOut(2),
                           stax.parallel(pre_half, post_half),
                           stax.FanInSum()
                          )

    def AdditiveCouplingLayer(input_dim, order):
        """the additive couplinglayer in the paper

        Args:
            nonlinearity: the ReluNetwork

        Returns:
            stax.serial(AdditiveCouplingLayer)
        """
        helf_dim = input_dim//2

        rhs_matrix = np.identity(helf_dim)
        rhs_matrix = np.pad(rhs_matrix, ((0, helf_dim), (0, helf_dim)))
        rhs_matrix = np.reshape(rhs_matrix, (*rhs_matrix.shape, 1))

        upper_path = stax.serial(
            stax.DotGeneral(
                    rhs = rhs_matrix,
                    dimension_numbers = (((2,), (1,)), ((), ())),
                    channel_axis = 1
                ), 
            stax.DotGeneral(
                    rhs = np.array([1]),
                    dimension_numbers = (((3,), (0,)), ((), ())),
                    channel_axis = 1
                )
        )

        return stax.serial(stax.FanOut(2),
                           stax.parallel(upper_path, lower_path(input_dim)),
                           stax.FanInSum()
                          )
    def LogisticPriorLoss(fx, y):
        return np.mean((0.5*np.sum(np.power(fx, 2), axis=1) + fx.shape[1]*0.5*np.log(2*np.pi)))

    # test
    x = np.array([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12], [13, 14, 15, 16, 17, 18]])
    x = np.reshape(x, (x.shape[0], 1, *x.shape[1:]))
    input_dim = x.shape[2]  # (B, 1, 4): B is batch size
    helf_dim = input_dim//2

    init_fn, apply_fn, kernel_fn = lower_path(input_dim=input_dim)

    key = random.PRNGKey(1)
    _, params = init_fn(key, input_shape=x.shape)

    # z_train.dim = x_train.dim
    z_train = random.normal(key, x.shape)
    x_test = np.array([[1, 2, 3, 4, 5, 6]])
    x_test = np.reshape(x_test, (x_test.shape[0], 1, *x_test.shape[1:]))

    ntk_train_train = kernel_fn(x, x, 'ntk', channel_axis=1, is_gaussian=True)
    ntk_test_train = kernel_fn(x_test, x, 'ntk')
    predictor = nt.predict.gradient_descent(LogisticPriorLoss, ntk_train_train, z_train)

Many thanks for your kindly reply.

sschoenholz commented 3 years ago

Great question!

A few points.

  1. You were on the right track with setting is_gaussian=True. Notice that post_half doesn't have any dense layers and so if the inputs to it aren't Gaussian then the outputs also will not be Gaussian. Unfortunately, NT doesn't support explicitly setting is_gaussian=True in the kernel_fn (since the inputs to the network are assumed to be constants rather than gaussian random variables). One way to solve this problem is to add a single dense layer at the top of your network.
  2. There were some inconsistencies in setting channel_axis. In particular, you also had to set channel_axis=1 in the Dense layers in DenseBlock(..) and ReluNetwork(..). When this was done there was a shape error where the number of channels was a bit different between the two branches. To solve this I ended up setting the initial dense layer to project down to the latent dimension, but I'm not sure whether this was what you were going for.

In any case, here is a version of the code that should work. Let me know if you run into any trouble!

%pdb on

from jax import random
from neural_tangents import stax
import jax.numpy as np
import neural_tangents as nt

def DenseBlock(neurons):
    return stax.serial(
        stax.Dense(neurons, channel_axis=1), stax.Relu()
    )

def ReluNetwork(latent_dim, hidden_dim, num_layers):
    """Create the network which is embedd in flow_base model

    Args:
        latent_dim: input and output dim
        hidden_dim: the width dim of the ReluNetwork
        num_layers: depth of the ReluNetwork

    Returns:
        stax.serial(ReluNetwork)
    """
    blocks = [DenseBlock(hidden_dim)]
    for _ in range(num_layers):
        blocks += [DenseBlock(hidden_dim)]
    blocks += [stax.Dense(latent_dim, channel_axis=1)]

    return stax.serial(*blocks)

def lower_path(input_dim):
    helf_dim = input_dim//2
    # pre_half's rhs
    rhs1 = np.identity(helf_dim)
    rhs1 = np.pad(rhs1, ((0, 0), (0, helf_dim)))
    rhs1 = np.reshape(rhs1, (*rhs1.shape, 1))

    # post_half's rhs
    rhs2 = np.identity(helf_dim)
    rhs2 = np.pad(rhs2, ((helf_dim, 0), (helf_dim, 0)))
    rhs2 = np.reshape(rhs2, (*rhs2.shape, 1))

    rhs4 = np.identity(helf_dim)
    rhs4 = np.pad(rhs4, ((helf_dim, 0), (0, 0)))
    rhs4 = np.reshape(rhs4, (*rhs4.shape, 1))

    blocks = [
        stax.DotGeneral(
                rhs = rhs1,
                dimension_numbers = (((2,), (1,)), ((), ())),
                channel_axis = 1
            ), 
        stax.DotGeneral(
                rhs = np.array([1]),
                dimension_numbers = (((3,), (0,)), ((), ())),
                channel_axis = 1
            )]
    blocks += [ReluNetwork(latent_dim=helf_dim, hidden_dim=helf_dim//2, num_layers=4)]
    blocks += [
        stax.DotGeneral(
                rhs = rhs4,
                dimension_numbers = (((2,), (1,)), ((), ())),
                channel_axis = 1
            ), 
        stax.DotGeneral(
                rhs = np.array([1]),
                dimension_numbers = (((3,), (0,)), ((), ())),
                channel_axis = 1
            )]

    pre_half = stax.serial(
        *blocks
    )

    post_half = stax.serial(
        stax.DotGeneral(
                rhs = rhs2,
                dimension_numbers = (((2,), (1,)), ((), ())),
                channel_axis = 1
            ), 
        stax.DotGeneral(
                rhs = np.array([1]),
                dimension_numbers = (((3,), (0,)), ((), ())),
                channel_axis = 1
            )
    )
    return stax.serial(stax.Dense(helf_dim, channel_axis=1),
                       stax.FanOut(2),
                       stax.parallel(pre_half, post_half),
                       stax.FanInSum()
                      )

def AdditiveCouplingLayer(input_dim, order):
    """the additive couplinglayer in the paper

    Args:
        nonlinearity: the ReluNetwork

    Returns:
        stax.serial(AdditiveCouplingLayer)
    """
    helf_dim = input_dim//2

    rhs_matrix = np.identity(helf_dim)
    rhs_matrix = np.pad(rhs_matrix, ((0, helf_dim), (0, helf_dim)))
    rhs_matrix = np.reshape(rhs_matrix, (*rhs_matrix.shape, 1))

    upper_path = stax.serial(
        stax.DotGeneral(
                rhs = rhs_matrix,
                dimension_numbers = (((2,), (1,)), ((), ())),
                channel_axis = 1
            ), 
        stax.DotGeneral(
                rhs = np.array([1]),
                dimension_numbers = (((3,), (0,)), ((), ())),
                channel_axis = 1
            )
    )

    return stax.serial(stax.FanOut(2),
                       stax.parallel(upper_path, lower_path(input_dim)),
                       stax.FanInSum()
                      )
def LogisticPriorLoss(fx, y):
    return np.mean((0.5*np.sum(np.power(fx, 2), axis=1) + fx.shape[1]*0.5*np.log(2*np.pi)))

# test
x = np.array([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12], [13, 14, 15, 16, 17, 18]])
x = np.reshape(x, (x.shape[0], 1, *x.shape[1:]))
input_dim = x.shape[2]  # (B, 1, 4): B is batch size
helf_dim = input_dim//2

init_fn, apply_fn, kernel_fn = lower_path(input_dim=input_dim)

key = random.PRNGKey(1)
_, params = init_fn(key, input_shape=x.shape)

# z_train.dim = x_train.dim
z_train = random.normal(key, x.shape)
x_test = np.array([[1, 2, 3, 4, 5, 6]])
x_test = np.reshape(x_test, (x_test.shape[0], 1, *x_test.shape[1:]))

ntk_train_train = kernel_fn(x, x, 'ntk', channel_axis=1, is_gaussian=True)
ntk_test_train = kernel_fn(x_test, x, 'ntk')
predictor = nt.predict.gradient_descent(LogisticPriorLoss, ntk_train_train, z_train)
kkeevin123456 commented 3 years ago

Hello @sschoenholz @romanngg Thanks for your kindly reply!!

But I still have some questions. Does adding Dense layer is a only solution for this problem? Since if I add Dense, the output and output.shape will change.

Base on optimizer.sgd method, the result may seems like

image

Above image shows that my architecture only change bottom half Your reply makes large progression to me

kkeevin123456 commented 3 years ago

Follow up

I observed one weird thing:

romanngg commented 3 years ago

Sorry for the long delay, a few more observations:

Lmk if this helps!