Closed alemi closed 1 year ago
Thank you for your report. But I can successfully run this model (the latest version of BrainPy). Could you please tell me what version of BrainPy are you using?
We have fixed a compatible issue of JAX at previous time. #431
May a update of BrainPy can solve this issue?
pip install brainpy -U
Thank you for your reply. My code was giving the error even with the latest version of BrainPy pip install brainpy -U pip install brainpy-datasets -U
But when included the following,
pip install brainpylib -U
the error disappeared.
There is an error when I run the file highdim_RNN_Analysis.py. The error is on the following line when computing the Jacobian:
Here is the error message:
AttributeError Traceback (most recent call last) /usr/local/lib/python3.10/dist-packages/jax/_src/core.py in getattr(self, name) 696 try: --> 697 attr = getattr(self.aval, name) 698 except AttributeError as err:
AttributeError: 'ShapedArray' object has no attribute 'split'
The above exception was the direct cause of the following exception:
AttributeError Traceback (most recent call last) 7 frames in <cell line: 179>()
177 # Computing the Jacobian and Plot distribution of eigenvalues
178 # ---
--> 179 finder.compute_jacobians({'h': finder._fixed_points['h'][:20]}, plot=True, num_col=5)
/usr/local/lib/python3.10/dist-packages/brainpy/_src/analysis/highdim/slow_points.py in compute_jacobians(self, points, stack_dict_var, plot, num_col, len_col, len_row) 578 579 # get Jacobian matrix --> 580 jacobian = self._get_f_jocabian(stack_dict_var)(points) 581 582 # visualization
/usr/local/lib/python3.10/dist-packages/brainpy/_src/analysis/highdim/slow_points.py in jacobian_func(x) 797 if isinstance(self.target, DynamicalSystem): 798 def jacobian_func(x): --> 799 r = f_jac(x) 800 for k, v in self.excluded_vars.items(): 801 v.value = self.excluded_data[k]
/usr/local/lib/python3.10/dist-packages/brainpy/_src/math/object_transform/autograd.py in call(self, *args, **kwargs) 216 ) 217 else: --> 218 rets = jax.eval_shape( 219 self._transform, 220 [v.value for v in self._grad_vars], # variables for gradients
/usr/local/lib/python3.10/dist-packages/brainpy/_src/math/object_transform/autograd.py in jacfun(*args, *kwargs) 478 y, pullback = _vjp(f_partial, dyn_args, has_aux=False) 479 tree_map(partial(_check_output_dtype_jacrev, holomorphic), y) --> 480 jac = vmap(pullback)(_std_basis(y)) 481 jac = jac[0] if isinstance(argnums, int) else jac 482 example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
/usr/local/lib/python3.10/dist-packages/brainpy/_src/math/object_transform/autograd.py in _std_basis(pytree) 459 dtype = dtypes.result_type(*leaves) 460 flat_basis = jax.numpy.eye(ndim, dtype=dtype) --> 461 return _unravel_array_into_pytree(pytree, 1, flat_basis) 462 463
/usr/local/lib/python3.10/dist-packages/brainpy/_src/math/object_transform/autograd.py in _unravel_array_into_pytree(pytree, axis, arr, is_leaf) 449 axis = axis % arr.ndim 450 shapes = [arr.shape[:axis] + np.shape(l) + arr.shape[axis + 1:] for l in leaves] --> 451 parts = arr.split(np.cumsum(safe_map(np.size, leaves[:-1])), axis) 452 reshaped_parts = [x.reshape(shape) for x, shape in zip(parts, shapes)] 453 return tree_unflatten(treedef, reshaped_parts, )
/usr/local/lib/python3.10/dist-packages/jax/_src/core.py in getattr(self, name) 697 attr = getattr(self.aval, name) 698 except AttributeError as err: --> 699 raise AttributeError( 700 f"{self.class.name} has no attribute {name}" 701 ) from err
AttributeError: DynamicJaxprTracer has no attribute split