-
Notifications
You must be signed in to change notification settings - Fork 53
Lift the requirement of human fingering with RP1M. #23
Description
Hello,
Thanks for the great work! Recently, we released a paper named RP1M (https://arxiv.org/abs/2408.11048, cc @clthegoat) which includes a reward term based on optimal transport, enabling the agent to play MIDI files without human fingering. We want to know whether it is possible to integrate the method in this repo, such that people can conveniently use the Robopianist to play more songs beyond the PIG dataset.
Here are some comparison results from the paper as well as a short plan for the modification of the code. Please let me know your thoughts.
Modifications:
I plan to change these lines
robopianist/robopianist/suite/tasks/piano_with_shadow_hands.py
Lines 135 to 136 in d9cde23
| if not self._disable_fingering_reward: | |
| self._reward_fn.add("fingering_reward", self._compute_fingering_reward) |
as:
if not self._disable_fingering_reward:
# when human fingering is available.
self._reward_fn.add("fingering_reward", self._compute_fingering_reward)
else:
# use OT reward
self._reward_fn.add("ot_reward", self._compute_ot_reward)where the _compute_ot_reward is defined as:
from scipy.optimize import linear_sum_assignment
def _compute_ot_reward(self, physics: mjcf.Physics) -> float:
""" OT reward calculation from RP1M https://arxiv.org/abs/2408.11048 """
# calcuate fingertip positions
fingertip_pos = [physics.bind(finger).xpos.copy() for finger in self.left_hand.fingertip_sites]
fingertip_pos += [physics.bind(finger).xpos.copy() for finger in self.right_hand.fingertip_sites]
# calcuate the positions of piano keys to press.
keys_to_press = np.flatnonzero(self._goal_current[:-1]) # keys to press
# if no key is pressed
if keys_to_press.shape[0] == 0:
return 1.
# same as RoboPianist
key_pos = []
for key in keys_to_press:
key_geom = self.piano.keys[key].geom[0]
key_geom_pos = physics.bind(key_geom).xpos.copy()
key_geom_pos[-1] += 0.5 * physics.bind(key_geom).size[2]
key_geom_pos[0] += 0.35 * physics.bind(key_geom).size[0]
key_pos.append(key_geom_pos.copy())
# calcualte the distance between keys and fingers
dist = np.full((len(fingertip_pos), len(key_pos)), 100.)
for i, finger in enumerate(fingertip_pos):
for j, key in enumerate(key_pos):
dist[i, j] = np.linalg.norm(key - finger)
# calculate the shortest distance
row_ind, col_ind = linear_sum_assignment(dist)
dist = dist[row_ind, col_ind]
rews = tolerance(
dist,
bounds=(0, _FINGER_CLOSE_ENOUGH_TO_KEY),
margin=(_FINGER_CLOSE_ENOUGH_TO_KEY * 10),
sigmoid="gaussian",
)
return float(np.mean(rews))