google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.33k stars 255 forks source link

Gym examples not work on Mac m1 #270

Closed Robokan closed 1 year ago

Robokan commented 1 year ago

I tried all the other examples of Brax on the Mac m1 and they work great. However on the Gym examples I don't get very far. Any ideas? here is the simple example that fails.

@title Import Brax and some helper modules

import functools import time

from IPython.display import HTML, Image import gym

import brax

from brax import envs from brax import jumpy as jp from brax.envs import to_torch from brax.io import html from brax.io import image import jax from jax import numpy as jnp import torch

v = torch.ones(1)

@title Visualizing pre-included Brax environments { run: "auto" }

@markdown Select an environment to preview it below:

environment = "ant" # @param ['ant', 'halfcheetah', 'hopper', 'humanoid', 'reacher', 'walker2d', 'fetch', 'grasp', 'ur5e'] env = envs.create(env_name=environment) state = env.reset(rng=jp.random_prngkey(seed=0))

HTML(html.render(env.sys, [state.qp]))

rollout = [] for i in range(100):

wiggle sinusoidally with a phase shift per actuator

action = jp.sin(i jp.pi / 15 + jp.arange(0, env.action_size) jp.pi) state = env.step(state, action)

on: state = env.step(state, action)

I get the following error:

Exception has occurred: TracerArrayConversionError The numpy.ndarray conversion method array() was called on the JAX Tracer object Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)> The error occurred while tracing the function f at /Users/eric/miniconda3/envs/py39/lib/python3.9/site-packages/brax/envs/wrappers.py:79 for scan. This concrete value was not available in Python because it depends on the value of the argument 'state'. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError File "/Users/eric/Documents/development/deepLearning/OpenAI/Brax/pytorchGPU/example1.py", line 35, in state = env.step(state, action) jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method array() was called on the JAX Tracer object Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)> The error occurred while tracing the function f at /Users/eric/miniconda3/envs/py39/lib/python3.9/site-packages/brax/envs/wrappers.py:79 for scan. This concrete value was not available in Python because it depends on the value of the argument 'state'.

ozhanozen commented 1 year ago

Is it an option for you to convert a Brax environment to a Gym environment first (i.e., using OpenAI Gym Wrapper)? They have an example doing this in this notebook. When I use this wrapper on my M1, it works, and I avoid the error you wrote.

If you face the error "error module 'jax' has no attribute 'dlpack'", then add "import jax.dlpack" to the import section of your notebook (see #260).

erikfrey commented 1 year ago

Thanks @ozhanozen for posting these pointers. We just released v0.0.16 which includes a fix for the dlpack issue.

@Robokan it's a bit difficult to parse the error here but as @ozhanozen said, you should be able to run the code in Brax Environments - I just ran it myself to be sure.

Robokan commented 1 year ago

0.0.16 fixed this. It works great now.