Open lishulincug opened 2 years ago
upgrade your python version to python==3.8.
np.argmax keepdims is new feature from numpy 1.22. numpy 1.22 has dropped support of python 3.7. (https://github.com/numpy/numpy/pull/19665)
BUT conda's python version in readme is 3.6. so you have to version up your python
Replace it with
ind_t = np.argmax(weights_b, axis=-1)
ind = np.expand_dims(ind_t, axis=-1) #[N,4,1]
I run python stage3.py , but get this error:
ind = np.argmax(weights_b, axis=-1, keepdims=True) #[N,4,1] TypeError: argmax() got an unexpected keyword argument 'keep_dims'