Closed chaozheUB closed 5 months ago
That's weird, which environment do you use?
oh i see. seems like you have mujoco installation which make brax
use mjx
. I personally use positional
backend.
Hi @jc-bao thank you for the quick feedback. I don't think I installed or configured mujoco
, in must be installed and configured with brax. May I ask how to switch to positional
backend? Thanks!
I should have configed that. Could you tell me which environment do you use?
@jc-bao Hello, actually I also got the same error messages.
override temp_sample to 0.1
init sigma = 6.30e-01
Diffusing: 100%|██████████████████████████████████████████| 99/99 [12:25<00:00, 7.53s/it, rew=1.55e+00]
Traceback (most recent call last):
File "/home/yutaka/rl_workspace/model-based-diffusion/mbd/planners/mbd_planner.py", line 184, in <module>
rew_final = run_diffusion(args=tyro.cli(Args))
File "/home/yutaka/rl_workspace/model-based-diffusion/mbd/planners/mbd_planner.py", line 172, in run_diffusion
mbd.utils.render_us, step_env_jit, env.sys.replace(dt=env.dt)
File "/home/yutaka/anaconda3/envs/mbd/lib/python3.9/site-packages/mujoco/mjx/_src/dataclasses.py", line 61, in replace
return dataclasses.replace(self, **updates)
File "/home/yutaka/anaconda3/envs/mbd/lib/python3.9/dataclasses.py", line 1284, in replace
return obj.__class__(**changes)
TypeError: __init__() got an unexpected keyword argument 'dt'
My environment is
Ubuntu 22.04 python = 3.9 brax = 0.10.4
Oh, I mean can you tell me your task? like ant
, halfcheetah
or something?
Hi @jc-bao I was running on both ubuntu (which mine is the same as @purewater0901) and mac (same python and brax version).
I tried all the environment and they all point to the same error ant
, halfcheetah
, hopper
, humanoidstandup
, humanoidrun
, walker2d
, pushT
. If I don't do the replace dt
step and directly supply the environment, I can get program running but the results look very different. I guess the time scale matters?
how long does it take to run ant
? I think maybe the backend
on your side is not correct 🤔 It takes me ~10s in 4070ti.
I think it is also fast, and I only have 3080. This error pops after the training.
I agree the backend is probably off. As you pointed mine (I think as well as @purewater0901) is using mujoco
backend rather than positional
. Just want to see how we can config to use the same backend.
your planning time indicates that you might use wrong backend
Diffusing: 100%|██████████████████████████████████████████| 99/99 [12:25<00:00, 7.53s/it, rew=1.55e+00]
For me, it is ~20s.
I have config the backend in mbd/envs/__init__.py
, you can take a look here.
Thank you @jc-bao! On my ubuntu with 3080 (even without paralleling) the speed is similar I believe.
python mbd_planner.py --env-name ant
2024-06-07 17:44:56.098764: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
override temp_sample to 0.1
init sigma = 6.30e-01
Diffusing: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 99/99 [00:25<00:00, 3.91it/s, rew=3.85e+00]
Does not have dt argument.
final reward = 4.08e+00
Please note that the additional print "Does not have dt argument" is added with the following code to by pass the error and finish the run.
if hasattr(env.sys, "dt"):
render_us = functools.partial(
mbd.utils.render_us, step_env_jit, env.sys.replace(dt=env.dt)
)
else:
print("Does not have dt argument.")
render_us = functools.partial(mbd.utils.render_us, step_env_jit, env.sys)
But I got the following results rollout.zip I think indeed the positional backend is requested, but I think this may be attribute to version issue.
I wonder if you can provide the version of the package in the setup?
Thanks!
my brax version is 0.10.4
.
Not sure if this is root cause but it seems that base
in brax does not have property dt, but it has opt.timestamp
which with _n_frames
gives the dt
The error goes away after I update the this part of the script, based on the test for postional
backend here.
render_us = functools.partial(
mbd.utils.render_us, step_env_jit, env.sys.tree_replace({'opt.timestep': env.dt / env._n_frames})
)
Not sure whether it should be above or below. Both works but the following gives more reasonable video in terms of speed. rollout_humanoid_run.zip
render_us = functools.partial(
mbd.utils.render_us, step_env_jit, env.sys.tree_replace({'opt.timestep': env.dt})
)
Anyway I don't think this is critical as it happens only at rendering phase. Thank you very much for the feedback.
I see. Maybe it is because the update of brax. I will check it later. Thanks!
You are right! I have just updated it. Seems a feature from 0.10.4. Thanks for the feedback!
@jc-bao @chaozheUB Thank you for resolving the issue!
Thank you for sharing this interesting work!
When trying to run the code, I encountered the following issue, which seem to point to the fact that the
brax base.system
does not havedt
as a property. hereAny idea why this is the case? I could get around by not replacing the dt field, but the results looks wired on a few of the environments.
Could it be because of Brax version? My version is the latest which is 0.10.4. Thanks!