Skip to content
This repository was archived by the owner on Jan 6, 2026. It is now read-only.
This repository was archived by the owner on Jan 6, 2026. It is now read-only.

Lift the requirement of human fingering with RP1M. #23

@zhaoyi11

Description

@zhaoyi11

Hello,

Thanks for the great work! Recently, we released a paper named RP1M (https://arxiv.org/abs/2408.11048, cc @clthegoat) which includes a reward term based on optimal transport, enabling the agent to play MIDI files without human fingering. We want to know whether it is possible to integrate the method in this repo, such that people can conveniently use the Robopianist to play more songs beyond the PIG dataset.

Here are some comparison results from the paper as well as a short plan for the modification of the code. Please let me know your thoughts.

Results:
comparison results

Modifications:
I plan to change these lines

if not self._disable_fingering_reward:
self._reward_fn.add("fingering_reward", self._compute_fingering_reward)

as:

if not self._disable_fingering_reward: 
    # when human fingering is available.
    self._reward_fn.add("fingering_reward", self._compute_fingering_reward)
else:
    # use OT reward 
    self._reward_fn.add("ot_reward", self._compute_ot_reward)

where the _compute_ot_reward is defined as:

from scipy.optimize import linear_sum_assignment

def _compute_ot_reward(self, physics: mjcf.Physics) -> float:
    """ OT reward calculation from RP1M https://arxiv.org/abs/2408.11048 """
    # calcuate fingertip positions
    fingertip_pos = [physics.bind(finger).xpos.copy() for finger in self.left_hand.fingertip_sites]
    fingertip_pos += [physics.bind(finger).xpos.copy() for finger in self.right_hand.fingertip_sites]
    
    # calcuate the positions of piano keys to press.
    keys_to_press = np.flatnonzero(self._goal_current[:-1]) # keys to press
    # if no key is pressed
    if keys_to_press.shape[0] == 0:
        return 1.

    # same as RoboPianist
    key_pos = []
    for key in keys_to_press:
        key_geom = self.piano.keys[key].geom[0]
        key_geom_pos = physics.bind(key_geom).xpos.copy()
        key_geom_pos[-1] += 0.5 * physics.bind(key_geom).size[2]
        key_geom_pos[0] += 0.35 * physics.bind(key_geom).size[0]
        key_pos.append(key_geom_pos.copy())

    # calcualte the distance between keys and fingers
    dist = np.full((len(fingertip_pos), len(key_pos)), 100.)
    for i, finger in enumerate(fingertip_pos):
        for j, key in enumerate(key_pos):
            dist[i, j] = np.linalg.norm(key - finger)
    
    # calculate the shortest distance
    row_ind, col_ind = linear_sum_assignment(dist)
    dist = dist[row_ind, col_ind]
    rews = tolerance(
        dist,
        bounds=(0, _FINGER_CLOSE_ENOUGH_TO_KEY),
        margin=(_FINGER_CLOSE_ENOUGH_TO_KEY * 10),
        sigmoid="gaussian",
    )
    return float(np.mean(rews))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions