Open AmoArt opened 2 years ago
jax.experimental.maps does have "mesh" as long as you have jax<=0.3.7:
❯ python3.9
Python 3.9.13 (main, Jun 8 2022, 09:45:57)
[GCC 11.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> import jaxlib
>>> jax.__version__
'0.2.12'
>>> jaxlib.__version__
'0.1.68'
>>> jax.experimental.maps.mesh
<function mesh at 0x7f6346581940>
>>> from jax.experimental import maps
>>> maps.mesh
<function mesh at 0x7f6346581940>
In the line 461 ' with maps.mesh(devices, ("dp", "mp")):' should be written as ' with maps.Mesh(devices, ("dp", "mp")):' otherwise it gives error that jax.experimental.maps do not have attribute called mesh.