Closed JonathanSum closed 2 years ago
It just turns it into a jax scalar
import jax.numpy as np np.array([1]).reshape().shape
()
This function could be written in a much more efficient way, just wanted to show it in a non-matrix style.
Thx for answering. But It is just "reshape()" in the code, not "reshape().shape"
I was just showing that the shape is () after reshape, i.e. scalar
Hi. I see the reshape function does not take arguments. I guess this is not an error. If so, can someone explain to me what shape it is going to have from the original shape?