Skip to content

Commit 0f2bcc1

Browse files
lingvo-botcopybara-github
authored andcommitted
Add a parameter used to decide which indexes to ignore when computing the metric.
PiperOrigin-RevId: 736960394
1 parent 03fd809 commit 0f2bcc1

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

lingvo/core/metrics.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -826,16 +826,25 @@ class GroupPairAUCMetric(AUCMetric):
826826
be treated as a separate 3rd group rather than part of the 1st group.
827827
"""
828828

829-
def UpdateRaw(self, group_ids, target, logits, weight=None):
829+
def UpdateRaw(self, group_ids, target, logits, weight=None, ignore_ids=None):
830830
"""Updates the metrics.
831831
832832
Args:
833833
group_ids: An array to specify the group identity.
834834
target: An array to specify the groundtruth float values.
835835
logits: An array to specify the raw prediction logits.
836836
weight: An array to specify the sample weight for the auc computation.
837+
ignore_ids: An array to specify the indexes to ignore.
837838
"""
838839

840+
if ignore_ids is not None:
841+
mask = np.asarray(ignore_ids) == 0
842+
group_ids = (np.asarray(group_ids)[mask]).tolist()
843+
target = (np.asarray(target)[mask]).tolist()
844+
logits = (np.asarray(logits)[mask]).tolist()
845+
if weight is not None:
846+
weight = (np.asarray(weight)[mask]).tolist()
847+
839848
assert self._samples <= 0
840849

841850
sigmoid = lambda x: 1.0 / (1.0 + np.exp(-x))

0 commit comments

Comments
 (0)