diff --git a/skit_labels/cli.py b/skit_labels/cli.py index 9dba2df..4d6572a 100644 --- a/skit_labels/cli.py +++ b/skit_labels/cli.py @@ -319,12 +319,17 @@ def build_cli(): return parser -def upload_dataset(input_file, url, token, job_id, data_source, data_label = None): +def upload_dataset(input_file, url, token, job_id, data_source, data_label = None, tagging_type=None): input_file = utils.add_data_label(input_file, data_label) if data_source == const.SOURCE__DB: fn = commands.upload_dataset_to_db elif data_source == const.SOURCE__LABELSTUDIO: - fn = commands.upload_dataset_to_labelstudio + if tagging_type: + is_valid, error = utils.validate_input_data(tagging_type, input_file) + if not is_valid: + return error, None + + fn = commands.upload_dataset_to_labelstudio errors, df_size = asyncio.run( fn( input_file, diff --git a/skit_labels/constants.py b/skit_labels/constants.py index ae80e08..b7fe6ef 100644 --- a/skit_labels/constants.py +++ b/skit_labels/constants.py @@ -120,4 +120,8 @@ FROM_NAME_INTENT = "tag" CHOICES = "choices" TAXONOMY = "taxonomy" -VALUE = "value" \ No newline at end of file +VALUE = "value" + +EXPECTED_COLUMNS_MAPPING = { + "conversation_tagging": ['situation_id', 'situation_str', 'call'] +} \ No newline at end of file diff --git a/skit_labels/utils.py b/skit_labels/utils.py index 7979eac..793d8d9 100644 --- a/skit_labels/utils.py +++ b/skit_labels/utils.py @@ -10,7 +10,7 @@ from datetime import datetime import pandas as pd from typing import Union - +from constants import EXPECTED_COLUMNS_MAPPING LOG_LEVELS = ["CRITICAL", "ERROR", "WARNING", "SUCCESS", "INFO", "DEBUG", "TRACE"] @@ -110,3 +110,37 @@ def add_data_label(input_file: str, data_label: Optional[str] = None) -> str: df = df.assign(data_label=data_label) df.to_csv(input_file, index=False) return input_file + + +def validate_headers(input_file, tagging_type): + expected_columns_mapping = EXPECTED_COLUMNS_MAPPING + expected_headers = expected_columns_mapping.get(tagging_type) + + df = pd.read_csv(input_file) + column_headers = df.columns.to_list() + column_headers = [header.lower() for header in column_headers] + column_headers = sorted(column_headers) + expected_headers = sorted(expected_headers) + logger.info(f"column_headers: {column_headers}") + logger.info(f"expected_headers: {expected_headers}") + + is_match = column_headers == expected_headers + mismatch_headers = [] + logger.info(f"Is match: {is_match}") + + if not is_match: + mismatch_headers_set =set(column_headers).symmetric_difference(set(expected_headers)) + mismatch_headers = list(mismatch_headers_set) + return is_match, mismatch_headers + + +def validate_input_data(tagging_type, input_file): + is_valid = True + error = '' + if tagging_type == 'conversation_tagging': + is_match, mismatch_headers = validate_headers(input_file, tagging_type) + if not is_match: + error = f'Headers in the input file does not match the expected fields. Mismatched fields = {mismatch_headers}' + is_valid = False + + return is_valid, error \ No newline at end of file