Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions t5/data/glue_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ def get_glue_postprocess_fn(builder_config):
return postprocessors.multirc
elif builder_config.name == "record":
return postprocessors.record
elif builder_config.name == "boolq":
return functools.partial(
postprocessors.fuzzy_string_label_to_class_id,
label_classes=builder_config.label_classes,
)
else:
return functools.partial(
postprocessors.string_label_to_class_id,
Expand Down
28 changes: 28 additions & 0 deletions t5/data/postprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,34 @@ def string_label_to_class_id(string_label,
return default


def fuzzy_string_label_to_class_id(string_label,
label_classes,
default=-1,
**unused_kwargs):
"""'Fuzzy' version of string_label_to_class_id that matches prefixes.

This is part of a best effort strategy to decode from models that haven't
learned to output class labels properly. For example, if string_label is
'Falseeeee' and label_classes is {'True', 'False'}, we return the index of
'False' since 'False' is a prefix of 'Falseeeee'.

Args:
string_label: an input string that we want to match against labels in
label_classes.
label_classes: a Sequence of strings to be matched against.
default: fallback label if there is no match.

Returns:
label_index: index of matched label or default.
"""

for candidate in label_classes:
if string_label.startswith(candidate):
return label_classes.index(candidate)

return default


def multirc(string_label, example=None, is_target=False):
"""Returns dict containing the class with the question index for grouping."""
res = {
Expand Down