Skip to content
15 changes: 15 additions & 0 deletions src/diffa/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@ def cli():
type=str,
help="Target table name.",
)
@click.option(
"--diff-dimensions",
multiple=True,
type=str,
help="Diff dimension columns.",
)
@click.option(
"--full-diff",
is_flag=True,
help="Full diff mode. Re-run the diff from the beginning.",
)
def data_diff(
*,
source_db_uri: str = None,
Expand All @@ -68,6 +79,8 @@ def data_diff(
target_database: str = None,
target_schema: str = "public",
target_table: str,
diff_dimensions: tuple = None,
full_diff: bool = False,
):
config_manager = ConfigManager().configure(
source_database=source_database,
Expand All @@ -79,6 +92,8 @@ def data_diff(
source_db_uri=source_db_uri,
target_db_uri=target_db_uri,
diffa_db_uri=diffa_db_uri,
diff_dimension_cols=list(diff_dimensions) if diff_dimensions else None,
full_diff=full_diff,
)
run_manager = RunManager(config_manager=config_manager)
check_manager = CheckManager(config_manager=config_manager)
Expand Down
21 changes: 18 additions & 3 deletions src/diffa/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import date
from enum import Enum
from urllib.parse import urlparse
from typing import List, Optional

from diffa.utils import Logger

Expand All @@ -11,7 +12,7 @@
DIFFA_DB_SCHEMA = "diffa"
DIFFA_DB_TABLE = "diffa_checks"
DIFFA_CHECK_RUNS_TABLE = "diffa_check_runs"
DIFFA_BEGIN_DATE = date(2024, 1, 1)
DIFFA_BEGIN_DATE = date(2020, 6, 1) # Matching with Ascenda start date


class ExitCode(Enum):
Expand Down Expand Up @@ -103,11 +104,20 @@ def get_db_table(self):

class SourceConfig(DBConfig):
"""A class to handle the configs for the Source DBs"""
def __init__(self, *args, diff_dimension_cols: Optional[List[str]] = None, **kwargs):
super().__init__(*args, **kwargs)
self.diff_dimension_cols = diff_dimension_cols or []


def get_diff_dimension_cols(self):
return self.diff_dimension_cols
class DiffaConfig(DBConfig):
"""A class to handle the configs for the Diffa DB"""
def __init__(self, *args, full_diff: bool = False, **kwargs):
super().__init__(*args, **kwargs)
self.full_diff = full_diff

def is_full_diff(self):
return self.full_diff

class ConfigManager:
"""Manage all the configuration needed for Diffa Operations"""
Expand Down Expand Up @@ -143,21 +153,26 @@ def configure(
target_schema: str = "public",
target_table: str,
diffa_db_uri: str = None,
diff_dimension_cols: List[str] = None,
full_diff: bool = False,
):
self.source.update(
db_uri=source_db_uri,
db_name=source_database,
db_schema=source_schema,
db_table=source_table,
diff_dimension_cols=diff_dimension_cols,
)
self.target.update(
db_uri=target_db_uri,
db_name=target_database,
db_schema=target_schema,
db_table=target_table,
diff_dimension_cols=diff_dimension_cols,
)
self.diffa_check.update(
db_uri=diffa_db_uri,
full_diff=full_diff,
)
self.diffa_check_run.update(
db_uri=diffa_db_uri,
Expand Down Expand Up @@ -203,7 +218,7 @@ def save_config(self, source_uri: str, target_uri: str, diffa_uri: str):
logger.info("Configuration saved to successfully.")

def __getattr__(self, __name: str) -> DBConfig:
"""Dynamically access DBConfig attributes (e.g config_manager.source.database)"""
"""Dynamically access DBConfig attributes (e.g config_manager.source.get_db_name())"""

if __name in self.config:
return self.config[__name]
Expand Down
116 changes: 90 additions & 26 deletions src/diffa/db/data_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import date
from typing import Optional
from dataclasses import dataclass, field
from typing import Optional, List, Tuple, Any
from dataclasses import dataclass, fields, make_dataclass
from functools import reduce
import uuid

from sqlalchemy import (
Expand Down Expand Up @@ -148,51 +149,114 @@ def validate_status(self):
return self


@dataclass
@dataclass(frozen=True)
class CountCheck:
"""A single count check in Source/Target Database"""

cnt: int
check_date: date

@classmethod
def create_with_dimensions(cls, dimension_cols: Optional[List[str]] = None):
"""Factory method to create a CountCheck class with dimension fields"""

return make_dataclass(
cls.__name__,
[(col, str) for col in sorted(dimension_cols)] if dimension_cols else [],
bases=(cls,),
frozen=True,
)

@classmethod
def get_base_fields(cls) -> List[Tuple[str, type]]:
return [("check_date", date), ("cnt", int)]

@classmethod
def get_dimension_fields(cls) -> List[Tuple[str, type]]:
base_fields = {name for name, _ in cls.get_base_fields()}

return [(f.name, f.type) for f in fields(cls) if f.name not in base_fields]

def get_dimension_values(self):
# check_date is still considered as a dimension field. In fact, it's a main dimension field.
return {
f[0]: getattr(self, f[0])
for f in self.get_dimension_fields() + [("check_date", date)]
}

def to_flatten_dimension_format(self) -> dict:
return {tuple(self.get_dimension_values().items()): self}


@dataclass
class MergedCountCheck:
"""A merged count check after checking count in Source/Target Databases"""

source_count: int
target_count: int
check_date: date
is_valid: bool = field(init=False)

def __post_init__(self):
self.is_valid = True if self.source_count <= self.target_count else False
def __init__(
self,
source_count: int,
target_count: int,
check_date: date,
is_valid: Optional[bool] = None,
**kwargs: Any,
):
self.source_count = source_count
self.target_count = target_count
self.check_date = check_date
for key, value in kwargs.items():
setattr(self, key, value)

self.is_valid = (
is_valid if is_valid is not None else source_count <= target_count
)

def __eq__(self, other):
if not isinstance(other, MergedCountCheck):
return NotImplemented
return (
self.source_count == other.source_count
and self.target_count == other.target_count
and self.check_date == other.check_date
and self.is_valid == other.is_valid
return self.__dict__ == other.__dict__

def __lt__(self, other):
if not isinstance(other, MergedCountCheck):
return NotImplemented
dynamic_fields = [
f
for f in self.__dict__.keys()
if f not in ["source_count", "target_count", "check_date", "is_valid"]
]
precedence = (
["check_date"]
+ dynamic_fields
+ ["source_count", "target_count", "is_valid"]
)

return tuple(getattr(self, f) for f in precedence) < tuple(
getattr(other, f) for f in precedence
)

def __str__(self):
return f"MergedCountCheck({", ".join(f"{k}={v!r}" for k, v in self.__dict__.items())})"

@classmethod
def create_with_dimensions(cls, dimension_fields: List[Tuple[str, type]]):
"""Factory method to dynamically create a MergedCountCheck with a CountCheck schema"""

return type(
cls.__name__,
(cls,),
reduce(
lambda x, y: x | y, map(lambda x: {x[0]: x[1]}, dimension_fields), {}
),
)

@classmethod
def from_counts(
cls, source: Optional[CountCheck] = None, target: Optional[CountCheck] = None
):
if source and target:
if source.check_date != target.check_date:
raise ValueError("Source and target counts are not for the same date.")
elif not source and not target:
raise ValueError("Both source and target counts are missing.")

check_date = source.check_date if source else target.check_date
source_count = source.cnt if source else 0
target_count = target.cnt if target else 0
count_check = source if source else target
merged_count_check_values = count_check.get_dimension_values()
merged_count_check_values["source_count"] = source.cnt if source else 0
merged_count_check_values["target_count"] = target.cnt if target else 0

return cls(source_count, target_count, check_date)
return cls(**merged_count_check_values)

def to_diffa_check_schema(
self,
Expand Down
21 changes: 16 additions & 5 deletions src/diffa/db/diffa_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sqlalchemy.dialects.postgresql import insert

from diffa.db.connect import DiffaConnection
from diffa.config import DBConfig, ConfigManager, DIFFA_BEGIN_DATE
from diffa.config import DiffaConfig, ConfigManager, DIFFA_BEGIN_DATE
from diffa.db.data_models import (
DiffaCheckSchema,
DiffaCheck,
Expand All @@ -20,7 +20,7 @@
class DiffaCheckDatabase:
"""SQLAlchemy Database Adapter for Diffa state management"""

def __init__(self, db_config: DBConfig):
def __init__(self, db_config: DiffaConfig):
self.db_config = db_config
self.conn = DiffaConnection(self.db_config.get_db_config())

Expand Down Expand Up @@ -103,6 +103,7 @@ class DiffaCheckService:
def __init__(self, config_manager: ConfigManager):
self.config_manager = config_manager
self.diffa_db = DiffaCheckDatabase(self.config_manager.diffa_check)
self.is_full_diff = self.config_manager.diffa_check.is_full_diff()

def get_last_check_date(self) -> date:

Expand All @@ -115,8 +116,16 @@ def get_last_check_date(self) -> date:
target_table=self.config_manager.target.get_db_table(),
)

check_date = latest_check["check_date"] if latest_check else DIFFA_BEGIN_DATE
logger.info(f"Last check date: {check_date}")
if not self.is_full_diff:
check_date = (
latest_check["check_date"] if latest_check else DIFFA_BEGIN_DATE
)
logger.info(f"Last check date: {check_date}")
else:
check_date = DIFFA_BEGIN_DATE
logger.info(
f"Full diff mode is enabled. Checking from the beginning. Last check date: {check_date}"
)

return check_date

Expand All @@ -134,7 +143,9 @@ def get_invalid_check_dates(self) -> Iterable[date]:
invalid_check_dates = [
invalid_check["check_date"] for invalid_check in invalid_checks
]
if len(invalid_check_dates) > 0:
if self.is_full_diff:
return None
elif len(invalid_check_dates) > 0:
logger.info(
f"The number of invalid check dates is: {len(invalid_check_dates)}"
)
Expand Down
Loading