diff --git a/src/diffa/cli.py b/src/diffa/cli.py index 512764f..688428e 100644 --- a/src/diffa/cli.py +++ b/src/diffa/cli.py @@ -57,6 +57,17 @@ def cli(): type=str, help="Target table name.", ) +@click.option( + "--diff-dimensions", + multiple=True, + type=str, + help="Diff dimension columns.", +) +@click.option( + "--full-diff", + is_flag=True, + help="Full diff mode. Re-run the diff from the beginning.", +) def data_diff( *, source_db_uri: str = None, @@ -68,6 +79,8 @@ def data_diff( target_database: str = None, target_schema: str = "public", target_table: str, + diff_dimensions: tuple = None, + full_diff: bool = False, ): config_manager = ConfigManager().configure( source_database=source_database, @@ -79,6 +92,8 @@ def data_diff( source_db_uri=source_db_uri, target_db_uri=target_db_uri, diffa_db_uri=diffa_db_uri, + diff_dimension_cols=list(diff_dimensions) if diff_dimensions else None, + full_diff=full_diff, ) run_manager = RunManager(config_manager=config_manager) check_manager = CheckManager(config_manager=config_manager) diff --git a/src/diffa/config.py b/src/diffa/config.py index 9f6d8de..bd15db4 100644 --- a/src/diffa/config.py +++ b/src/diffa/config.py @@ -3,6 +3,7 @@ from datetime import date from enum import Enum from urllib.parse import urlparse +from typing import List, Optional from diffa.utils import Logger @@ -11,7 +12,7 @@ DIFFA_DB_SCHEMA = "diffa" DIFFA_DB_TABLE = "diffa_checks" DIFFA_CHECK_RUNS_TABLE = "diffa_check_runs" -DIFFA_BEGIN_DATE = date(2024, 1, 1) +DIFFA_BEGIN_DATE = date(2020, 6, 1) # Matching with Ascenda start date class ExitCode(Enum): @@ -103,11 +104,20 @@ def get_db_table(self): class SourceConfig(DBConfig): """A class to handle the configs for the Source DBs""" + def __init__(self, *args, diff_dimension_cols: Optional[List[str]] = None, **kwargs): + super().__init__(*args, **kwargs) + self.diff_dimension_cols = diff_dimension_cols or [] - + def get_diff_dimension_cols(self): + return self.diff_dimension_cols class DiffaConfig(DBConfig): """A class to handle the configs for the Diffa DB""" + def __init__(self, *args, full_diff: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.full_diff = full_diff + def is_full_diff(self): + return self.full_diff class ConfigManager: """Manage all the configuration needed for Diffa Operations""" @@ -143,21 +153,26 @@ def configure( target_schema: str = "public", target_table: str, diffa_db_uri: str = None, + diff_dimension_cols: List[str] = None, + full_diff: bool = False, ): self.source.update( db_uri=source_db_uri, db_name=source_database, db_schema=source_schema, db_table=source_table, + diff_dimension_cols=diff_dimension_cols, ) self.target.update( db_uri=target_db_uri, db_name=target_database, db_schema=target_schema, db_table=target_table, + diff_dimension_cols=diff_dimension_cols, ) self.diffa_check.update( db_uri=diffa_db_uri, + full_diff=full_diff, ) self.diffa_check_run.update( db_uri=diffa_db_uri, @@ -203,7 +218,7 @@ def save_config(self, source_uri: str, target_uri: str, diffa_uri: str): logger.info("Configuration saved to successfully.") def __getattr__(self, __name: str) -> DBConfig: - """Dynamically access DBConfig attributes (e.g config_manager.source.database)""" + """Dynamically access DBConfig attributes (e.g config_manager.source.get_db_name())""" if __name in self.config: return self.config[__name] diff --git a/src/diffa/db/data_models.py b/src/diffa/db/data_models.py index c864adc..559954d 100644 --- a/src/diffa/db/data_models.py +++ b/src/diffa/db/data_models.py @@ -1,6 +1,7 @@ from datetime import date -from typing import Optional -from dataclasses import dataclass, field +from typing import Optional, List, Tuple, Any +from dataclasses import dataclass, fields, make_dataclass +from functools import reduce import uuid from sqlalchemy import ( @@ -148,51 +149,114 @@ def validate_status(self): return self -@dataclass +@dataclass(frozen=True) class CountCheck: """A single count check in Source/Target Database""" cnt: int check_date: date + @classmethod + def create_with_dimensions(cls, dimension_cols: Optional[List[str]] = None): + """Factory method to create a CountCheck class with dimension fields""" + + return make_dataclass( + cls.__name__, + [(col, str) for col in sorted(dimension_cols)] if dimension_cols else [], + bases=(cls,), + frozen=True, + ) + + @classmethod + def get_base_fields(cls) -> List[Tuple[str, type]]: + return [("check_date", date), ("cnt", int)] + + @classmethod + def get_dimension_fields(cls) -> List[Tuple[str, type]]: + base_fields = {name for name, _ in cls.get_base_fields()} + + return [(f.name, f.type) for f in fields(cls) if f.name not in base_fields] + + def get_dimension_values(self): + # check_date is still considered as a dimension field. In fact, it's a main dimension field. + return { + f[0]: getattr(self, f[0]) + for f in self.get_dimension_fields() + [("check_date", date)] + } + + def to_flatten_dimension_format(self) -> dict: + return {tuple(self.get_dimension_values().items()): self} + -@dataclass class MergedCountCheck: """A merged count check after checking count in Source/Target Databases""" - source_count: int - target_count: int - check_date: date - is_valid: bool = field(init=False) - - def __post_init__(self): - self.is_valid = True if self.source_count <= self.target_count else False + def __init__( + self, + source_count: int, + target_count: int, + check_date: date, + is_valid: Optional[bool] = None, + **kwargs: Any, + ): + self.source_count = source_count + self.target_count = target_count + self.check_date = check_date + for key, value in kwargs.items(): + setattr(self, key, value) + + self.is_valid = ( + is_valid if is_valid is not None else source_count <= target_count + ) def __eq__(self, other): if not isinstance(other, MergedCountCheck): return NotImplemented - return ( - self.source_count == other.source_count - and self.target_count == other.target_count - and self.check_date == other.check_date - and self.is_valid == other.is_valid + return self.__dict__ == other.__dict__ + + def __lt__(self, other): + if not isinstance(other, MergedCountCheck): + return NotImplemented + dynamic_fields = [ + f + for f in self.__dict__.keys() + if f not in ["source_count", "target_count", "check_date", "is_valid"] + ] + precedence = ( + ["check_date"] + + dynamic_fields + + ["source_count", "target_count", "is_valid"] + ) + + return tuple(getattr(self, f) for f in precedence) < tuple( + getattr(other, f) for f in precedence + ) + + def __str__(self): + return f"MergedCountCheck({", ".join(f"{k}={v!r}" for k, v in self.__dict__.items())})" + + @classmethod + def create_with_dimensions(cls, dimension_fields: List[Tuple[str, type]]): + """Factory method to dynamically create a MergedCountCheck with a CountCheck schema""" + + return type( + cls.__name__, + (cls,), + reduce( + lambda x, y: x | y, map(lambda x: {x[0]: x[1]}, dimension_fields), {} + ), ) @classmethod def from_counts( cls, source: Optional[CountCheck] = None, target: Optional[CountCheck] = None ): - if source and target: - if source.check_date != target.check_date: - raise ValueError("Source and target counts are not for the same date.") - elif not source and not target: - raise ValueError("Both source and target counts are missing.") - - check_date = source.check_date if source else target.check_date - source_count = source.cnt if source else 0 - target_count = target.cnt if target else 0 + count_check = source if source else target + merged_count_check_values = count_check.get_dimension_values() + merged_count_check_values["source_count"] = source.cnt if source else 0 + merged_count_check_values["target_count"] = target.cnt if target else 0 - return cls(source_count, target_count, check_date) + return cls(**merged_count_check_values) def to_diffa_check_schema( self, diff --git a/src/diffa/db/diffa_check.py b/src/diffa/db/diffa_check.py index af24a2f..23c031b 100644 --- a/src/diffa/db/diffa_check.py +++ b/src/diffa/db/diffa_check.py @@ -6,7 +6,7 @@ from sqlalchemy.dialects.postgresql import insert from diffa.db.connect import DiffaConnection -from diffa.config import DBConfig, ConfigManager, DIFFA_BEGIN_DATE +from diffa.config import DiffaConfig, ConfigManager, DIFFA_BEGIN_DATE from diffa.db.data_models import ( DiffaCheckSchema, DiffaCheck, @@ -20,7 +20,7 @@ class DiffaCheckDatabase: """SQLAlchemy Database Adapter for Diffa state management""" - def __init__(self, db_config: DBConfig): + def __init__(self, db_config: DiffaConfig): self.db_config = db_config self.conn = DiffaConnection(self.db_config.get_db_config()) @@ -103,6 +103,7 @@ class DiffaCheckService: def __init__(self, config_manager: ConfigManager): self.config_manager = config_manager self.diffa_db = DiffaCheckDatabase(self.config_manager.diffa_check) + self.is_full_diff = self.config_manager.diffa_check.is_full_diff() def get_last_check_date(self) -> date: @@ -115,8 +116,16 @@ def get_last_check_date(self) -> date: target_table=self.config_manager.target.get_db_table(), ) - check_date = latest_check["check_date"] if latest_check else DIFFA_BEGIN_DATE - logger.info(f"Last check date: {check_date}") + if not self.is_full_diff: + check_date = ( + latest_check["check_date"] if latest_check else DIFFA_BEGIN_DATE + ) + logger.info(f"Last check date: {check_date}") + else: + check_date = DIFFA_BEGIN_DATE + logger.info( + f"Full diff mode is enabled. Checking from the beginning. Last check date: {check_date}" + ) return check_date @@ -134,7 +143,9 @@ def get_invalid_check_dates(self) -> Iterable[date]: invalid_check_dates = [ invalid_check["check_date"] for invalid_check in invalid_checks ] - if len(invalid_check_dates) > 0: + if self.is_full_diff: + return None + elif len(invalid_check_dates) > 0: logger.info( f"The number of invalid check dates is: {len(invalid_check_dates)}" ) diff --git a/src/diffa/db/source_target.py b/src/diffa/db/source_target.py index 96cdad9..9c418fe 100644 --- a/src/diffa/db/source_target.py +++ b/src/diffa/db/source_target.py @@ -1,12 +1,13 @@ from datetime import date -from typing import List, Iterable +from typing import List, Iterable, Optional from concurrent.futures import ThreadPoolExecutor +from functools import partial import psycopg2.extras from diffa.utils import Logger from diffa.db.connect import PostgresConnection -from diffa.config import DBConfig +from diffa.config import SourceConfig from diffa.db.data_models import CountCheck from diffa.config import ConfigManager @@ -16,7 +17,7 @@ class SourceTargetDatabase: """Base class for the Source Target DB handling""" - def __init__(self, db_config: DBConfig) -> None: + def __init__(self, db_config: SourceConfig) -> None: self.db_config = db_config self.conn = PostgresConnection(self.db_config.get_db_config()) @@ -34,7 +35,10 @@ def _execute_query(self, query: str, sql_params: tuple = None): raise e def _build_count_query( - self, latest_check_date: date, invalid_check_dates: List[date] + self, + latest_check_date: date, + invalid_check_dates: List[date], + diff_dimension_cols: Optional[List[str]] = None, ): backfill_where_clause = ( f" (created_at::DATE IN ({','.join([f"'{date}'" for date in invalid_check_dates])})) OR" @@ -47,21 +51,44 @@ def _build_count_query( created_at::DATE <= CURRENT_DATE - INTERVAL '2 DAY' ) """ + group_by_diff_dimensions_clause = ( + f", {','.join(diff_dimension_cols)}" if diff_dimension_cols else "" + ) + select_diff_dimensions_clause = ( + f", {','.join([f'{col}::text' for col in diff_dimension_cols])}" + if diff_dimension_cols + else "" + ) + return f""" SELECT created_at::DATE as check_date, - COUNT(*) AS cnt + COUNT(*) AS cnt + {select_diff_dimensions_clause} FROM {self.db_config.get_db_schema()}.{self.db_config.get_db_table()} WHERE {backfill_where_clause} {catchup_where_clause} - GROUP BY created_at::DATE + GROUP BY created_at::DATE + {group_by_diff_dimensions_clause} ORDER BY created_at::DATE ASC """ def count(self, latest_check_date: date, invalid_check_dates: List[date]): - count_query = self._build_count_query(latest_check_date, invalid_check_dates) + if self.db_config.get_diff_dimension_cols(): + count_query = self._build_count_query( + latest_check_date, + invalid_check_dates, + self.db_config.get_diff_dimension_cols(), + ) + logger.warning( + "Diff dimensions are enabled. May impact the performance of the query" + ) + else: + count_query = self._build_count_query( + latest_check_date, invalid_check_dates + ) logger.info( f"Executing the count query on {self.db_config.get_db_scheme()}: {count_query}" ) @@ -77,8 +104,15 @@ def __init__(self, config_manager: ConfigManager): def get_counts( self, last_check_date: date, invalid_check_dates: Iterable[date] ) -> Iterable[CountCheck]: - def to_count_check(count_dict: dict) -> CountCheck: - return CountCheck(**count_dict) + def to_count_check( + count_dict: dict, diff_dimension_cols: Optional[List[str]] = None + ) -> CountCheck: + if diff_dimension_cols: + return CountCheck.create_with_dimensions(diff_dimension_cols)( + **count_dict + ) + else: + return CountCheck(**count_dict) with ThreadPoolExecutor(max_workers=2) as executor: future_source_count = executor.submit( @@ -92,4 +126,16 @@ def to_count_check(count_dict: dict) -> CountCheck: future_source_count.result(), future_target_count.result(), ) - return map(to_count_check, source_counts), map(to_count_check, target_counts) + return map( + partial( + to_count_check, + diff_dimension_cols=self.source_db.db_config.get_diff_dimension_cols(), + ), + source_counts, + ), map( + partial( + to_count_check, + diff_dimension_cols=self.target_db.db_config.get_diff_dimension_cols(), + ), + target_counts, + ) diff --git a/src/diffa/managers/check_manager.py b/src/diffa/managers/check_manager.py index 28049c2..e84fcaf 100644 --- a/src/diffa/managers/check_manager.py +++ b/src/diffa/managers/check_manager.py @@ -1,4 +1,7 @@ from typing import Iterable +from datetime import date +from collections import defaultdict +from functools import reduce from diffa.db.data_models import CountCheck, MergedCountCheck from diffa.db.diffa_check import DiffaCheckService @@ -19,10 +22,10 @@ def __init__(self, config_manager: ConfigManager): def data_diff(self): """This will interupt the process when there are invalid diff found.""" - if self.compare_tables(): - logger.error("There is an invalid diff between source and target.") + if not self.compare_tables(): + logger.error("❌ There is an invalid diff between source and target.") raise InvalidDiffException - logger.info("There is no invalid diff between source and target.") + logger.info("✅ There is no invalid diff between source and target.") def compare_tables(self): """Data-diff comparison service. Will return True if there is any invalid diff.""" @@ -45,6 +48,7 @@ def compare_tables(self): last_check_date, invalid_check_dates ) merged_count_checks = self._merge_count_checks(source_counts, target_counts) + merged_by_date = self._merge_by_check_date(merged_count_checks) # Step 4: Save the merged count checks to the diffa database self.diffa_check_service.save_diffa_checks( @@ -57,20 +61,84 @@ def compare_tables(self): target_schema=self.cm.target.get_db_schema(), target_table=self.cm.target.get_db_table(), ), - merged_count_checks, + merged_by_date.values(), ) ) + # Step 5: Build and log the check summary + self._build_check_summary(merged_count_checks, merged_by_date) + # Return True if there is any invalid diff - return self._check_if_invalid_diff(merged_count_checks) + return self._check_if_valid_diff(merged_by_date.values()) + + def _check_if_valid_diff(self, merged_by_date: list[MergedCountCheck]) -> bool: + return all(mcc.is_valid for mcc in merged_by_date) + + def _build_check_summary( + self, + merged_count_checks: Iterable[MergedCountCheck], + merged_by_date: dict[date, MergedCountCheck], + ): + stats_by_day = { + check_date: { + "detailed_msgs": self._get_check_messages( + self._get_checks_by_date(merged_count_checks, check_date) + ), + "summary_msg": self._get_check_messages([mcc])[0], + } + for check_date, mcc in merged_by_date.items() + } + + summary_lines = [ + f""" + - {check_date}: + summary: + {stats['summary_msg']} + detailed: + {stats['detailed_msgs']} + """ + for check_date, stats in stats_by_day.items() + ] + stats_summary = "\n".join(summary_lines) if summary_lines else "No stats available" - def _check_if_invalid_diff( - self, merged_count_checks: Iterable[MergedCountCheck] - ) -> bool: - for merged_count_check in merged_count_checks: - if not merged_count_check.is_valid: - return True - return False + logger.info( + f""" + Data-diff comparison result: + Summary: + - Total days checked: {len(stats_by_day)} + - Stats by day: + {stats_summary} + """ + ) + + @staticmethod + def _get_check_messages(merged_count_checks: Iterable[MergedCountCheck]): + return [ + f"{'✅ No Diff' if mcc.is_valid else '❌ Diff'} {mcc}" + for mcc in merged_count_checks + ] + + @staticmethod + def _get_checks_by_date( + merged_count_checks: Iterable[MergedCountCheck], check_date: date + ) -> list[MergedCountCheck]: + return [mcc for mcc in merged_count_checks if mcc.check_date == check_date] + + @staticmethod + def _merge_by_check_date( + merged_count_checks: Iterable[MergedCountCheck], + ) -> dict[date, MergedCountCheck]: + merged = defaultdict( + lambda: dict(check_date=None, source_count=0, target_count=0, is_valid=True) + ) + for mcc in merged_count_checks: + entry = merged[mcc.check_date] + entry["source_count"] += mcc.source_count + entry["target_count"] += mcc.target_count + entry["is_valid"] &= mcc.is_valid + entry["check_date"] = mcc.check_date + + return {cd: MergedCountCheck(**data) for cd, data in merged.items()} def _merge_count_checks( self, source_counts: Iterable[CountCheck], target_counts: Iterable[CountCheck] @@ -83,18 +151,26 @@ def _merge_count_checks( Output [(1,0), (2,2), (0,4), (5,5), (6,0), (0,7)] """ - source_dict = {count.check_date: count for count in source_counts} - target_dict = {count.check_date: count for count in target_counts} + source_dict = reduce( + lambda x, y: x | y, + map(lambda x: x.to_flatten_dimension_format(), source_counts), + {}, + ) + target_dict = reduce( + lambda x, y: x | y, + map(lambda x: x.to_flatten_dimension_format(), target_counts), + {}, + ) - all_dates = set(source_dict.keys()) | set(target_dict.keys()) + all_dims = set(source_dict.keys()) | set(target_dict.keys()) merged_count_checks = [] - for check_date in all_dates: - source_count = source_dict.get(check_date) - target_count = target_dict.get(check_date) + for dim in all_dims: + source_count = source_dict.get(dim) + target_count = target_dict.get(dim) merged_count_check = MergedCountCheck.from_counts( source_count, target_count ) merged_count_checks.append(merged_count_check) - return merged_count_checks + return sorted(merged_count_checks, key=lambda x: x.check_date) diff --git a/tests/managers/test_check_manager.py b/tests/managers/test_check_manager.py index b419c75..b6370f9 100644 --- a/tests/managers/test_check_manager.py +++ b/tests/managers/test_check_manager.py @@ -71,7 +71,34 @@ def check_manager(): ) ], ), - # Case 4: Checking dates are in neither source nor target + # Case 4: Checking different dates in source and target + ( + [ + CountCheck( + cnt=200, + check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), + ) + ], + [ + CountCheck( + cnt=200, + check_date=datetime.strptime("2024-01-02", "%Y-%m-%d").date(), + ) + ], + [ + MergedCountCheck( + source_count=200, + target_count=0, + check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), + ), + MergedCountCheck( + source_count=0, + target_count=200, + check_date=datetime.strptime("2024-01-02", "%Y-%m-%d").date(), + ), + ], + ), + # Case 5: Checking dates are in neither source nor target ([], [], []), ], ) @@ -81,60 +108,295 @@ def test__merge_count_check( merged_counts = check_manager._merge_count_checks(source_counts, target_counts) assert expected_merged_counts == merged_counts +@pytest.mark.parametrize( + "source_counts, target_counts, expected_merged_counts", + [ + # Case 1: Checking dates are in both source and target + ( + [ + CountCheck.create_with_dimensions(["status", "country"])( + cnt=100, + check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), + status="True", + country="US" + ) + ], + [ + CountCheck.create_with_dimensions(["status", "country"])( + cnt=200, + check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), + status="True", + country="US" + ) + ], + [ + MergedCountCheck.create_with_dimensions(["status", "country"])( + source_count=100, + target_count=200, + check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), + status="True", + country="US" + ) + ], + ), + # Case 2: Checking dates are in source only + ( + [ + CountCheck.create_with_dimensions( + ["status", "country"])( + cnt=100, + check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), + status="True", + country="US" + ) + ], + [], + [ + MergedCountCheck.create_with_dimensions(["status", "country"])( + source_count=100, + target_count=0, + check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), + status="True", + country="US" + ) + ], + ), + # Case 3: Checking dates are in target only + ( + [], + [ + CountCheck.create_with_dimensions(["status", "country"])( + cnt=200, + check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), + status="True", + country="US" + ) + ], + [ + MergedCountCheck.create_with_dimensions(["status", "country"])( + source_count=0, + target_count=200, + check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), + status="True", + country="US" + ) + ], + ), + # Case 4: Checking different dates in source and target + ( + [ + CountCheck.create_with_dimensions(["status", "country"])( + cnt=200, + check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), + status="True", + country="US" + ), + CountCheck.create_with_dimensions(["status", "country"])( + cnt=200, + check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), + status="True", + country="Singapore" + ) + ], + [ + CountCheck.create_with_dimensions(["status", "country"])( + cnt=200, + check_date=datetime.strptime("2024-01-02", "%Y-%m-%d").date(), + status="False", + country="US" + ) + ], + [ + MergedCountCheck.create_with_dimensions(["status", "country"])( + source_count=200, + target_count=0, + check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), + status="True", + country="Singapore" + ), + MergedCountCheck.create_with_dimensions(["status", "country"])( + source_count=200, + target_count=0, + check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), + status="True", + country="US" + ), + MergedCountCheck.create_with_dimensions(["status", "country"])( + source_count=0, + target_count=200, + check_date=datetime.strptime("2024-01-02", "%Y-%m-%d").date(), + status="False", + country="US" + ), + ], + ), + # Case 5: Checking dates are in neither source nor target + ([], [], []), + + ], +) +def test__merge_count_check_with_dimensions(check_manager, source_counts, target_counts, expected_merged_counts): + merged_counts = check_manager._merge_count_checks(source_counts, target_counts) + assert expected_merged_counts == merged_counts + @pytest.mark.parametrize( - "merged_count_checks, expected_result", + "merged_count_checks, expected_merged_by_date", [ - # Case 1: All merged count checks are valid - [ + # Case 1: Merge count checks by check date with base dimension field + ( [ MergedCountCheck( source_count=100, - target_count=100, - check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), + target_count=200, + check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date() ), MergedCountCheck( + source_count=300, + target_count=400, + check_date=datetime.strptime("2024-01-02", "%Y-%m-%d").date() + ), + ], + { + datetime.strptime("2024-01-01", "%Y-%m-%d").date(): MergedCountCheck( + source_count=100, + target_count=200, + is_valid=True, + check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date() + ), + datetime.strptime("2024-01-02", "%Y-%m-%d").date(): MergedCountCheck( + source_count=300, + target_count=400, + is_valid=True, + check_date=datetime.strptime("2024-01-02", "%Y-%m-%d").date() + ), + } + ), + # Case 2: Merge count checks by check date with 1 dimension field (happy case) + ( + [ + MergedCountCheck.create_with_dimensions(["status"])( source_count=100, - target_count=150, + target_count=200, + check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), + status="True" + ), + MergedCountCheck.create_with_dimensions(["status"])( + source_count=200, + target_count=300, + check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), + status="False" + ), + MergedCountCheck.create_with_dimensions(["status"])( + source_count=400, + target_count=300, check_date=datetime.strptime("2024-01-02", "%Y-%m-%d").date(), + status="True" ), ], - False, - ], - # Case 2: All merged count checks are invalid - [ + { + datetime.strptime("2024-01-01", "%Y-%m-%d").date(): MergedCountCheck( + source_count=300, + target_count=500, + is_valid=True, + check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), + ), + datetime.strptime("2024-01-02", "%Y-%m-%d").date(): MergedCountCheck( + source_count=400, + target_count=300, + is_valid=False, + check_date=datetime.strptime("2024-01-02", "%Y-%m-%d").date(), + ), + } + ), + # Case 3: Merge count checks by check date with 2 dimension fields (unhappy case: dimenssion failure => invalid diff) + ( [ - MergedCountCheck( - source_count=150, + MergedCountCheck.create_with_dimensions(["status", "country"])( + source_count=100, + target_count=200, + check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), + status="True", + country="US" + ), + MergedCountCheck.create_with_dimensions(["status", "country"])( + source_count=200, target_count=100, check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), + status="False", + country="US" + ), + MergedCountCheck.create_with_dimensions(["status", "country"])( + source_count=300, + target_count=400, + check_date=datetime.strptime("2024-01-02", "%Y-%m-%d").date(), + status="True", + country="Singapore" ), ], - True, - ], - # Case 3: Mixed valid and invalid merged count checks - [ + { + datetime.strptime("2024-01-01", "%Y-%m-%d").date(): MergedCountCheck( + source_count=300, + target_count=300, + is_valid=False, + check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), + ), + datetime.strptime("2024-01-02", "%Y-%m-%d").date(): MergedCountCheck( + source_count=300, + target_count=400, + is_valid=True, + check_date=datetime.strptime("2024-01-02", "%Y-%m-%d").date(), + ), + } + ), + ] +) +def test__merge_by_check_date(check_manager, merged_count_checks, expected_merged_by_date): + merged_by_date = check_manager._merge_by_check_date(merged_count_checks) + assert expected_merged_by_date == merged_by_date + +@pytest.mark.parametrize( + "merged_by_date, expected_is_valid_diff", + [ + # Case 1: Happy case + ( [ MergedCountCheck( source_count=100, - target_count=100, - check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), + target_count=200, + is_valid=True, + check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date() ), MergedCountCheck( - source_count=150, - target_count=100, - check_date=datetime.strptime("2024-01-02", "%Y-%m-%d").date(), + source_count=200, + target_count=200, + is_valid=True, + check_date=datetime.strptime("2024-01-02", "%Y-%m-%d").date() ), + ], + True + ), + # Case 2: Unhappy case + ( + [ MergedCountCheck( source_count=100, - target_count=150, - check_date=datetime.strptime("2024-01-03", "%Y-%m-%d").date(), + target_count=200, + is_valid=True, + check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date() + ), + MergedCountCheck( + source_count=200, + target_count=200, + is_valid=False, + check_date=datetime.strptime("2024-01-02", "%Y-%m-%d").date() ), ], - True, - ], + False + ) ], + ) -def test__check_if_invalid_diff(check_manager, merged_count_checks, expected_result): - is_invalid_diff = check_manager._check_if_invalid_diff(merged_count_checks) - assert is_invalid_diff == expected_result +def test__check_if_valid_diff(check_manager, merged_by_date, expected_is_valid_diff): + is_valid_diff = check_manager._check_if_valid_diff(merged_by_date) + assert is_valid_diff == expected_is_valid_diff