google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

Question: is this the correct way to combine stax.Aggregate and nt.predict? #118

Closed bestadcarry closed 3 years ago

bestadcarry commented 3 years ago

Hi, thanks for this handy library. I have a basic question. I was trying to apply nngp on graph convolutional layers. For example,

# graph-nngp
init_fn, apply_fn, kernel_fn = stax.serial(
  stax.Aggregate(aggregate_axis=1,batch_axis=0,channel_axis=2), stax.Relu(),  #stax.Relu()  stax.Erf()
  stax.Aggregate(aggregate_axis=1,batch_axis=0,channel_axis=2), stax.Relu(),
  stax.Aggregate(aggregate_axis=1,batch_axis=0,channel_axis=2)
  )

# pattern1 is the pattern of training dataset, pattern2 is the pattern of test dataset
pattern1 = jax.numpy.array(np.broadcast_to(graph_matrix, (n1, p, p)))
pattern2 = jax.numpy.array(np.broadcast_to(graph_matrix, (n2, p, p)))

#inference
predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, kernel_fn_train_train_kwargs=(pattern1,pattern1),kernel_fn_test_test_kwargs=(pattern2,pattern2))  
nngp_mean, nngp_covariance = predict_fn(x_test=x_text, get='nngp',compute_cov=True)

This function nt.predict.gradient_descent_mse_ensemble report error cannot unpack non-iterable NoneType object. It seems like I didn't use these two keywords kernel_fn_train_train_kwargs, kernel_fn_test_test_kwargs correctly. Can you give me some suggestions?

Thanks again!

romanngg commented 3 years ago

You're almost correct, the precise syntax to pass the arguments would be

predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, pattern=(pattern1,pattern1))  
nngp_mean, nngp_covariance = predict_fn(x_test=x_test, get='nngp',compute_cov=True, pattern=(pattern2,pattern2))

Note that 1) You pass kernel_fn_test_test_kwargs to predict_fn, not gradient_descent_mse_ensemble. In general, you pass test_test arguments together with x_test, and train-train arguments together with x_train. 2) kernel_fn_test_test_kwargs and kernel_fn_train_train_kwargs are dictionaries mapping argument names to their values. Here pattern1/2 are the values, but names are pattern, which is the keyword argument name that kernel_fn with aggregation layers accepts, hence you want to pass pattern=(pattern1, pattern1).

Complete example: https://colab.research.google.com/gist/romanngg/94299dfd47f11c393564bb329ade1e96/aggregate_and_predict_example.ipynb

Lmk if you have other questions!

bestadcarry commented 3 years ago

Thanks for the detailed reply! It works!

init_fn, apply_fn, kernel_fn = stax.serial(
  stax.Dense(1),
  stax.Aggregate(aggregate_axis=1,batch_axis=0,channel_axis=2), stax.Relu(),  #stax.Relu()  stax.Erf()
  stax.Dense(1),
  stax.Aggregate(aggregate_axis=1,batch_axis=0,channel_axis=2), stax.Relu(),
  stax.Aggregate(aggregate_axis=1,batch_axis=0,channel_axis=2),
  stax.Flatten()
  )

By the way, I see you insert a stax.Dense layer before each activation function. Why is this required? If I delete the dense layer, an error arises: The input to the activation function must be Gaussian, i.e. a random affine transform is required before the activation function. Based on my understanding, the parameters in the graph layer stax.Aggregate already have Gaussian initialization.

Many thanks!

romanngg commented 3 years ago

Actually our implementation of Aggregate has no trainable parameters, as you can see here it's just a linear transformation of inputs with the pattern matrix: https://github.com/google/neural-tangents/blob/f4730831186d242793c1626e142b835239e868d3/neural_tangents/stax.py#L660

so it will preserve Gaussianity - Gaussian inputs remain Gaussian, but non-Gaussian inputs remain non-Gaussian.

bestadcarry commented 3 years ago

Got it, thanks for your fast reply!

romanngg commented 3 years ago

Will close this as answered, but please open another issue if you have other questions!