Skip to content

Commit facd752

Browse files
committed
Update Python models to allow extracting any set of graph outputs
1 parent 7d01b7f commit facd752

3 files changed

Lines changed: 40 additions & 17 deletions

File tree

mediapipe/python/solutions/face_mesh.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# Lint as: python3
1616
"""MediaPipe FaceMesh."""
1717

18-
from typing import NamedTuple
18+
from typing import NamedTuple, Optional, Tuple
1919

2020
import numpy as np
2121

@@ -249,7 +249,8 @@ def __init__(self,
249249
static_image_mode=False,
250250
max_num_faces=2,
251251
min_detection_confidence=0.5,
252-
min_tracking_confidence=0.5):
252+
min_tracking_confidence=0.5,
253+
outputs: Optional[Tuple[str]] = ('multi_face_landmarks',)):
253254
"""Initializes a MediaPipe FaceMesh object.
254255
255256
Args:
@@ -274,6 +275,9 @@ def __init__(self,
274275
robustness of the solution, at the expense of a higher latency. Ignored
275276
if "static_image_mode" is True, where face detection simply runs on
276277
every image. Default to 0.5.
278+
outputs: A list of the graph output stream names to observe. If the list
279+
is empty, all the output streams listed in the graph config will be
280+
automatically observed by default.
277281
"""
278282
super().__init__(
279283
binary_graph_path=BINARYPB_FILE_PATH,
@@ -287,7 +291,7 @@ def __init__(self,
287291
'facelandmarkcpu__ThresholdingCalculator.threshold':
288292
min_tracking_confidence,
289293
},
290-
outputs=['multi_face_landmarks'])
294+
outputs=list(outputs) if outputs else [])
291295

292296
def process(self, image: np.ndarray) -> NamedTuple:
293297
"""Processes an RGB image and returns the face landmarks on each detected face.
@@ -300,8 +304,12 @@ def process(self, image: np.ndarray) -> NamedTuple:
300304
ValueError: If the input image is not three channel RGB.
301305
302306
Returns:
303-
A NamedTuple object with a "multi_face_landmarks" field that contains the
304-
face landmarks on each detected face.
307+
A NamedTuple object with fields corresponding to the set of outputs passed to the
308+
constructor. Fields may include:
309+
"multi_hand_landmarks" The face landmarks on each detected face
310+
"face_detections" The detected faces
311+
"face_rects_from_landmarks" Regions of interest calculated based on landmarks
312+
"face_rects_from_detections" Regions of interest calculated based on face detections
305313
"""
306314

307315
return super().process(input_data={'image': image})

mediapipe/python/solutions/hands.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"""MediaPipe Hands."""
1717

1818
import enum
19-
from typing import NamedTuple
19+
from typing import NamedTuple, Optional, Tuple
2020

2121
import numpy as np
2222

@@ -168,7 +168,8 @@ def __init__(self,
168168
static_image_mode=False,
169169
max_num_hands=2,
170170
min_detection_confidence=0.7,
171-
min_tracking_confidence=0.5):
171+
min_tracking_confidence=0.5,
172+
outputs: Optional[Tuple[str]] = ('multi_hand_landmarks', 'multi_handedness')):
172173
"""Initializes a MediaPipe Hand object.
173174
174175
Args:
@@ -193,6 +194,9 @@ def __init__(self,
193194
robustness of the solution, at the expense of a higher latency. Ignored
194195
if "static_image_mode" is True, where hand detection simply runs on
195196
every image. Default to 0.5.
197+
outputs: A tuple of the graph output stream names to observe. If the tuple
198+
is empty, all the output streams listed in the graph config will be
199+
automatically observed by default.
196200
"""
197201
super().__init__(
198202
binary_graph_path=BINARYPB_FILE_PATH,
@@ -206,7 +210,7 @@ def __init__(self,
206210
'handlandmarkcpu__ThresholdingCalculator.threshold':
207211
min_tracking_confidence,
208212
},
209-
outputs=['multi_hand_landmarks', 'multi_handedness'])
213+
outputs=list(outputs) if outputs else [])
210214

211215
def process(self, image: np.ndarray) -> NamedTuple:
212216
"""Processes an RGB image and returns the hand landmarks and handedness of each detected hand.
@@ -219,10 +223,13 @@ def process(self, image: np.ndarray) -> NamedTuple:
219223
ValueError: If the input image is not three channel RGB.
220224
221225
Returns:
222-
A NamedTuple object with two fields: a "multi_hand_landmarks" field that
223-
contains the hand landmarks on each detected hand and a "multi_handedness"
224-
field that contains the handedness (left v.s. right hand) of the detected
225-
hand.
226+
A NamedTuple object with fields corresponding to the set of outputs passed to the
227+
constructor. Fields may include:
228+
"multi_hand_landmarks" The hand landmarks on each detected hand
229+
"multi_handedness" The handedness (left v.s. right hand) of the detected hand
230+
"palm_detections" The detected palms
231+
"hand_rects" Regions of interest calculated based on landmarks
232+
"hand_rects_from_palm_detections" Regions of interest calculated based on palm detections
226233
"""
227234

228235
return super().process(input_data={'image': image})

mediapipe/python/solutions/pose.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"""MediaPipe Pose."""
1717

1818
import enum
19-
from typing import NamedTuple
19+
from typing import NamedTuple, Optional, Tuple
2020

2121
import numpy as np
2222

@@ -159,7 +159,8 @@ class Pose(SolutionBase):
159159
def __init__(self,
160160
static_image_mode=False,
161161
min_detection_confidence=0.5,
162-
min_tracking_confidence=0.5):
162+
min_tracking_confidence=0.5,
163+
outputs: Optional[Tuple[str]] = ('pose_landmarks',)):
163164
"""Initializes a MediaPipe Pose object.
164165
165166
Args:
@@ -181,6 +182,9 @@ def __init__(self,
181182
increase robustness of the solution, at the expense of a higher latency.
182183
Ignored if "static_image_mode" is True, where person detection simply
183184
runs on every image. Default to 0.5.
185+
outputs: A list of the graph output stream names to observe. If the list
186+
is empty, all the output streams listed in the graph config will be
187+
automatically observed by default.
184188
"""
185189
super().__init__(
186190
binary_graph_path=BINARYPB_FILE_PATH,
@@ -193,7 +197,7 @@ def __init__(self,
193197
'poselandmarkupperbodycpu__poselandmarkupperbodybyroicpu__ThresholdingCalculator.threshold':
194198
min_tracking_confidence,
195199
},
196-
outputs=['pose_landmarks'])
200+
outputs=list(outputs) if outputs else [])
197201

198202
def process(self, image: np.ndarray) -> NamedTuple:
199203
"""Processes an RGB image and returns the pose landmarks on the most prominent person detected.
@@ -206,8 +210,12 @@ def process(self, image: np.ndarray) -> NamedTuple:
206210
ValueError: If the input image is not three channel RGB.
207211
208212
Returns:
209-
A NamedTuple object with a "pose_landmarks" field that contains the pose
210-
landmarks on the most prominent person detected.
213+
A NamedTuple object with fields corresponding to the set of outputs passed to the
214+
constructor. Fields may include:
215+
"pose_landmarks" The pose landmarks on the most prominent person detected
216+
"pose_detection" The detected pose
217+
"pose_rect_from_landmarks" Region of interest calculated based on landmarks
218+
"pose_rect_from_detection" Region of interest calculated based on pose detection
211219
"""
212220

213221
return super().process(input_data={'image': image})

0 commit comments

Comments
 (0)