-
-
Notifications
You must be signed in to change notification settings - Fork 276
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
just commit for merge another branch
- Loading branch information
Showing
5 changed files
with
64 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .ner_pii_masking import ner_pii_masking | ||
from .pass_passage_filter import pass_passage_filter | ||
from .threshold_cutoff import similarity_threshold_cutoff |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,35 @@ | ||
from typing import List, Tuple | ||
|
||
from transformers import pipeline | ||
|
||
from autorag.nodes.passagefilter.base import passage_filter_node | ||
|
||
|
||
@passage_filter_node | ||
def ner_pii_masking(contents_list: List[List[str]], | ||
scores_list: List[List[float]], ids_list: List[List[str]], | ||
) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]: | ||
""" | ||
Mask PII in the contents using NER. | ||
Uses HF transformers model. | ||
:param contents_list: The list of lists of contents to filter | ||
:param scores_list: The list of lists of scores retrieved | ||
:param ids_list: The list of lists of ids retrieved | ||
:return: Tuple of lists containing the filtered contents, ids, and scores | ||
""" | ||
model = pipeline("ner", grouped_entities=True) | ||
|
||
masked_contents_list = list( | ||
map(lambda contents: list(map(lambda content: mask_pii(model, content), contents)), contents_list)) | ||
|
||
return masked_contents_list, ids_list, scores_list | ||
|
||
|
||
def mask_pii(model, text: str) -> str: | ||
new_text = text | ||
response = model(text) | ||
for entry in response: | ||
entity_group_tag = f"[{entry['entity_group']}_{entry['start']}]" | ||
new_text = new_text.replace(entry["word"], entity_group_tag).strip() | ||
return new_text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from autorag.nodes.passagefilter import ner_pii_masking | ||
from tests.autorag.nodes.passagefilter.test_passage_filter_base import base_passage_filter_test, contents_example, \ | ||
ids_example, scores_example, project_dir, previous_result, base_passage_filter_node_test | ||
|
||
|
||
def test_ner_pii_masking(): | ||
original_ner = ner_pii_masking.__wrapped__ | ||
contents, ids, scores = original_ner(contents_example, scores_example, ids_example) | ||
assert contents[1][3] == "[PER_0] is one of the members of [ORG_34]." | ||
base_passage_filter_test(contents, ids, scores) | ||
|
||
|
||
def test_ner_pii_masking_node(): | ||
result_df = ner_pii_masking(project_dir=project_dir, previous_result=previous_result) | ||
base_passage_filter_node_test(result_df) |