google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.29k stars 227 forks source link

predict.gradient_descent wrong prediction dimensions #31

Closed tancik closed 4 years ago

tancik commented 4 years ago

When calling predict in nt.predict.gradient_descent with variables of the following dimensions, g_dd [256,256] g_dt [256,256] fx_train [256,1] fx_test [256,1] The tuple of predictions are ([2,256], [0,256]). Running the same values in nt.predict_gradient_descent_mse returns predictions with dimensions `([1,256], [1,256]). I am curious if there might be a bug in the following slicing code - https://github.com/google/neural-tangents/blob/38e9ba906f9cac7a6daa270c28b9d0c16f5335be/neural_tangents/predict.py#L278

Also the example documentation seem to be outdated:

romanngg commented 4 years ago

Thanks again for reporting this and, sorry for the super long delay - fixed in https://github.com/google/neural-tangents/commit/a76bbb494f19af4f8c9c1a1b0904e91b105f769e !

Example of correct shapes using the new API: https://colab.research.google.com/gist/romanngg/19d98f5bc40714711d02b06a10045757/predict_shapes.ipynb