google-research / robopianist

[CoRL '23] Dexterous piano playing with deep reinforcement learning.
https://kzakka.com/robopianist/
Apache License 2.0
586 stars 47 forks source link

Lift the requirement of human fingering with RP1M. #23

Open zhaoyi11 opened 3 months ago

zhaoyi11 commented 3 months ago

Hello,

Thanks for the great work! Recently, we released a paper named RP1M (https://arxiv.org/abs/2408.11048, cc @clthegoat) which includes a reward term based on optimal transport, enabling the agent to play MIDI files without human fingering. We want to know whether it is possible to integrate the method in this repo, such that people can conveniently use the Robopianist to play more songs beyond the PIG dataset.

Here are some comparison results from the paper as well as a short plan for the modification of the code. Please let me know your thoughts.

Results: comparison results

Modifications: I plan to change these lines https://github.com/google-research/robopianist/blob/d9cde23e46cb30ebb8eeebb375a9c52191238a30/robopianist/suite/tasks/piano_with_shadow_hands.py#L135-L136 as:

if not self._disable_fingering_reward: 
    # when human fingering is available.
    self._reward_fn.add("fingering_reward", self._compute_fingering_reward)
else:
    # use OT reward 
    self._reward_fn.add("ot_reward", self._compute_ot_reward)

where the _compute_ot_reward is defined as:

from scipy.optimize import linear_sum_assignment

def _compute_ot_reward(self, physics: mjcf.Physics) -> float:
    """ OT reward calculation from RP1M https://arxiv.org/abs/2408.11048 """
    # calcuate fingertip positions
    fingertip_pos = [physics.bind(finger).xpos.copy() for finger in self.left_hand.fingertip_sites]
    fingertip_pos += [physics.bind(finger).xpos.copy() for finger in self.right_hand.fingertip_sites]

    # calcuate the positions of piano keys to press.
    keys_to_press = np.flatnonzero(self._goal_current[:-1]) # keys to press
    # if no key is pressed
    if keys_to_press.shape[0] == 0:
        return 1.

    # same as RoboPianist
    key_pos = []
    for key in keys_to_press:
        key_geom = self.piano.keys[key].geom[0]
        key_geom_pos = physics.bind(key_geom).xpos.copy()
        key_geom_pos[-1] += 0.5 * physics.bind(key_geom).size[2]
        key_geom_pos[0] += 0.35 * physics.bind(key_geom).size[0]
        key_pos.append(key_geom_pos.copy())

    # calcualte the distance between keys and fingers
    dist = np.full((len(fingertip_pos), len(key_pos)), 100.)
    for i, finger in enumerate(fingertip_pos):
        for j, key in enumerate(key_pos):
            dist[i, j] = np.linalg.norm(key - finger)

    # calculate the shortest distance
    row_ind, col_ind = linear_sum_assignment(dist)
    dist = dist[row_ind, col_ind]
    rews = tolerance(
        dist,
        bounds=(0, _FINGER_CLOSE_ENOUGH_TO_KEY),
        margin=(_FINGER_CLOSE_ENOUGH_TO_KEY * 10),
        sigmoid="gaussian",
    )
    return float(np.mean(rews))
kevinzakka commented 3 months ago

@zhaoyi11 I was going to email you all and ask you if you would be down to integrate your labeling pipeline in the repo 😂 So to answer your question, yes! I would be more than happy to help you, please feel free to submit a PR! Love the paper btw!

zhaoyi11 commented 3 months ago

Thanks @kevinzakka! Great, I will prepare the PR asap.