Closed bestadcarry closed 3 years ago
You're almost correct, the precise syntax to pass the arguments would be
predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, pattern=(pattern1,pattern1))
nngp_mean, nngp_covariance = predict_fn(x_test=x_test, get='nngp',compute_cov=True, pattern=(pattern2,pattern2))
Note that
1) You pass kernel_fn_test_test_kwargs
to predict_fn
, not gradient_descent_mse_ensemble
. In general, you pass test_test arguments together with x_test
, and train-train arguments together with x_train
.
2) kernel_fn_test_test_kwargs
and kernel_fn_train_train_kwargs
are dictionaries mapping argument names to their values. Here pattern1/2
are the values, but names are pattern
, which is the keyword argument name that kernel_fn
with aggregation layers accepts, hence you want to pass pattern=(pattern1, pattern1)
.
Complete example: https://colab.research.google.com/gist/romanngg/94299dfd47f11c393564bb329ade1e96/aggregate_and_predict_example.ipynb
Lmk if you have other questions!
Thanks for the detailed reply! It works!
init_fn, apply_fn, kernel_fn = stax.serial(
stax.Dense(1),
stax.Aggregate(aggregate_axis=1,batch_axis=0,channel_axis=2), stax.Relu(), #stax.Relu() stax.Erf()
stax.Dense(1),
stax.Aggregate(aggregate_axis=1,batch_axis=0,channel_axis=2), stax.Relu(),
stax.Aggregate(aggregate_axis=1,batch_axis=0,channel_axis=2),
stax.Flatten()
)
By the way, I see you insert a stax.Dense
layer before each activation function. Why is this required? If I delete the dense layer, an error arises: The input to the activation function must be Gaussian, i.e. a random affine transform is required before the activation function
. Based on my understanding, the parameters in the graph layer stax.Aggregate
already have Gaussian initialization.
Many thanks!
Actually our implementation of Aggregate
has no trainable parameters, as you can see here it's just a linear transformation of inputs with the pattern
matrix: https://github.com/google/neural-tangents/blob/f4730831186d242793c1626e142b835239e868d3/neural_tangents/stax.py#L660
so it will preserve Gaussianity - Gaussian inputs remain Gaussian, but non-Gaussian inputs remain non-Gaussian.
Got it, thanks for your fast reply!
Will close this as answered, but please open another issue if you have other questions!
Hi, thanks for this handy library. I have a basic question. I was trying to apply nngp on graph convolutional layers. For example,
This function
nt.predict.gradient_descent_mse_ensemble
report errorcannot unpack non-iterable NoneType object
. It seems like I didn't use these two keywordskernel_fn_train_train_kwargs, kernel_fn_test_test_kwargs
correctly. Can you give me some suggestions?Thanks again!