LeCAR-Lab / model-based-diffusion

Official implementation for the paper "Model-based Diffusion for Trajectory Optimization". Model-based diffusion (MBD) is a novel diffusion-based trajectory optimization framework that employs a dynamics model to run the reverse denoising process to generate high-quality trajectories.
https://lecar-lab.github.io/mbd/
Apache License 2.0
174 stars 8 forks source link

unexpected keyword argument 'dt' #1

Closed chaozheUB closed 5 months ago

chaozheUB commented 5 months ago

Thank you for sharing this interesting work!

When trying to run the code, I encountered the following issue, which seem to point to the fact that the brax base.system does not have dt as a property. here

  File "...model-based-diffusion/mbd/planners/mbd_planner.py", line 172, in run_diffusion
    mbd.utils.render_us, step_env_jit, env.sys.replace(dt=env.dt)
  File ".../model-based-diffusion/.mbd_venv/lib/python3.10/site-packages/mujoco/mjx/_src/dataclasses.py", line 61, in replace
    return dataclasses.replace(self, **updates)
  File "/usr/lib/python3.10/dataclasses.py", line 1453, in replace
    return obj.__class__(**changes)
TypeError: System.__init__() got an unexpected keyword argument 'dt'

Any idea why this is the case? I could get around by not replacing the dt field, but the results looks wired on a few of the environments.

Could it be because of Brax version? My version is the latest which is 0.10.4. Thanks!

jc-bao commented 5 months ago

That's weird, which environment do you use?

jc-bao commented 5 months ago

oh i see. seems like you have mujoco installation which make brax use mjx. I personally use positional backend.

chaozheUB commented 5 months ago

Hi @jc-bao thank you for the quick feedback. I don't think I installed or configured mujoco, in must be installed and configured with brax. May I ask how to switch to positional backend? Thanks!

jc-bao commented 5 months ago

I should have configed that. Could you tell me which environment do you use?

purewater0901 commented 5 months ago

@jc-bao Hello, actually I also got the same error messages.

override temp_sample to 0.1
init sigma = 6.30e-01
Diffusing: 100%|██████████████████████████████████████████| 99/99 [12:25<00:00,  7.53s/it, rew=1.55e+00]
Traceback (most recent call last):
  File "/home/yutaka/rl_workspace/model-based-diffusion/mbd/planners/mbd_planner.py", line 184, in <module>
    rew_final = run_diffusion(args=tyro.cli(Args))
  File "/home/yutaka/rl_workspace/model-based-diffusion/mbd/planners/mbd_planner.py", line 172, in run_diffusion
    mbd.utils.render_us, step_env_jit, env.sys.replace(dt=env.dt)
  File "/home/yutaka/anaconda3/envs/mbd/lib/python3.9/site-packages/mujoco/mjx/_src/dataclasses.py", line 61, in replace
    return dataclasses.replace(self, **updates)
  File "/home/yutaka/anaconda3/envs/mbd/lib/python3.9/dataclasses.py", line 1284, in replace
    return obj.__class__(**changes)
TypeError: __init__() got an unexpected keyword argument 'dt'

My environment is

Ubuntu 22.04 python = 3.9 brax = 0.10.4

jc-bao commented 5 months ago

Oh, I mean can you tell me your task? like ant, halfcheetah or something?

chaozheUB commented 5 months ago

Hi @jc-bao I was running on both ubuntu (which mine is the same as @purewater0901) and mac (same python and brax version). I tried all the environment and they all point to the same error ant, halfcheetah, hopper, humanoidstandup, humanoidrun, walker2d, pushT. If I don't do the replace dt step and directly supply the environment, I can get program running but the results look very different. I guess the time scale matters?

jc-bao commented 5 months ago

how long does it take to run ant? I think maybe the backend on your side is not correct 🤔 It takes me ~10s in 4070ti.

chaozheUB commented 5 months ago

I think it is also fast, and I only have 3080. This error pops after the training. I agree the backend is probably off. As you pointed mine (I think as well as @purewater0901) is using mujoco backend rather than positional. Just want to see how we can config to use the same backend.

jc-bao commented 5 months ago

your planning time indicates that you might use wrong backend Diffusing: 100%|██████████████████████████████████████████| 99/99 [12:25<00:00, 7.53s/it, rew=1.55e+00] For me, it is ~20s.

I have config the backend in mbd/envs/__init__.py, you can take a look here.

chaozheUB commented 5 months ago

Thank you @jc-bao! On my ubuntu with 3080 (even without paralleling) the speed is similar I believe.

python mbd_planner.py --env-name ant
2024-06-07 17:44:56.098764: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
override temp_sample to 0.1
init sigma = 6.30e-01
Diffusing: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 99/99 [00:25<00:00,  3.91it/s, rew=3.85e+00]
Does not have dt argument.
final reward = 4.08e+00

Please note that the additional print "Does not have dt argument" is added with the following code to by pass the error and finish the run.

            if hasattr(env.sys, "dt"):
                render_us = functools.partial(
                    mbd.utils.render_us, step_env_jit, env.sys.replace(dt=env.dt)
                )
            else:
                print("Does not have dt argument.")
                render_us = functools.partial(mbd.utils.render_us, step_env_jit, env.sys)

But I got the following results rollout.zip I think indeed the positional backend is requested, but I think this may be attribute to version issue.

I wonder if you can provide the version of the package in the setup?

Thanks!

jc-bao commented 5 months ago

my brax version is 0.10.4.

chaozheUB commented 5 months ago

Not sure if this is root cause but it seems that base in brax does not have property dt, but it has opt.timestamp which with _n_frames gives the dt

The error goes away after I update the this part of the script, based on the test for postional backend here.

render_us = functools.partial(
                    mbd.utils.render_us, step_env_jit, env.sys.tree_replace({'opt.timestep': env.dt / env._n_frames})
                )

Not sure whether it should be above or below. Both works but the following gives more reasonable video in terms of speed. rollout_humanoid_run.zip

render_us = functools.partial(
                    mbd.utils.render_us, step_env_jit, env.sys.tree_replace({'opt.timestep': env.dt})
                )

rollout_humanoid_run.zip

Anyway I don't think this is critical as it happens only at rendering phase. Thank you very much for the feedback.

jc-bao commented 5 months ago

I see. Maybe it is because the update of brax. I will check it later. Thanks!

jc-bao commented 5 months ago

You are right! I have just updated it. Seems a feature from 0.10.4. Thanks for the feedback!

purewater0901 commented 5 months ago

@jc-bao @chaozheUB Thank you for resolving the issue!