jianglongye / cgf

Learning Continuous Grasping Function with a Dexterous Hand from Human Demonstrations, RA-L 2023 & IROS 2023
https://jianglongye.com/cgf
MIT License
12 stars 2 forks source link

Failure to implement motion-retargeting to another dexterous hand... #5

Closed CodingCatMountain closed 1 month ago

CodingCatMountain commented 2 months ago

@jianglongye Hi. Since there is not an allegro hand in my lab, so I try to implement the retargeting.py to the dexterous hand we have, which is inspire_hand. However, I found the motion of inspire_hand is very odd after I made some modifications, e.g. the urdf path, the finger tip names and rescale factor, in retargeting.py inorder to let it run successfully. Below is the video about the motion of inspire hand:

https://github.com/user-attachments/assets/f1f4491d-64c3-476d-890c-bbf88879d58d

May I ask for your instructions to make the motion of inspire hand more naturally? Looking forward to your reply.

Your Sincerely, Kacun.

jianglongye commented 1 month ago

Hi! Thanks for reaching out, and sorry for the delay.

Since we haven't tested the code with the Inspire hand, I'm not sure about the exact issue you're facing. However, here are a few things to double-check that might help:

  1. In this line, have you set the correct link mapping for the Inspire hand? Make sure it's updated correctly for your hand's structure.
  2. In this line, the rotation matrix is hard-coded for the Allegro hand. Check if the rest pose (where all joint positions are zero) for the Inspire hand matches that of the Allegro hand. If not, you might need to adjust it.

The retargeting process itself is straightforward. The goal is to ensure the joint positions of the MANO hand and the robot hand are spatially close. Does the loss curve look normal during optimization? If the loss curve is okay, the retargeting should work. In that case, the issue could be with the visualization script.

jianglongye commented 1 month ago

Additionally, you might want to check out this repo which seems to have been tested with the Inspire hand.

CodingCatMountain commented 1 month ago

Hi! Thanks for reaching out, and sorry for the delay.

Since we haven't tested the code with the Inspire hand, I'm not sure about the exact issue you're facing. However, here are a few things to double-check that might help:

1. In [this line](https://github.com/jianglongye/cgf/blob/main/scripts/retargeting.py#L44), have you set the correct link mapping for the Inspire hand? Make sure it's updated correctly for your hand's structure.

2. In [this line](https://github.com/jianglongye/cgf/blob/main/scripts/retargeting.py#L56), the rotation matrix is hard-coded for the Allegro hand. Check if the rest pose (where all joint positions are zero) for the Inspire hand matches that of the Allegro hand. If not, you might need to adjust it.

The retargeting process itself is straightforward. The goal is to ensure the joint positions of the MANO hand and the robot hand are spatially close. Does the loss curve look normal during optimization? If the loss curve is okay, the retargeting should work. In that case, the issue could be with the visualization script.

Hi, @jianglongye . Thanks for replying.

For the first point, the inspire hand doesn't have the palm link. Therefore, I mapped the link_hand_indices[0] to the hand_base_link of inspire hand which is the wrist of inspire hand. And I don't think it will be the problem, since the index-0 in mano hand model is also the wrist. Am I right?

For the second point, I have checked the rotation matrix. And I think the rotation matrix could also used in inspire right hand, because the base coordinate of inspire hand is same as the base coordinate of allegro hand. And the rotation matrix just aligned the coordinate of mano right hand with the the coordinate of allegro hand. Am I right?

As to the repo you recommend, I knew it too. But I have not try to merge dex-retargeting into this repo for now. : ]

jianglongye commented 1 month ago

Apologies for the delayed response. I have now implemented the retargeting for the Inspire hand using the dexycb dataset. Here is the visualization:

https://github.com/user-attachments/assets/d0dea06d-faeb-449b-a16d-6f4746acac9b

Here is the single-file implementation:

import os
import time
from dataclasses import dataclass, field
from typing import Dict, Tuple

os.environ["OMP_NUM_THREADS"] = "4"
os.environ["OPENBLAS_NUM_THREADS"] = "4"
os.environ["MKL_NUM_THREADS"] = "6"
os.environ["VECLIB_MAXIMUM_THREADS"] = "4"
os.environ["NUMEXPR_NUM_THREADS"] = "6"

import numpy as np
import torch
import tyro
import yaml
from fastdev.robo.robot_model import RobotModel
from fastdev.sim_webui import SimWebUI
from fastdev.smplx import build_mano_layer, build_mano_layer_manopth, transform_mano_pose
from fastdev.utils import timeit
from fastdev.utils.tensor_utils import to_torch
from fastdev.xform import (
    axis_angle_vector_to_matrix,
    coord_conversion,
    inverse_tf_mat,
    matrix_to_axis_angle_vector,
    matrix_to_rotation_6d,
    quaternion_real_to_first,
    quaternion_to_matrix,
    rot_tl_to_tf_mat,
    rotation_6d_to_matrix,
)

_SUBJECTS = [
    "20200709-subject-01",
    "20200813-subject-02",
    "20200820-subject-03",
    "20200903-subject-04",
    "20200908-subject-05",
    "20200918-subject-06",
    "20200928-subject-07",
    "20201002-subject-08",
    "20201015-subject-09",
    "20201022-subject-10",
]

_YCB_CLASSES = {
    1: "002_master_chef_can",
    2: "003_cracker_box",
    3: "004_sugar_box",
    4: "005_tomato_soup_can",
    5: "006_mustard_bottle",
    6: "007_tuna_fish_can",
    7: "008_pudding_box",
    8: "009_gelatin_box",
    9: "010_potted_meat_can",
    10: "011_banana",
    11: "019_pitcher_base",
    12: "021_bleach_cleanser",
    13: "024_bowl",
    14: "025_mug",
    15: "035_power_drill",
    16: "036_wood_block",
    17: "037_scissors",
    18: "040_large_marker",
    19: "051_large_clamp",
    20: "052_extra_large_clamp",
    21: "061_foam_brick",
}

@dataclass
class Args:
    data_root: str = str("data/dexycb")
    right_hand_urdf_path: str = str("assets/robot_description/inspire_hand/inspire_hand_right.urdf")
    left_hand_urdf_path: str = str("assets/robot_description/inspire_hand/inspire_hand_left.urdf")
    right_hand_coord_spec: str = "x: front, y: left, z: up"
    left_hand_coord_spec: str = "x: front, y: left, z: up"

    # this mapping is from mano joint index to robot hand joint name
    joints_mapping: Dict[int, str] = field(
        default_factory=lambda: {
            0: "hand_base_link",
            4: "thumb_tip",
            8: "index_tip",
            12: "middle_tip",
            16: "ring_tip",
            20: "pinky_tip",
        }
    )

    joints_weight: Tuple[float, ...] = (1.0, 2.0, 1.0, 1.0, 1.0, 1.0)

@timeit("extract_metas")
def extract_metas(data_root: str, moving_threshold: float = 0.004) -> Dict[str, Dict]:
    subject_seq_list = []
    for subject_id in _SUBJECTS:
        seq_ids = os.listdir(os.path.join(data_root, subject_id))
        assert len(seq_ids) == 100, f"Each subject in DexYCB should have 100 sequences. subject: {subject_id}"
        subject_seq_list.extend([(subject_id, x) for x in seq_ids])
    subject_seq_list = sorted(subject_seq_list, key=lambda x: os.path.join(*x))
    metas = {}
    for subject_id, seq_id in subject_seq_list:
        seq_pose_path = os.path.join(data_root, subject_id, seq_id, "pose.npz")
        seq_meta_path = os.path.join(data_root, subject_id, seq_id, "meta.yml")
        with open(seq_meta_path, "r") as f:
            seq_meta = yaml.safe_load(f)
        seq_pose = np.load(seq_pose_path)
        mano_calib_path = os.path.join(data_root, "calibration", f"mano_{seq_meta['mano_calib'][0]}", "mano.yml")
        with open(mano_calib_path, "r") as f:
            betas = yaml.safe_load(f)["betas"]

        seq_pose_y = seq_pose["pose_y"]
        seq_pose_m = seq_pose["pose_m"]
        seq_ycb_ids = seq_meta["ycb_ids"]

        static_flags = np.all(np.all(seq_pose_y == seq_pose_y[0], axis=2), axis=0)
        # all seqs in dex-ycb contain only one moving object
        if np.count_nonzero(static_flags) != len(seq_ycb_ids) - 1:
            raise ValueError(f"There is not only one moving object in subject: {subject_id}, seq: {seq_id}.")

        target_index = int(np.argmax(np.logical_not(static_flags)))
        target_ycb_id = seq_ycb_ids[target_index]
        assert target_index == seq_meta["ycb_grasp_ind"]

        target_pose = seq_pose_y[:, target_index]
        target_tl = target_pose[:, 4:]
        target_static_flags = np.logical_or(
            target_tl > target_tl[0] + np.ones(3) * moving_threshold,
            target_tl < target_tl[0] - np.ones(3) * moving_threshold,
        )

        # end_frame is exclusive
        end_frame = int(np.argmax(np.any(target_static_flags, axis=1)))

        non_zero_pose_frames = np.any(seq_pose_m != 0, axis=-1).squeeze(-1)
        # start_frame is inclusive
        start_frame = int(np.argmax(non_zero_pose_frames))
        padded_non_zero_pose_frames = np.concatenate(
            (np.logical_not(non_zero_pose_frames[start_frame:]), np.asarray([True]))
        )
        end_frame = min(end_frame, int(np.argmax(padded_non_zero_pose_frames)) + start_frame)
        if start_frame >= end_frame:
            print(f"there is a invalid sequence. subject: {subject_id}, seq: {seq_id}")
            print("this sequence will be skipped.")
            continue

        result = {
            "subject_id": subject_id,
            "seq_id": seq_id,
            "target_index": target_index,
            "target_ycb_id": target_ycb_id,
            "betas": betas,
            "valid_frames_range": [start_frame, end_frame],  # start is inclusive, end is exclusive
        }
        result.update(seq_meta)

        metas[seq_id] = result

    # sort by key
    metas = {k: v for k, v in sorted(list(metas.items()))}
    return metas

@timeit("transform_pose_m_to_object_coord")
def transform_pose_m_to_object_coord(data_root: str, metas: Dict):
    left_mano_layer = build_mano_layer(hand_side="left", use_pca=True, num_pca_comps=45)
    right_mano_layer = build_mano_layer(hand_side="right", use_pca=True, num_pca_comps=45)

    pose_m_object_coord = {}
    for seq_id in metas:
        seq_meta = metas[seq_id]
        data_dir = os.path.join(data_root, seq_meta["subject_id"], seq_meta["seq_id"])
        mano_calib_path = os.path.join(data_root, "calibration", f"mano_{seq_meta['mano_calib'][0]}", "mano.yml")

        pose_path = os.path.join(data_dir, "pose.npz")
        pose = np.load(pose_path, allow_pickle=True)
        pose_m, pose_y = pose["pose_m"], pose["pose_y"]
        target_index = seq_meta["target_index"]

        with open(mano_calib_path, "r") as f:
            mano_calib = yaml.safe_load(f)
        mano_side = seq_meta["mano_sides"][0]
        mano_betas = torch.tensor(mano_calib["betas"], dtype=torch.float32)
        valid_pose_m = torch.from_numpy(
            pose_m[seq_meta["valid_frames_range"][0] : seq_meta["valid_frames_range"][1]]
        ).squeeze(1)
        valid_pose_y = torch.from_numpy(
            pose_y[seq_meta["valid_frames_range"][0] : seq_meta["valid_frames_range"][1]]
        ).squeeze(1)

        mano_layer = left_mano_layer if mano_side == "left" else right_mano_layer

        y_rot = quaternion_to_matrix(quaternion_real_to_first(valid_pose_y[..., target_index, :4]))
        y_tl = valid_pose_y[..., target_index, 4:]
        inv_y_tf = inverse_tf_mat(rot_tl_to_tf_mat(rot_mat=y_rot, tl=y_tl))
        # inv_y_rot = quaternion_invert(quaternion_real_to_first(valid_pose_y[..., target_index, :4]))
        # inv_y_transl = -rotate_points(valid_pose_y[..., target_index, 4:], quaternion_to_matrix(inv_y_rot))

        tf_glb_ori, tf_transl = transform_mano_pose(
            mano_layer,
            betas=mano_betas[None],
            global_orient=valid_pose_m[..., :3],
            transl=valid_pose_m[..., 48:51],
            tf_rot=matrix_to_axis_angle_vector(inv_y_tf[:, :3, :3]),
            tf_transl=inv_y_tf[:, :3, 3],
        )
        tf_valid_pose_m = torch.cat((tf_glb_ori, valid_pose_m[..., 3:48], tf_transl), dim=-1).detach().cpu().numpy()

        pose_m_object_coord[seq_id] = tf_valid_pose_m
    return pose_m_object_coord

@timeit("retargeting")
def retargeting(pose_m_object_coord: Dict[str, np.ndarray], metas: Dict[str, Dict], args: Args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    right_mano_coord_spec = "x: down, y: back, z: left"
    left_mano_coord_spec = "x: up, y: back, z: right"
    joints_weight = torch.tensor(args.joints_weight, dtype=torch.float32, device=device)

    results = {}
    for side in ["left", "right"]:
        mano_layer = build_mano_layer(hand_side=side, use_pca=True, num_pca_comps=45).to(device)
        mano_layer_pth = build_mano_layer_manopth(hand_side=side, use_pca=True, num_pca_comps=45).to(device)
        urdf_path = args.left_hand_urdf_path if side == "left" else args.right_hand_urdf_path
        robot_model = RobotModel.from_urdf_or_mjcf_path(urdf_path, device=device)

        num_side_seqs = sum([metas[seq_id]["mano_sides"][0] == side for seq_id in metas])
        side_metas = [metas[seq_id] for seq_id in metas if metas[seq_id]["mano_sides"][0] == side]
        max_num_frames = max([pose_m_object_coord[seq_meta["seq_id"]].shape[0] for seq_meta in side_metas])
        pose_m = np.zeros((num_side_seqs, max_num_frames, 51))
        pose_m_mask = np.zeros((num_side_seqs, max_num_frames), dtype=bool)
        for seq_idx, seq_meta in enumerate(side_metas):
            seq_id = seq_meta["seq_id"]
            pose_m[seq_idx, : pose_m_object_coord[seq_id].shape[0]] = pose_m_object_coord[seq_id]
            pose_m_mask[seq_idx, : pose_m_object_coord[seq_id].shape[0]] = True

        pose_m = torch.from_numpy(pose_m).float().to(device)
        pose_m_mask = torch.from_numpy(pose_m_mask).bool().to(device)
        betas = torch.tensor([seq_meta["betas"] for seq_meta in side_metas], dtype=torch.float32).to(device)

        init_hand_out = mano_layer(
            global_orient=pose_m[:, 0, :3],
            hand_pose=pose_m[:, 0, 3:48],
            betas=betas,
            transl=pose_m[:, 0, 48:51],
            return_verts=True,
        )
        init_tl = init_hand_out.joints[:, 0]
        mano_coord_spec = left_mano_coord_spec if side == "left" else right_mano_coord_spec
        robot_coord_spec = args.left_hand_coord_spec if side == "left" else args.right_hand_coord_spec
        robot_to_mano_rot = torch.from_numpy(coord_conversion(robot_coord_spec, mano_coord_spec)).to(device)
        init_rot = axis_angle_vector_to_matrix(pose_m[:, 0, :3]) @ robot_to_mano_rot
        init_rot_6d = matrix_to_rotation_6d(init_rot)
        init_joint_angles = torch.mean(robot_model.joint_limits, dim=-1).repeat(init_rot.shape[0], 1)

        mano_joint_indices = torch.tensor(list(args.joints_mapping.keys()), dtype=torch.int32).to(device)
        robot_link_indices = [robot_model.link_names.index(name) for name in args.joints_mapping.values()]
        robot_link_indices = torch.tensor(robot_link_indices, dtype=torch.int32).to(device)

        result_variables = torch.zeros(
            (num_side_seqs, max_num_frames, 3 + 6 + init_joint_angles.shape[-1]), device=device
        )
        for frame_idx in range(max_num_frames):
            frame_mask = pose_m_mask[:, frame_idx]
            if frame_idx == 0:
                robot_tl = init_tl[frame_mask].detach().requires_grad_(True)
                robot_rot = init_rot_6d[frame_mask].detach().requires_grad_(True)
                robot_joint_angles = init_joint_angles[frame_mask].detach().requires_grad_(True)
            else:
                robot_tl = result_variables[frame_mask, frame_idx - 1, :3].detach().requires_grad_(True)
                robot_rot = result_variables[frame_mask, frame_idx - 1, 3:9].detach().requires_grad_(True)
                robot_joint_angles = result_variables[frame_mask, frame_idx - 1, 9:].detach().requires_grad_(True)

            optimizer = torch.optim.Adam([robot_tl, robot_rot, robot_joint_angles], lr=1e-3)
            loss_fn = torch.nn.SmoothL1Loss(beta=0.01, reduction="none")

            mano_verts, mano_joint = mano_layer_pth(
                pose_m[frame_mask, frame_idx, :48], betas[frame_mask], pose_m[frame_mask, frame_idx, 48:51]
            )
            mano_verts /= 1000
            mano_joint /= 1000

            target_joint_positions = mano_joint[:, mano_joint_indices]

            for iter_idx in range(1000 if frame_idx == 0 else 100):
                optimizer.zero_grad()
                robot_link_poses = robot_model.forward_kinematics(
                    joint_values=robot_joint_angles,
                    root_poses=rot_tl_to_tf_mat(rot_mat=rotation_6d_to_matrix(robot_rot), tl=robot_tl),
                )
                robot_link_positions = robot_link_poses[:, robot_link_indices, :3, 3]

                pos_loss = loss_fn(robot_link_positions, target_joint_positions)
                weighted_loss = (pos_loss * joints_weight[None, :, None]).mean()

                loss = weighted_loss
                loss.backward()
                optimizer.step()

                if iter_idx % 20 == 0:
                    print(f"iter {iter_idx}: loss = {loss.item()}")

            result_variables[frame_mask, frame_idx, :3] = robot_tl.detach()
            result_variables[frame_mask, frame_idx, 3:9] = robot_rot.detach()
            result_variables[frame_mask, frame_idx, 9:] = robot_joint_angles.detach()

        results[side] = result_variables.cpu().numpy()
    return results

if __name__ == "__main__":
    args = tyro.cli(Args)

    metas = extract_metas(args.data_root)

    pose_m_object_coord = transform_pose_m_to_object_coord(args.data_root, metas)

    robot_results = retargeting(pose_m_object_coord, metas, args)

    # --------- visualize retargeted pose_m ---------
    viz_side = "left"
    side_variables = robot_results[viz_side]
    side_metas = [metas[seq_id] for seq_id in metas if metas[seq_id]["mano_sides"][0] == viz_side]
    mano_layer_pth = build_mano_layer_manopth(hand_side=viz_side, use_pca=True, num_pca_comps=45)
    urdf_path = args.left_hand_urdf_path if viz_side == "left" else args.right_hand_urdf_path
    robot_model = RobotModel.from_urdf_or_mjcf_path(urdf_path)
    robot_link_meshes = robot_model.get_link_trimesh_meshes()

    webui = SimWebUI()
    for scene_idx in range(20):
        seq_id = side_metas[scene_idx]["seq_id"]
        num_valid_frames = pose_m_object_coord[seq_id].shape[0]

        viz_pose_m = torch.from_numpy(pose_m_object_coord[seq_id])
        viz_betas = torch.tensor(side_metas[scene_idx]["betas"], dtype=torch.float32)

        target_ycb_id = _YCB_CLASSES[side_metas[scene_idx]["target_ycb_id"]]
        obj_mesh_path = os.path.join(args.data_root, "models", f"{target_ycb_id}", "textured_simple.obj")
        mesh_asset_id = webui.add_mesh_asset(mesh_path=obj_mesh_path)
        webui.set_mesh_state(asset_id=mesh_asset_id, scene_index=scene_idx, frame_range=(0, viz_pose_m.shape[0]))

        hand_vert, hand_joint = mano_layer_pth(
            viz_pose_m[:, :48],
            viz_betas[None].expand(viz_pose_m.shape[0], -1),
            viz_pose_m[:, 48:51],
        )
        hand_vert /= 1000
        for frame_idx in range(viz_pose_m.shape[0]):
            asset_id = webui.add_point_cloud_asset(hand_vert[frame_idx])
            webui.set_point_cloud_state(
                asset_id=asset_id,
                scene_index=scene_idx,
                frame_range=(frame_idx, frame_idx + 1),
                point_size=0.002,
                color="blue",
            )
        robot_asset_id = webui.add_robot_asset(urdf_path)
        root_poses = rot_tl_to_tf_mat(
            rot_mat=rotation_6d_to_matrix(to_torch(side_variables[scene_idx, :num_valid_frames, 3:9])),
            tl=to_torch(side_variables[scene_idx, :num_valid_frames, :3]),
        )
        webui.set_robot_state(
            asset_id=robot_asset_id,
            scene_index=scene_idx,
            joint_values=side_variables[scene_idx, :num_valid_frames, 9:],
            root_poses=root_poses,
        )
    while True:
        time.sleep(1)

Here is the URDF file I use. Please note that this script relies on several functions from my private repository fastdev. While the repo itself isn't public, you can install the package via PyPI using pip install fastdev==0.1.8. Also, make sure to modify the dataset/URDF paths in the script as needed.

The current implementation was done in a rush, so there are several things that could be improved, such as adding a smoothness loss for the trajectory.