google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.38k stars 257 forks source link

Running on M1 MacBook with Python 3.9? #71

Closed slerman12 closed 3 years ago

slerman12 commented 3 years ago

I get this error from running the Pytorch collab code:

[2021-10-14 17:20:52,534][absl][INFO] - Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
[2021-10-14 17:20:52,535][absl][INFO] - Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
[2021-10-14 17:20:52,536][absl][INFO] - Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
[2021-10-14 17:20:52,536][absl][WARNING] - No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
LLVM ERROR: 64-bit code requested on a subtarget that doesn't support it!
erikfrey commented 3 years ago

Hi Sam - indeed, M1 is not yet supported by JAX, but there is a lot of activity around this issue:

https://github.com/google/jax/issues/5501

Poking through that issue it does seem some folks have gotten JAX to work by either running in emulation mode or building custom wheels, so that might be worth a shot.

EelcoHoogendoorn commented 3 years ago

Heh. Here I was thinking of buying the next macbook update; but I was worried about such shenanigans. Only just mentally healed from the horrors of the 32/64 bit switch; going to sit out this one for a while.

erwincoumans commented 3 years ago

Brax can work with the custom wheel for jax 0.1.70 (not latest 0.1.73) but it is a hassle to install on M1 arm. You may want to install a virtual machine (Ubuntu on Parallels under M1 works great!) or use colab.

ppo and sac tests fails with latest jax, so use 0.2.20 (see report) Running all the tests on M1 with jax cpu, only CapsuleTest.test_capsule_hits_ground test fails, all others pass.

Here are some snippets from my history today:

brew install python@3.9
echo 'export PATH="/opt/homebrew/opt/python@3.9/bin:$PATH"' >> ~/.zshrc
source ~/.zshrc
pip3 install cython pybind11
pip3 install --no-binary :all: --no-use-pep517 numpy
export OPENBLAS=/opt/homebrew/opt/openblas/lib/
export CPPFLAGS="-I/opt/homebrew/opt/openblas/include"
export LDFLAGS="-L/opt/homebrew/opt/openblas/lib"
pip3 install pythran
pip3 install --no-binary :all: --no-use-pep517 scipy
pip3 install protobuf
pip3 install gym
python3 -m pip install jaxlib==0.1.70 -f "https://dfm.io/custom-wheels/jaxlib/index.html"
pip3 install jax==0.2.20
pip3 install --no-binary :all: --no-use-pep517 brax
pip3 install transforms3d

using pip list:

erwincoumans@ErwinMacbookM1 tests % pip3 list
Package                Version  Location
---------------------- -------- ----------------------------
absl-py                0.13.0
beniget                0.4.1
brax                   0.0.5    /Users/erwincoumans/dev/brax
chex                   0.0.8
cloudpickle            2.0.0
cycler                 0.10.0
Cython                 0.29.24
decorator              5.1.0
dm-tree                0.1.6
flatbuffers            2.0
flax                   0.3.5
gast                   0.5.2
gym                    0.21.0
jax                    0.2.20
jaxlib                 0.1.70
kiwisolver             1.3.2
matplotlib             3.4.3
msgpack                1.0.2
numpy                  1.21.2
opt-einsum             3.3.0
optax                  0.0.9
Pillow                 8.4.0
pip                    21.2.4
ply                    3.11
protobuf               3.19.0
pybind11               2.8.0
pybullet               3.1.3
pyparsing              2.4.7
python-dateutil        2.8.2
pythran                0.10.0
scipy                  1.7.1
setuptools             57.4.0
six                    1.16.0
tensorflow-probability 0.14.1
toolz                  0.11.1
transforms3d           0.3.1
typing-extensions      3.10.0.2
wheel                  0.37.0

and then

erwincoumans@ErwinMacbookM1 tests % python3 urdf_test.py     
/opt/homebrew/lib/python3.9/site-packages/jax/_src/lib/__init__.py:32: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
Running tests under Python 3.9.7: /opt/homebrew/opt/python@3.9/bin/python3.9
[ RUN      ] UrdfTest.test_build
[       OK ] UrdfTest.test_build
----------------------------------------------------------------------
Ran 1 test in 0.000s

OK
erwincoumans@ErwinMacbookM1 tests % python3 env_test.py      
/opt/homebrew/lib/python3.9/site-packages/jax/_src/lib/__init__.py:32: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
Running tests under Python 3.9.7: /opt/homebrew/opt/python@3.9/bin/python3.9
[ RUN      ] EnvTest.testSpeed0 ('ant', 1000)
I1020 21:11:23.405458 4372266304 xla_bridge.py:231] Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: 
I1020 21:11:23.405616 4372266304 xla_bridge.py:231] Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: "cuda". Available platform names are: Host Interpreter
I1020 21:11:23.405805 4372266304 xla_bridge.py:231] Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.
W1020 21:11:23.405834 4372266304 xla_bridge.py:236] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
I1020 21:11:49.438776 4372266304 env_test.py:74] ant SPS 71764.89 [72224.61471279911, 72096.87745458583, 72142.22719112478, 71517.88816708465, 70842.81148642051]
[       OK ] EnvTest.testSpeed0 ('ant', 1000)
[ RUN      ] EnvTest.testSpeed1 ('fetch', 1000)
/opt/homebrew/lib/python3.9/site-packages/jax/_src/ops/scatter.py:382: DeprecationWarning: index_update is deprecated. Use x.at[idx].set(y) instead.
  warnings.warn("index_update is deprecated. Use x.at[idx].set(y) instead.",
/opt/homebrew/lib/python3.9/site-packages/jax/_src/ops/scatter.py:382: DeprecationWarning: index_update is deprecated. Use x.at[idx].set(y) instead.
  warnings.warn("index_update is deprecated. Use x.at[idx].set(y) instead.",
I1020 21:12:23.984752 4372266304 env_test.py:74] fetch SPS 72623.71 [73695.32873411497, 72104.74974307347, 72102.357857387, 73242.55239500025, 71973.55669716113]
[       OK ] EnvTest.testSpeed1 ('fetch', 1000)
----------------------------------------------------------------------
Ran 2 tests in 60.626s

OK

erwincoumans@ErwinMacbookM1 tests % python3 physics_test.py 
/opt/homebrew/lib/python3.9/site-packages/jax/_src/lib/__init__.py:32: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
Running tests under Python 3.9.7: /opt/homebrew/opt/python@3.9/bin/python3.9
[ RUN      ] Actuator1DTest.test_1d_angle_actuator0 (15.0)
I1020 21:21:13.422533 4338416960 xla_bridge.py:231] Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: 
I1020 21:21:13.422983 4338416960 xla_bridge.py:231] Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
I1020 21:21:13.423403 4338416960 xla_bridge.py:231] Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.
W1020 21:21:13.423434 4338416960 xla_bridge.py:236] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[       OK ] Actuator1DTest.test_1d_angle_actuator0 (15.0)
[ RUN      ] Actuator1DTest.test_1d_angle_actuator1 (30.0)
[       OK ] Actuator1DTest.test_1d_angle_actuator1 (30.0)
[ RUN      ] Actuator1DTest.test_1d_angle_actuator2 (45.0)
[       OK ] Actuator1DTest.test_1d_angle_actuator2 (45.0)
[ RUN      ] Actuator1DTest.test_1d_angle_actuator3 (90.0)
[       OK ] Actuator1DTest.test_1d_angle_actuator3 (90.0)
[ RUN      ] Actuator2DTest.test_2d_angle_actuator0 (15.0, 30.0)
[       OK ] Actuator2DTest.test_2d_angle_actuator0 (15.0, 30.0)
[ RUN      ] Actuator2DTest.test_2d_angle_actuator1 (45.0, 90.5)
[       OK ] Actuator2DTest.test_2d_angle_actuator1 (45.0, 90.5)
[ RUN      ] Actuator2DTest.test_2d_angle_actuator2 (-120, 60.0)
[       OK ] Actuator2DTest.test_2d_angle_actuator2 (-120, 60.0)
[ RUN      ] Actuator2DTest.test_2d_angle_actuator3 (30.0, -120.0)
[       OK ] Actuator2DTest.test_2d_angle_actuator3 (30.0, -120.0)
[ RUN      ] Actuator2DTest.test_2d_angle_actuator4 (-150, -130)
[       OK ] Actuator2DTest.test_2d_angle_actuator4 (-150, -130)
[ RUN      ] Actuator2DTest.test_2d_angle_actuator5 (130, 165)
[       OK ] Actuator2DTest.test_2d_angle_actuator5 (130, 165)
[ RUN      ] Actuator3DTest.test_3d_torque_actuator0 ((15, 15, 15), [(1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 1, 1)])
[       OK ] Actuator3DTest.test_3d_torque_actuator0 ((15, 15, 15), [(1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 1, 1)])
[ RUN      ] Actuator3DTest.test_3d_torque_actuator1 ((35, 40, 75), [(1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 1, 1)])
[       OK ] Actuator3DTest.test_3d_torque_actuator1 ((35, 40, 75), [(1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 1, 1)])
[ RUN      ] Actuator3DTest.test_3d_torque_actuator2 ((80, 45, 30), [(1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 1, 1)])
[       OK ] Actuator3DTest.test_3d_torque_actuator2 ((80, 45, 30), [(1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 1, 1)])
[ RUN      ] BodyTest.test_projectile_motion
[       OK ] BodyTest.test_projectile_motion
[ RUN      ] BoxTest.test_box_hits_ground
[       OK ] BoxTest.test_box_hits_ground
[ RUN      ] BoxTest.test_box_slide
[       OK ] BoxTest.test_box_slide
[ RUN      ] CapsuleTest.test_capsule_hits_capsule
[       OK ] CapsuleTest.test_capsule_hits_capsule
[ RUN      ] CapsuleTest.test_capsule_hits_ground
[  FAILED  ] CapsuleTest.test_capsule_hits_ground
[ RUN      ] HeightMapTest.test_box_stays_on_heightMap
[       OK ] HeightMapTest.test_box_stays_on_heightMap
[ RUN      ] JointTest.test_pendulum_period0 (2.0, 0.125, 0.0625)
[       OK ] JointTest.test_pendulum_period0 (2.0, 0.125, 0.0625)
[ RUN      ] JointTest.test_pendulum_period1 (5.0, 0.125, 0.03125)
[       OK ] JointTest.test_pendulum_period1 (5.0, 0.125, 0.03125)
[ RUN      ] JointTest.test_pendulum_period2 (1.0, 0.0625, 0.1)
[       OK ] JointTest.test_pendulum_period2 (1.0, 0.0625, 0.1)
[ RUN      ] SphereTest.test_sphere_hits_ground
[       OK ] SphereTest.test_sphere_hits_ground
======================================================================
FAIL: test_capsule_hits_ground (__main__.CapsuleTest)
CapsuleTest.test_capsule_hits_ground
A capsule falls onto the ground and stops.
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/erwincoumans/dev/brax/brax/tests/physics_test.py", line 180, in test_capsule_hits_ground
    self.assertAlmostEqual(qp.pos[2, 2], 0.25, 2)  # rolls to side from y rot
AssertionError: DeviceArray(0.33739132, dtype=float32) != 0.25 within 2 places (DeviceArray(0.08739132, dtype=float32) difference)

rwincoumans@ErwinMacbookM1 tests % python3 ars_test.py 
/opt/homebrew/lib/python3.9/site-packages/jax/_src/lib/__init__.py:32: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
Running tests under Python 3.9.7: /opt/homebrew/opt/python@3.9/bin/python3.9
[ RUN      ] ARSTest.testModelEncoding0 (True)
I1020 21:20:45.580332 4373380416 xla_bridge.py:231] Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: 
I1020 21:20:45.580879 4373380416 xla_bridge.py:231] Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: "cuda". Available platform names are: Host Interpreter
I1020 21:20:45.581276 4373380416 xla_bridge.py:231] Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.
W1020 21:20:45.581306 4373380416 xla_bridge.py:236] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
I1020 21:20:45.581456 4373380416 ars.py:94] Device count: 1, process count: 1 (id 0), local device count: 1, devices to be used count: 1
I1020 21:20:45.970596 4373380416 ars.py:231] starting iteration 0 0.39330220222473145
I1020 21:20:46.131521 4373380416 ars.py:265] Step 0 metrics {'eval/episode_reward': DeviceArray(0., dtype=float32), 'eval/completed_episodes': DeviceArray(128., dtype=float32), 'eval/episode_length': DeviceArray(128., dtype=float32), 'train/completed_episodes': 0, 'speed/sps': 0, 'speed/eval_sps': 104354.10925679021, 'speed/training_walltime': 0, 'speed/eval_walltime': 0.15700101852416992, 'speed/timestamp': 0}
I1020 21:20:46.768952 4373380416 ars.py:231] starting iteration 1 1.1916530132293701
I1020 21:20:46.769227 4373380416 ars.py:265] Step 15360 metrics {'eval/episode_reward': DeviceArray(0., dtype=float32), 'train/eval_scores_mean': DeviceArray(0., dtype=float32), 'train/eval_scores_std': DeviceArray(0., dtype=float32), 'train/params_norm': DeviceArray(nan, dtype=float32), 'train/reward_std': DeviceArray(0., dtype=float32), 'train/weights': DeviceArray(0.33333334, dtype=float32), 'eval/completed_episodes': DeviceArray(128., dtype=float32), 'eval/episode_length': DeviceArray(128., dtype=float32), 'train/completed_episodes': 120, 'speed/sps': 24107.387278162638, 'speed/eval_sps': 99162304.09235209, 'speed/training_walltime': 0.6371500492095947, 'speed/eval_walltime': 0.1571650505065918, 'speed/timestamp': 0.6371500492095947}
[       OK ] ARSTest.testModelEncoding0 (True)
[ RUN      ] ARSTest.testModelEncoding1 (False)
I1020 21:20:46.850713 4373380416 ars.py:94] Device count: 1, process count: 1 (id 0), local device count: 1, devices to be used count: 1
I1020 21:20:46.853812 4373380416 ars.py:231] starting iteration 0 0.003113985061645508
I1020 21:20:47.007452 4373380416 ars.py:265] Step 0 metrics {'eval/episode_reward': DeviceArray(0., dtype=float32), 'eval/completed_episodes': DeviceArray(128., dtype=float32), 'eval/episode_length': DeviceArray(128., dtype=float32), 'train/completed_episodes': 0, 'speed/sps': 0, 'speed/eval_sps': 106690.19820742955, 'speed/training_walltime': 0, 'speed/eval_walltime': 0.1535630226135254, 'speed/timestamp': 0}
I1020 21:20:47.632093 4373380416 ars.py:231] starting iteration 1 0.7813839912414551
I1020 21:20:47.632388 4373380416 ars.py:265] Step 15360 metrics {'eval/episode_reward': DeviceArray(0., dtype=float32), 'train/eval_scores_mean': DeviceArray(0., dtype=float32), 'train/eval_scores_std': DeviceArray(0., dtype=float32), 'train/params_norm': DeviceArray(nan, dtype=float32), 'train/reward_std': DeviceArray(0., dtype=float32), 'train/weights': DeviceArray(0.33333334, dtype=float32), 'eval/completed_episodes': DeviceArray(128., dtype=float32), 'eval/episode_length': DeviceArray(128., dtype=float32), 'train/completed_episodes': 120, 'speed/sps': 24601.19768363211, 'speed/eval_sps': 91019174.48476821, 'speed/training_walltime': 0.6243607997894287, 'speed/eval_walltime': 0.15374112129211426, 'speed/timestamp': 0.6243607997894287}
[       OK ] ARSTest.testModelEncoding1 (False)
----------------------------------------------------------------------
Ran 2 tests in 2.063s

OK

erwincoumans@ErwinMacbookM1 tests % python3 testvmap.py 
/opt/homebrew/lib/python3.9/site-packages/jax/_src/lib/__init__.py:32: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[[ 2.  2.  2.  2.  2.]
 [-2. -2. -2. -2. -2.]
 [ 0.  0.  0.  0.  0.]]

erwincoumans@ErwinMacbookM1 tests % python3 sac_test.py 
/opt/homebrew/lib/python3.9/site-packages/jax/lib/__init__.py:31: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
Running tests under Python 3.9.7: /opt/homebrew/opt/python@3.9/bin/python3.9
[ RUN      ] SACTest.testModelEncoding0 (True)
I1020 21:29:40.672581 4341775680 xla_bridge.py:236] Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: 
I1020 21:29:40.672731 4341775680 xla_bridge.py:236] Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: "cuda". Available platform names are: Host Interpreter
I1020 21:29:40.672922 4341775680 xla_bridge.py:236] Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.
W1020 21:29:40.672973 4341775680 xla_bridge.py:240] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
I1020 21:29:40.673007 4341775680 sac.py:164] Device count: 1, process count: 1 (id 0), local device count: 1, devices to be used count: 1
I1020 21:29:41.531414 4341775680 sac.py:505] step 0
I1020 21:29:42.099275 4341775680 sac.py:538] {'eval/episode_reward': DeviceArray(70.57749, dtype=float32), 'eval/completed_episodes': DeviceArray(128., dtype=float32), 'eval/avg_episode_length': DeviceArray(128., dtype=float32), 'speed/sps': 0, 'speed/eval_sps': 29770.471356100352, 'speed/training_walltime': 0, 'speed/eval_walltime': 0.5503418445587158, 'training/grad_updates': DeviceArray(0., dtype=float32)}
I1020 21:30:32.333238 4341775680 sac.py:505] step 18304
I1020 21:30:32.364439 4341775680 sac.py:538] {'eval/episode_reward': DeviceArray(91.18119, dtype=float32), 'training/actor_loss': DeviceArray(-5.891935, dtype=float32), 'training/alpha': DeviceArray(0.31349677, dtype=float32), 'training/alpha_loss': DeviceArray(0.3696652, dtype=float32), 'training/buffer_current_position': DeviceArray(13312., dtype=float32), 'training/buffer_current_size': DeviceArray(13312., dtype=float32), 'training/critic_loss': DeviceArray(0.00215272, dtype=float32), 'eval/completed_episodes': DeviceArray(128., dtype=float32), 'eval/avg_episode_length': DeviceArray(128., dtype=float32), 'speed/sps': DeviceArray(369.54416, dtype=float32), 'speed/eval_sps': 728245.9940442758, 'speed/training_walltime': 50.23360586166382, 'speed/eval_walltime': 0.5728359222412109, 'training/grad_updates': DeviceArray(10112., dtype=float32)}
I1020 21:30:32.365164 4341775680 sac.py:583] total steps: 18304.0
[       OK ] SACTest.testModelEncoding0 (True)
[ RUN      ] SACTest.testModelEncoding1 (False)
I1020 21:30:32.624608 4341775680 sac.py:164] Device count: 1, process count: 1 (id 0), local device count: 1, devices to be used count: 1
I1020 21:30:32.644865 4341775680 sac.py:505] step 0
I1020 21:30:33.174823 4341775680 sac.py:538] {'eval/episode_reward': DeviceArray(70.57749, dtype=float32), 'eval/completed_episodes': DeviceArray(128., dtype=float32), 'eval/avg_episode_length': DeviceArray(128., dtype=float32), 'speed/sps': 0, 'speed/eval_sps': 30942.467705804775, 'speed/training_walltime': 0, 'speed/eval_walltime': 0.5294966697692871, 'training/grad_updates': DeviceArray(0., dtype=float32)}
I1020 21:31:23.568519 4341775680 sac.py:505] step 18304
I1020 21:31:23.587382 4341775680 sac.py:538] {'eval/episode_reward': DeviceArray(91.18119, dtype=float32), 'training/actor_loss': DeviceArray(-5.891935, dtype=float32), 'training/alpha': DeviceArray(0.31349677, dtype=float32), 'training/alpha_loss': DeviceArray(0.3696652, dtype=float32), 'training/buffer_current_position': DeviceArray(13312., dtype=float32), 'training/buffer_current_size': DeviceArray(13312., dtype=float32), 'training/critic_loss': DeviceArray(0.00215272, dtype=float32), 'eval/completed_episodes': DeviceArray(128., dtype=float32), 'eval/avg_episode_length': DeviceArray(128., dtype=float32), 'speed/sps': DeviceArray(368.29453, dtype=float32), 'speed/eval_sps': 932345.7620275147, 'speed/training_walltime': 50.39337182044983, 'speed/eval_walltime': 0.5470688343048096, 'training/grad_updates': DeviceArray(10112., dtype=float32)}
I1020 21:31:23.588174 4341775680 sac.py:583] total steps: 18304.0
[       OK ] SACTest.testModelEncoding1 (False)
[ RUN      ] SACTest.testTrain
I1020 21:31:23.616513 4341775680 sac.py:164] Device count: 1, process count: 1 (id 0), local device count: 1, devices to be used count: 1
I1020 21:31:23.717451 4341775680 sac.py:505] step 0
I1020 21:31:24.277982 4341775680 sac.py:538] {'eval/episode_reward': DeviceArray(70.57749, dtype=float32), 'eval/completed_episodes': DeviceArray(128., dtype=float32), 'eval/avg_episode_length': DeviceArray(128., dtype=float32), 'speed/sps': 0, 'speed/eval_sps': 29255.98746989552, 'speed/training_walltime': 0, 'speed/eval_walltime': 0.5600199699401855, 'training/grad_updates': DeviceArray(0., dtype=float32)}
I1020 21:31:28.316565 4341775680 sac.py:505] step 1056
I1020 21:31:28.339959 4341775680 sac.py:538] {'eval/episode_reward': DeviceArray(45.33158, dtype=float32), 'training/actor_loss': DeviceArray(-10.71464, dtype=float32), 'training/alpha': DeviceArray(0.86732984, dtype=float32), 'training/alpha_loss': DeviceArray(0.9081256, dtype=float32), 'training/buffer_current_position': DeviceArray(546., dtype=float32), 'training/buffer_current_size': DeviceArray(546., dtype=float32), 'training/critic_loss': DeviceArray(0.6350142, dtype=float32), 'eval/completed_episodes': DeviceArray(128., dtype=float32), 'eval/avg_episode_length': DeviceArray(128., dtype=float32), 'speed/sps': DeviceArray(308.7342, dtype=float32), 'speed/eval_sps': 735305.8276640594, 'speed/training_walltime': 4.038257122039795, 'speed/eval_walltime': 0.5822999477386475, 'training/grad_updates': DeviceArray(1024., dtype=float32)}
I1020 21:31:29.101129 4341775680 sac.py:505] step 2080
I1020 21:31:29.125257 4341775680 sac.py:538] {'eval/episode_reward': DeviceArray(113.68454, dtype=float32), 'training/actor_loss': DeviceArray(-36.29545, dtype=float32), 'training/alpha': DeviceArray(0.6551076, dtype=float32), 'training/alpha_loss': DeviceArray(0.62401754, dtype=float32), 'training/buffer_current_position': DeviceArray(1570., dtype=float32), 'training/buffer_current_size': DeviceArray(1570., dtype=float32), 'training/critic_loss': DeviceArray(0.6124105, dtype=float32), 'eval/completed_episodes': DeviceArray(128., dtype=float32), 'eval/avg_episode_length': DeviceArray(128., dtype=float32), 'speed/sps': DeviceArray(1346.1392, dtype=float32), 'speed/eval_sps': 713027.7632215155, 'speed/training_walltime': 4.798957109451294, 'speed/eval_walltime': 0.6052758693695068, 'training/grad_updates': DeviceArray(2048., dtype=float32)}
I1020 21:31:29.880496 4341775680 sac.py:505] step 3104
I1020 21:31:29.903525 4341775680 sac.py:538] {'eval/episode_reward': DeviceArray(61.385185, dtype=float32), 'training/actor_loss': DeviceArray(-71.23854, dtype=float32), 'training/alpha': DeviceArray(0.50871587, dtype=float32), 'training/alpha_loss': DeviceArray(0.364863, dtype=float32), 'training/buffer_current_position': DeviceArray(2594., dtype=float32), 'training/buffer_current_size': DeviceArray(2594., dtype=float32), 'training/critic_loss': DeviceArray(0.44352925, dtype=float32), 'eval/completed_episodes': DeviceArray(128., dtype=float32), 'eval/avg_episode_length': DeviceArray(128., dtype=float32), 'speed/sps': DeviceArray(1356.7372, dtype=float32), 'speed/eval_sps': 746180.3217981432, 'speed/training_walltime': 5.553715944290161, 'speed/eval_walltime': 0.6272320747375488, 'training/grad_updates': DeviceArray(3072., dtype=float32)}
I1020 21:31:30.658525 4341775680 sac.py:505] step 4128
I1020 21:31:30.682650 4341775680 sac.py:538] {'eval/episode_reward': DeviceArray(137.51349, dtype=float32), 'training/actor_loss': DeviceArray(-108.17761, dtype=float32), 'training/alpha': DeviceArray(0.41276523, dtype=float32), 'training/alpha_loss': DeviceArray(0.23481959, dtype=float32), 'training/buffer_current_position': DeviceArray(3618., dtype=float32), 'training/buffer_current_size': DeviceArray(3618., dtype=float32), 'training/critic_loss': DeviceArray(0.46300048, dtype=float32), 'eval/completed_episodes': DeviceArray(128., dtype=float32), 'eval/avg_episode_length': DeviceArray(128., dtype=float32), 'speed/sps': DeviceArray(1357.1741, dtype=float32), 'speed/eval_sps': 714116.9774082926, 'speed/training_walltime': 6.30824089050293, 'speed/eval_walltime': 0.6501739025115967, 'training/grad_updates': DeviceArray(4096., dtype=float32)}
I1020 21:31:30.683218 4341775680 sac.py:583] total steps: 4128.0
[       OK ] SACTest.testTrain
----------------------------------------------------------------------
Ran 3 tests in 110.022s

OK
erwincoumans commented 3 years ago

Jax cpu version should work fine now on Mac M1. See https://github.com/google/jax/issues/5501#issuecomment-955590288 Let's close the issue. If it still doesn't work for you, please create a new issue.