Skip to content

Commit

Permalink
just commit for merge another branch
Browse files Browse the repository at this point in the history
  • Loading branch information
bwook00 committed Apr 12, 2024
1 parent d1949fa commit 6f6c18a
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 3 deletions.
1 change: 1 addition & 0 deletions autorag/nodes/passagefilter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .ner_pii_masking import ner_pii_masking
from .pass_passage_filter import pass_passage_filter
from .threshold_cutoff import similarity_threshold_cutoff
17 changes: 14 additions & 3 deletions autorag/nodes/passagefilter/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import functools
import os
from pathlib import Path
from typing import Union, Tuple, List

import pandas as pd

from autorag.utils import result_to_dataframe, validate_qa_dataset
from autorag.utils import result_to_dataframe, validate_qa_dataset, fetch_contents


# same with passage filter from now
Expand Down Expand Up @@ -33,8 +34,18 @@ def wrapper(
assert "retrieved_ids" in previous_result.columns, "previous_result must have retrieved_ids column."
ids = previous_result["retrieved_ids"].tolist()

filtered_contents, filtered_ids, filtered_scores = func(queries=queries, contents_list=contents,
scores_list=scores, ids_list=ids, *args, **kwargs)
if func.__name__ == 'recency_filter':
corpus_df = pd.read_parquet(os.path.join(project_dir, "data", "corpus.parquet"))
metadatas = fetch_contents(corpus_df, ids, column_name='metadata')
times = [[time['last_modified_datetime'] for time in time_list] for time_list in metadatas]
filtered_contents, filtered_ids, filtered_scores \
= func(contents_list=contents, scores_list=scores, ids_list=ids, time_list=times, *args, **kwargs)
elif func.__name__ == 'ner_pii_masking':
filtered_contents, filtered_ids, filtered_scores = func(contents_list=contents,
scores_list=scores, ids_list=ids, *args, **kwargs)
else:
filtered_contents, filtered_ids, filtered_scores = func(queries=queries, contents_list=contents,
scores_list=scores, ids_list=ids, *args, **kwargs)

return filtered_contents, filtered_ids, filtered_scores

Expand Down
33 changes: 33 additions & 0 deletions autorag/nodes/passagefilter/ner_pii_masking.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,35 @@
from typing import List, Tuple

from transformers import pipeline

from autorag.nodes.passagefilter.base import passage_filter_node


@passage_filter_node
def ner_pii_masking(contents_list: List[List[str]],
scores_list: List[List[float]], ids_list: List[List[str]],
) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]:
"""
Mask PII in the contents using NER.
Uses HF transformers model.
:param contents_list: The list of lists of contents to filter
:param scores_list: The list of lists of scores retrieved
:param ids_list: The list of lists of ids retrieved
:return: Tuple of lists containing the filtered contents, ids, and scores
"""
model = pipeline("ner", grouped_entities=True)

masked_contents_list = list(
map(lambda contents: list(map(lambda content: mask_pii(model, content), contents)), contents_list))

return masked_contents_list, ids_list, scores_list


def mask_pii(model, text: str) -> str:
new_text = text
response = model(text)
for entry in response:
entity_group_tag = f"[{entry['entity_group']}_{entry['start']}]"
new_text = new_text.replace(entry["word"], entity_group_tag).strip()
return new_text
1 change: 1 addition & 0 deletions autorag/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def get_support_modules(module_name: str) -> Callable:
# passage_filter
'pass_passage_filter': ('autorag.nodes.passagefilter', 'pass_passage_filter'),
'similarity_threshold_cutoff': ('autorag.nodes.passagefilter', 'similarity_threshold_cutoff'),
''
# passage_compressor
'tree_summarize': ('autorag.nodes.passagecompressor', 'tree_summarize'),
'pass_compressor': ('autorag.nodes.passagecompressor', 'pass_compressor'),
Expand Down
15 changes: 15 additions & 0 deletions tests/autorag/nodes/passagefilter/test_ner_pii_masking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from autorag.nodes.passagefilter import ner_pii_masking
from tests.autorag.nodes.passagefilter.test_passage_filter_base import base_passage_filter_test, contents_example, \
ids_example, scores_example, project_dir, previous_result, base_passage_filter_node_test


def test_ner_pii_masking():
original_ner = ner_pii_masking.__wrapped__
contents, ids, scores = original_ner(contents_example, scores_example, ids_example)
assert contents[1][3] == "[PER_0] is one of the members of [ORG_34]."
base_passage_filter_test(contents, ids, scores)


def test_ner_pii_masking_node():
result_df = ner_pii_masking(project_dir=project_dir, previous_result=previous_result)
base_passage_filter_node_test(result_df)

0 comments on commit 6f6c18a

Please sign in to comment.