Open kkeevin123456 opened 3 years ago
Great question!
A few points.
is_gaussian=True
. Notice that post_half
doesn't have any dense layers and so if the inputs to it aren't Gaussian then the outputs also will not be Gaussian. Unfortunately, NT doesn't support explicitly setting is_gaussian=True
in the kernel_fn
(since the inputs to the network are assumed to be constants rather than gaussian random variables). One way to solve this problem is to add a single dense layer at the top of your network.channel_axis
. In particular, you also had to set channel_axis=1
in the Dense layers in DenseBlock(..)
and ReluNetwork(..)
. When this was done there was a shape error where the number of channels was a bit different between the two branches. To solve this I ended up setting the initial dense layer to project down to the latent dimension, but I'm not sure whether this was what you were going for.In any case, here is a version of the code that should work. Let me know if you run into any trouble!
%pdb on
from jax import random
from neural_tangents import stax
import jax.numpy as np
import neural_tangents as nt
def DenseBlock(neurons):
return stax.serial(
stax.Dense(neurons, channel_axis=1), stax.Relu()
)
def ReluNetwork(latent_dim, hidden_dim, num_layers):
"""Create the network which is embedd in flow_base model
Args:
latent_dim: input and output dim
hidden_dim: the width dim of the ReluNetwork
num_layers: depth of the ReluNetwork
Returns:
stax.serial(ReluNetwork)
"""
blocks = [DenseBlock(hidden_dim)]
for _ in range(num_layers):
blocks += [DenseBlock(hidden_dim)]
blocks += [stax.Dense(latent_dim, channel_axis=1)]
return stax.serial(*blocks)
def lower_path(input_dim):
helf_dim = input_dim//2
# pre_half's rhs
rhs1 = np.identity(helf_dim)
rhs1 = np.pad(rhs1, ((0, 0), (0, helf_dim)))
rhs1 = np.reshape(rhs1, (*rhs1.shape, 1))
# post_half's rhs
rhs2 = np.identity(helf_dim)
rhs2 = np.pad(rhs2, ((helf_dim, 0), (helf_dim, 0)))
rhs2 = np.reshape(rhs2, (*rhs2.shape, 1))
rhs4 = np.identity(helf_dim)
rhs4 = np.pad(rhs4, ((helf_dim, 0), (0, 0)))
rhs4 = np.reshape(rhs4, (*rhs4.shape, 1))
blocks = [
stax.DotGeneral(
rhs = rhs1,
dimension_numbers = (((2,), (1,)), ((), ())),
channel_axis = 1
),
stax.DotGeneral(
rhs = np.array([1]),
dimension_numbers = (((3,), (0,)), ((), ())),
channel_axis = 1
)]
blocks += [ReluNetwork(latent_dim=helf_dim, hidden_dim=helf_dim//2, num_layers=4)]
blocks += [
stax.DotGeneral(
rhs = rhs4,
dimension_numbers = (((2,), (1,)), ((), ())),
channel_axis = 1
),
stax.DotGeneral(
rhs = np.array([1]),
dimension_numbers = (((3,), (0,)), ((), ())),
channel_axis = 1
)]
pre_half = stax.serial(
*blocks
)
post_half = stax.serial(
stax.DotGeneral(
rhs = rhs2,
dimension_numbers = (((2,), (1,)), ((), ())),
channel_axis = 1
),
stax.DotGeneral(
rhs = np.array([1]),
dimension_numbers = (((3,), (0,)), ((), ())),
channel_axis = 1
)
)
return stax.serial(stax.Dense(helf_dim, channel_axis=1),
stax.FanOut(2),
stax.parallel(pre_half, post_half),
stax.FanInSum()
)
def AdditiveCouplingLayer(input_dim, order):
"""the additive couplinglayer in the paper
Args:
nonlinearity: the ReluNetwork
Returns:
stax.serial(AdditiveCouplingLayer)
"""
helf_dim = input_dim//2
rhs_matrix = np.identity(helf_dim)
rhs_matrix = np.pad(rhs_matrix, ((0, helf_dim), (0, helf_dim)))
rhs_matrix = np.reshape(rhs_matrix, (*rhs_matrix.shape, 1))
upper_path = stax.serial(
stax.DotGeneral(
rhs = rhs_matrix,
dimension_numbers = (((2,), (1,)), ((), ())),
channel_axis = 1
),
stax.DotGeneral(
rhs = np.array([1]),
dimension_numbers = (((3,), (0,)), ((), ())),
channel_axis = 1
)
)
return stax.serial(stax.FanOut(2),
stax.parallel(upper_path, lower_path(input_dim)),
stax.FanInSum()
)
def LogisticPriorLoss(fx, y):
return np.mean((0.5*np.sum(np.power(fx, 2), axis=1) + fx.shape[1]*0.5*np.log(2*np.pi)))
# test
x = np.array([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12], [13, 14, 15, 16, 17, 18]])
x = np.reshape(x, (x.shape[0], 1, *x.shape[1:]))
input_dim = x.shape[2] # (B, 1, 4): B is batch size
helf_dim = input_dim//2
init_fn, apply_fn, kernel_fn = lower_path(input_dim=input_dim)
key = random.PRNGKey(1)
_, params = init_fn(key, input_shape=x.shape)
# z_train.dim = x_train.dim
z_train = random.normal(key, x.shape)
x_test = np.array([[1, 2, 3, 4, 5, 6]])
x_test = np.reshape(x_test, (x_test.shape[0], 1, *x_test.shape[1:]))
ntk_train_train = kernel_fn(x, x, 'ntk', channel_axis=1, is_gaussian=True)
ntk_test_train = kernel_fn(x_test, x, 'ntk')
predictor = nt.predict.gradient_descent(LogisticPriorLoss, ntk_train_train, z_train)
Hello @sschoenholz @romanngg Thanks for your kindly reply!!
But I still have some questions.
Does adding Dense
layer is a only solution for this problem?
Since if I add Dense
, the output and output.shape will change.
Base on optimizer.sgd
method, the result may seems like
Above image shows that my architecture only change bottom half Your reply makes large progression to me
Follow up
I observed one weird thing:
upper_path
in AdditiveCouplingLayer
also doesn't have any Dense
layer, why this can work normallySorry for the long delay, a few more observations:
AdditiveCouplingLayer
is not called in the above code sample (only lower_path
), so I imagine it would have the same problem.
In lower_path
, which is tested above, you already have a dense layer in pre_half
, but no dense layer in post_half
. So you could either add a common dense layer as @sschoenholz suggested, or you could also add a dense layer only somewhere in post_half
instead, e.g.
post_half = stax.serial(
stax.DotGeneral(
rhs = rhs2,
dimension_numbers = (((2,), (1,)), ((), ())),
channel_axis = 1
),
stax.DotGeneral(
rhs = np.array([1]),
dimension_numbers = (((3,), (0,)), ((), ())),
channel_axis = 1
),
stax.Dense(helf_dim, channel_axis=1),
)
Note that in the case above, Dense
layer will not change the output shape if you set the out_dim = helf_dim
equal to the number of channels in your input image; it will not affect the pixel structure; but as you've mentioned it will indeed change the outputs themselves). However, out_dim
must be equal to helf_dim
, which is the number of channels output by pre_half
- otherwise you would be asking FanInSum
to add arrays of different shapes, and I think it only worked above because the input channels have size 1
and it was silently broadcasted to helf_dim
when combined.
Similarly, in the infinite-width limit, it's not clear how to avoid adding Dense
, since FanInSum
must add arrays of the same shape. In the infinite width limit, which is invoked when you call kernel_fn
, pre_half
will output infinite-dimensional arrays along the channel axis, since it has dense layers. post_half
, without dense layers, will output finite-dimensional arrays, having as many channels along the channel axis as the input image. Therefore arguably the sum of the two is not well-defined.
In any case, for NTK, if your post_half
branch doesn't contain any trainable parameters, it should not influence the NTK (i.e. if f(params, x) = pre_half(params, x) + post_half(x)
, then NTK(f)(x1, x2, params) = NTK(pre_half)(x1, x2, params)
), so IIUC as a workaround you could just compute the NTK of pre_half
in separation. Alternatively, you could also use nt.empirical_kernel_fn
to compute the empirical NTK, which should work for any function/architecture.
Lmk if this helps!
Hi, when I try to implement two-layer coupling layer like below image. I got this error, do you have any insight to fix it?
The error looks like:
Some direction I had try
is_gaussian
to be Trueoptimizers.sgd
to train my network. It works, but I still need to kernelize itHere is some code can reproduce error:
Many thanks for your kindly reply.