Closed Robokan closed 1 year ago
Is it an option for you to convert a Brax environment to a Gym environment first (i.e., using OpenAI Gym Wrapper)? They have an example doing this in this notebook. When I use this wrapper on my M1, it works, and I avoid the error you wrote.
If you face the error "error module 'jax' has no attribute 'dlpack'", then add "import jax.dlpack" to the import section of your notebook (see #260).
Thanks @ozhanozen for posting these pointers. We just released v0.0.16
which includes a fix for the dlpack
issue.
@Robokan it's a bit difficult to parse the error here but as @ozhanozen said, you should be able to run the code in Brax Environments - I just ran it myself to be sure.
0.0.16 fixed this. It works great now.
I tried all the other examples of Brax on the Mac m1 and they work great. However on the Gym examples I don't get very far. Any ideas? here is the simple example that fails.
@title Import Brax and some helper modules
import functools import time
from IPython.display import HTML, Image import gym
import brax
from brax import envs from brax import jumpy as jp from brax.envs import to_torch from brax.io import html from brax.io import image import jax from jax import numpy as jnp import torch
v = torch.ones(1)
@title Visualizing pre-included Brax environments { run: "auto" }
@markdown Select an environment to preview it below:
environment = "ant" # @param ['ant', 'halfcheetah', 'hopper', 'humanoid', 'reacher', 'walker2d', 'fetch', 'grasp', 'ur5e'] env = envs.create(env_name=environment) state = env.reset(rng=jp.random_prngkey(seed=0))
HTML(html.render(env.sys, [state.qp]))
rollout = [] for i in range(100):
wiggle sinusoidally with a phase shift per actuator
action = jp.sin(i jp.pi / 15 + jp.arange(0, env.action_size) jp.pi) state = env.step(state, action)
on: state = env.step(state, action)
I get the following error:
Exception has occurred: TracerArrayConversionError The numpy.ndarray conversion method array() was called on the JAX Tracer object Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)> The error occurred while tracing the function f at /Users/eric/miniconda3/envs/py39/lib/python3.9/site-packages/brax/envs/wrappers.py:79 for scan. This concrete value was not available in Python because it depends on the value of the argument 'state'. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError File "/Users/eric/Documents/development/deepLearning/OpenAI/Brax/pytorchGPU/example1.py", line 35, in
state = env.step(state, action)
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method array() was called on the JAX Tracer object Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>
The error occurred while tracing the function f at /Users/eric/miniconda3/envs/py39/lib/python3.9/site-packages/brax/envs/wrappers.py:79 for scan. This concrete value was not available in Python because it depends on the value of the argument 'state'.