Open gouchangjiang opened 9 months ago
hi! thanks for trying out the examples,
pip install ."[dev]", we need the quotes
yeah I guess only Bash is ok with it, but neither zsh (the default on macOS) or cmd on Windows like it.
InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (3, 512) and (3, 12)
indeed, the example code is dependent on the number of cores present. I'll fix this shortly.
GPU on Mac, we can set the environment variable JAX_PLATFORMS=cpu
thanks, are you using the experimental Metal backend for Jax on your Mac?
As you mentioned, a third of testing cases failed on jax-metal. I tried it again with the latest jax-metal 0.0.4, the same, many tests failed. So, I switched to CPU, the default is GPU, to disable the GPU, we need to set the environment variable JAX_PLATFORMS=cpu.
There are some small issues:
the second example, Parallel parameter space exploration, is platform-dependent. I am running it on a 12-core M2 Pro, it shows that the size does not match. I am trying to change the shape to make it divisible by 12, but it incurs new problems. Details listed below. By the way, to disable GPU on Mac, we can set the environment variable JAX_PLATFORMS=cpu
InconclusiveDimensionOperation Traceback (most recent call last) /Users/cjgou/cjgou/vbjax-example/examples.ipynb Cell 5 line 6 4 logks, etas = np.mgrid[-9.0:0.0:16j, -4.0:-6.0:32j] 5 pars = np.c[np.exp(log_ks.ravel()),np.ones(512)*sig_i, etas.ravel()].T.copy() ----> 6 pars = pars.reshape((3, vb.cores)) 7 result = run_batches(pars) 8 pl.imshow(result.reshape((16, 32)), vmin=0.2, vmax=0.7)
File ~/.virtualenvs/vbjax/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:145, in _reshape(a, order, *args) 143 newshape = _compute_newshape(a, args[0] if len(args) == 1 else args) 144 if order == "C": --> 145 return lax.reshape(a, newshape, None) 146 elif order == "F": 147 dims = list(range(a.ndim)[::-1])
File ~/.virtualenvs/vbjax/lib/python3.10/site-packages/jax/_src/lax/lax.py:857, in reshape(operand, new_sizes, dimensions) 854 else: 855 dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes) --> 857 return reshape_p.bind( 858 operand, *dyn_shape, new_sizes=tuple(static_new_sizes), 859 dimensions=None if dims is None or same_dims else dims)
File ~/.virtualenvs/vbjax/lib/python3.10/site-packages/jax/_src/core.py:380, in Primitive.bind(self, *args, *params) 377 def bind(self, args, **params): ... 1850 if sz1 % sz2: -> 1851 raise InconclusiveDimensionOperation(f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}") 1852 return sz1 // sz2
InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (3, 512) and (3, 12)