ins-amu / vbjax

A nascent Jax-based package for virtual brain modeling.
Apache License 2.0
7 stars 2 forks source link

the second example, Parallel parameter space exploration, is platform-dependent #46

Open gouchangjiang opened 9 months ago

gouchangjiang commented 9 months ago

There are some small issues:

  1. the installation should be pip install ."[dev]", we need the quotes.
  2. the second example, Parallel parameter space exploration, is platform-dependent. I am running it on a 12-core M2 Pro, it shows that the size does not match. I am trying to change the shape to make it divisible by 12, but it incurs new problems. Details listed below. By the way, to disable GPU on Mac, we can set the environment variable JAX_PLATFORMS=cpu

    InconclusiveDimensionOperation Traceback (most recent call last) /Users/cjgou/cjgou/vbjax-example/examples.ipynb Cell 5 line 6 4 logks, etas = np.mgrid[-9.0:0.0:16j, -4.0:-6.0:32j] 5 pars = np.c[np.exp(log_ks.ravel()),np.ones(512)*sig_i, etas.ravel()].T.copy() ----> 6 pars = pars.reshape((3, vb.cores)) 7 result = run_batches(pars) 8 pl.imshow(result.reshape((16, 32)), vmin=0.2, vmax=0.7)

File ~/.virtualenvs/vbjax/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:145, in _reshape(a, order, *args) 143 newshape = _compute_newshape(a, args[0] if len(args) == 1 else args) 144 if order == "C": --> 145 return lax.reshape(a, newshape, None) 146 elif order == "F": 147 dims = list(range(a.ndim)[::-1])

File ~/.virtualenvs/vbjax/lib/python3.10/site-packages/jax/_src/lax/lax.py:857, in reshape(operand, new_sizes, dimensions) 854 else: 855 dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes) --> 857 return reshape_p.bind( 858 operand, *dyn_shape, new_sizes=tuple(static_new_sizes), 859 dimensions=None if dims is None or same_dims else dims)

File ~/.virtualenvs/vbjax/lib/python3.10/site-packages/jax/_src/core.py:380, in Primitive.bind(self, *args, *params) 377 def bind(self, args, **params): ... 1850 if sz1 % sz2: -> 1851 raise InconclusiveDimensionOperation(f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}") 1852 return sz1 // sz2

InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (3, 512) and (3, 12)

maedoc commented 9 months ago

hi! thanks for trying out the examples,

pip install ."[dev]", we need the quotes

yeah I guess only Bash is ok with it, but neither zsh (the default on macOS) or cmd on Windows like it.

InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (3, 512) and (3, 12)

indeed, the example code is dependent on the number of cores present. I'll fix this shortly.

GPU on Mac, we can set the environment variable JAX_PLATFORMS=cpu

thanks, are you using the experimental Metal backend for Jax on your Mac?

gouchangjiang commented 9 months ago

As you mentioned, a third of testing cases failed on jax-metal. I tried it again with the latest jax-metal 0.0.4, the same, many tests failed. So, I switched to CPU, the default is GPU, to disable the GPU, we need to set the environment variable JAX_PLATFORMS=cpu.