Closed tancik closed 4 years ago
Thanks again for reporting this and, sorry for the super long delay - fixed in https://github.com/google/neural-tangents/commit/a76bbb494f19af4f8c9c1a1b0904e91b105f769e !
Example of correct shapes using the new API: https://colab.research.google.com/gist/romanngg/19d98f5bc40714711d02b06a10045757/predict_shapes.ipynb
When calling
predict
innt.predict.gradient_descent
with variables of the following dimensions,g_dd [256,256]
g_dt [256,256]
fx_train [256,1]
fx_test [256,1]
The tuple of predictions are([2,256], [0,256])
. Running the same values innt.predict_gradient_descent_mse
returns predictions with dimensions `([1,256], [1,256]). I am curious if there might be a bug in the following slicing code - https://github.com/google/neural-tangents/blob/38e9ba906f9cac7a6daa270c28b9d0c16f5335be/neural_tangents/predict.py#L278Also the example documentation seem to be outdated: