google-deepmind / dm_control

Google DeepMind's software stack for physics-based simulation and Reinforcement Learning environments, using MuJoCo.
Apache License 2.0
3.82k stars 669 forks source link

Use of MjSpec in environment #494

Open LilianLaporte opened 1 month ago

LilianLaporte commented 1 month ago

Hello,

I would like to create a RL pipeline in order to find a policy to stack rebars in the most stable way depending on their shape. As I want to stack them, I need to spawn them at one by one. One way (and the easiest for me) would be to use the MjSpec class from mujoco and recompile the model every time I want to spawn a rebar. However, when I wanted to make the environment for the training, I realized that it was not trivial to use this MjSpec feature that I am using with the Physics class... So my question would be to know if there is an easy way to do it or if I have to rewrite the Environment class?

Thank you

Example

Here is an example of my problem with a simple script.

xml file ```xml ```
python code ```python import sys, os import time # Add the parent directory to sys.path sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) import mujoco import mujoco.viewer def display_video(frames, videoname='animation.mp4', framerate=30): """Display a list of frames as a video. Args: frames (list): list of frames to display. videoname (str, optional): name of the video. Defaults to 'animation.mp4'. framerate (int, optional): # frames per second. Defaults to 30. """ height, width, _ = frames[0].shape dpi = 10 orig_backend = matplotlib.get_backend() matplotlib.use('Agg') # Switch to headless 'Agg' to inhibit figure rendering. fig, ax = plt.subplots(1, 1, figsize=(width / dpi, height / dpi), dpi=dpi) matplotlib.use(orig_backend) # Switch back to the original backend. ax.set_axis_off() ax.set_aspect('equal') ax.set_position([0, 0, 1, 1]) im = ax.imshow(frames[0]) def update(frame): im.set_data(frame) return [im] interval = 1000/framerate anim = animation.FuncAnimation(fig=fig, func=update, frames=frames, interval=interval, blit=True, repeat=False) # return HTML(anim.to_html5_video()) # Save as MP4 anim.save(videoname, writer='ffmpeg', fps=framerate) plt.close(fig) def add_body_in_spec(spec): """ Add a body to the mujoco model. Args: spec: mujoco.MjSpec, the mujoco model specification. Returns: spec: mujoco.MjSpec, the updated mujoco model specification. Note: The body is a classic U-shape stirrup TODO: make this process adaptve to every kind of stirrup. """ # plugin2_0 = spec.add_plugin() # plugin2_0.instance_name = "compositestirrup0" # plugin2_0.active = 1 body2 = spec.worldbody.add_body() # body2 = spec.find_body("stirrup") body2.name = "stirrup2" body2.pos = [0.5, 0.1, 0.5] # body2.quat = [ 0, 0, 0.7071068, 0.7071068 ] body2.quat = [0.707107, 0, -0.707107, 0] body2.mass = 1 body2.inertia = [0.001, 0.001, 0.001] body2.ipos = [0.28, 0, 0.005] joint = body2.add_joint() joint.type = mujoco.mjtJoint.mjJNT_FREE body2_0 = body2.add_body() body2_0.name = "stirrup20B0" body2_0.pos = [0, 0.005, 0] joint2_0 = body2_0.add_joint() joint2_0.name = "stirrup20J0" joint2_0.type = mujoco.mjtJoint.mjJNT_BALL joint2_0.group = 3 joint2_0.pos = [0, 0, 0] joint2_0.armature = 100 joint2_0.damping = 1 joint2_0.stiffness = 100 joint2_0.range = [0, 2] geom2_0 = body2_0.add_geom() geom2_0.name = "stirrup2G0" geom2_0.size = [0.005, 0.2825, 0] geom2_0.pos = [0.2825, 0, 0] geom2_0.quat = [0.707107, 0, -0.707107, 0] # geom2_0.quat = [1, 0, 0, 0] geom2_0.type = mujoco.mjtGeom.mjGEOM_CAPSULE geom2_0.rgba = [0.8, 0.2, 0.1, 1] site2_0 = body2_0.add_site() site2_0.name = "stirrup2S0" site2_0.pos = [0, 0, 0] site2_0.group = 3 body2_1 = body2_0.add_body() body2_1.name = "stirrup20B1" body2_1.pos = [0.565, 0, 0] body2_1.quat = [0.92388, 0, 0, 0.382683] joint2_1 = body2_1.add_joint() joint2_1.name = "stirrup20J1" joint2_1.type = mujoco.mjtJoint.mjJNT_BALL joint2_1.group = 3 joint2_1.pos = [0, 0, 0] joint2_1.armature = 100 joint2_1.damping = 1 joint2_1.stiffness = 100 joint2_1.range = [0, 2] geom2_1 = body2_1.add_geom() geom2_1.name = "stirrup2G1" geom2_1.size = [0.005, 0.0212132, 0] geom2_1.pos = [0.0212132, 0, 0] geom2_1.quat = [0.707107, 0, -0.707107, 0] geom2_1.type = mujoco.mjtGeom.mjGEOM_CAPSULE geom2_1.rgba = [0.8, 0.2, 0.1, 1] site2_1 = body2_1.add_site() site2_1.name = "stirrup2S1" site2_1.pos = [0, 0, 0] site2_1.group = 3 body2_2 = body2_1.add_body() body2_2.name = "stirrup20B2" body2_2.pos = [0.0424264, 0, 0] body2_2.quat = [0.92388, 0, 0, 0.382683] joint2_2 = body2_2.add_joint() joint2_2.name = "stirrup20J2" joint2_2.type = mujoco.mjtJoint.mjJNT_BALL joint2_2.group = 3 joint2_2.pos = [0, 0, 0] joint2_2.armature = 100 joint2_2.damping = 1 joint2_2.stiffness = 100 joint2_2.range = [0, 2] geom2_2 = body2_2.add_geom() geom2_2.name = "stirrup2G2" geom2_2.size = [0.005, 0.06, 0] geom2_2.pos = [0.06, 0, 0] geom2_2.quat = [0.707107, 0, -0.707107, 0] geom2_2.type = mujoco.mjtGeom.mjGEOM_CAPSULE geom2_2.rgba = [0.8, 0.2, 0.1, 1] site2_2 = body2_2.add_site() site2_2.name = "stirrup2S2" site2_2.pos = [0, 0, 0] site2_2.group = 3 body2_3 = body2_2.add_body() body2_3.name = "stirrup20B3" body2_3.pos = [0.12, 0, 0] body2_3.quat = [0.92388, 0, 0, 0.382683] joint2_3 = body2_3.add_joint() joint2_3.name = "stirrup20J3" joint2_3.type = mujoco.mjtJoint.mjJNT_BALL joint2_3.group = 3 joint2_3.pos = [0, 0, 0] joint2_3.armature = 100 joint2_3.damping = 1 joint2_3.stiffness = 100 joint2_3.range = [0, 2] geom2_3 = body2_3.add_geom() geom2_3.name = "stirrup2G3" geom2_3.size = [0.005, 0.0212132, 0] geom2_3.pos = [0.0212132, 0, 0] geom2_3.quat = [0.707107, 0, -0.707107, 0] geom2_3.type = mujoco.mjtGeom.mjGEOM_CAPSULE geom2_3.rgba = [0.8, 0.2, 0.1, 1] site2_3 = body2_3.add_site() site2_3.name = "stirrup2S3" site2_3.pos = [0, 0, 0] site2_3.group = 3 body2_4 = body2_3.add_body() body2_4.name = "stirrup20B4" body2_4.pos = [0.0424264, 0, 0] body2_4.quat = [0.92388, 0, 0, 0.382683] joint2_4 = body2_4.add_joint() joint2_4.name = "stirrup20J4" joint2_4.type = mujoco.mjtJoint.mjJNT_BALL joint2_4.group = 3 joint2_4.pos = [0, 0, 0] joint2_4.armature = 100 joint2_4.damping = 1 joint2_4.stiffness = 100 joint2_4.range = [0, 2] geom2_4 = body2_4.add_geom() geom2_4.name = "stirrup2G4" geom2_4.size = [0.005, 0.2825, 0] geom2_4.pos = [0.2825, 0, 0] geom2_4.quat = [0.707107, 0, -0.707107, 0] geom2_4.type = mujoco.mjtGeom.mjGEOM_CAPSULE geom2_4.rgba = [0.8, 0.2, 0.1, 1] site2_4 = body2_4.add_site() site2_4.name = "stirrup2S4" site2_4.pos = [0, 0, 0] site2_4.group = 3 return spec if __name__ == '__main__': spec = mujoco.MjSpec() spec.from_file("src/rebar_mujoco/environments/xml_files/mujoco.xml") model = spec.compile() data = mujoco.MjData(model) framerate = 40 # (Hz) timestep = 0.001 # (seconds) duration = 5 # (seconds) width, height = 500, 500 renderer = mujoco.Renderer(model, height=height, width=width) video_name = "src/rebar_mujoco/tests/tmp/test_mjSpec.mp4" scene_option = mujoco.MjvOption() frames = [] frame_count = 0 camera = mujoco.MjvCamera() mujoco.mjv_defaultFreeCamera(model, camera) camera.distance = 2.5 camera.elevation = -20 # camera.azimuth = -130 camera.lookat = (0, 0, 0.5) add_body_time = 2 # seconds add_body = False z_pos_vec = [] nb_step = 0 while nb_step*timestep < duration: mujoco.mj_step(model, data) if nb_step*timestep > add_body_time and not add_body: print("Adding body") add_body = True spec = add_body_in_spec(spec) ## Recompile model and data while maintaining the state. model, data = spec.recompile(model, data) renderer.close() renderer = mujoco.Renderer(model, height=height, width=width) if frame_count < nb_step*timestep * framerate: z_pos_vec.append(data.geom_xpos[-1][2]) renderer.update_scene(data, camera=camera, scene_option=scene_option) frame = renderer.render() frames.append(frame) frame_count += 1 nb_step += 1 display_video(frames, video_name, framerate) print(f"Video saved at {video_name}") import matplotlib.pyplot as plt plt.plot(z_pos_vec) plt.show() ```

Context: