Skip to content

Commit

Permalink
Add validation for tagging type
Browse files Browse the repository at this point in the history
  • Loading branch information
d-shree committed Dec 26, 2023
1 parent 3900ac6 commit 857e539
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 4 deletions.
9 changes: 7 additions & 2 deletions skit_labels/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,12 +319,17 @@ def build_cli():
return parser


def upload_dataset(input_file, url, token, job_id, data_source, data_label = None):
def upload_dataset(input_file, url, token, job_id, data_source, data_label = None, tagging_type=None):
input_file = utils.add_data_label(input_file, data_label)
if data_source == const.SOURCE__DB:
fn = commands.upload_dataset_to_db
elif data_source == const.SOURCE__LABELSTUDIO:
fn = commands.upload_dataset_to_labelstudio
if tagging_type:
is_valid, error = utils.validate_input_data(tagging_type, input_file)
if not is_valid:
return error, None

fn = commands.upload_dataset_to_labelstudio
errors, df_size = asyncio.run(
fn(
input_file,
Expand Down
6 changes: 5 additions & 1 deletion skit_labels/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,8 @@
FROM_NAME_INTENT = "tag"
CHOICES = "choices"
TAXONOMY = "taxonomy"
VALUE = "value"
VALUE = "value"

EXPECTED_COLUMNS_MAPPING = {
"conversation_tagging": ['situation_id', 'situation_str', 'call']
}
36 changes: 35 additions & 1 deletion skit_labels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from datetime import datetime
import pandas as pd
from typing import Union

from constants import EXPECTED_COLUMNS_MAPPING

LOG_LEVELS = ["CRITICAL", "ERROR", "WARNING", "SUCCESS", "INFO", "DEBUG", "TRACE"]

Expand Down Expand Up @@ -110,3 +110,37 @@ def add_data_label(input_file: str, data_label: Optional[str] = None) -> str:
df = df.assign(data_label=data_label)
df.to_csv(input_file, index=False)
return input_file


def validate_headers(input_file, tagging_type):
expected_columns_mapping = EXPECTED_COLUMNS_MAPPING
expected_headers = expected_columns_mapping.get(tagging_type)

df = pd.read_csv(input_file)
column_headers = df.columns.to_list()
column_headers = [header.lower() for header in column_headers]
column_headers = sorted(column_headers)
expected_headers = sorted(expected_headers)
logger.info(f"column_headers: {column_headers}")
logger.info(f"expected_headers: {expected_headers}")

is_match = column_headers == expected_headers
mismatch_headers = []
logger.info(f"Is match: {is_match}")

if not is_match:
mismatch_headers_set =set(column_headers).symmetric_difference(set(expected_headers))
mismatch_headers = list(mismatch_headers_set)
return is_match, mismatch_headers


def validate_input_data(tagging_type, input_file):
is_valid = True
error = ''
if tagging_type == 'conversation_tagging':
is_match, mismatch_headers = validate_headers(input_file, tagging_type)
if not is_match:
error = f'Headers in the input file does not match the expected fields. Mismatched fields = {mismatch_headers}'
is_valid = False

return is_valid, error

0 comments on commit 857e539

Please sign in to comment.