Skip to content

Commit

Permalink
Duplicates (#224)
Browse files Browse the repository at this point in the history
  • Loading branch information
caspervdw authored Nov 16, 2022
1 parent bcf134a commit 856f6e6
Show file tree
Hide file tree
Showing 6 changed files with 282 additions and 79 deletions.
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ Changelog of threedi-modelchecker
- Migration to schema version 210 also fixes errors 421, 424, 425, 426, 427 by
replacing negative values with NULL.

- All settings checks are now done only on the first global settings entry.

- Added "AllEqual" warnings (codes 330 and further) that check whether grid builder global
settings are all the same in case there are multiple records.

- Added a unique check on v2_manhole.connection_node_id.


Expand Down
69 changes: 69 additions & 0 deletions tests/test_checks_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from tests import factories
from threedi_modelchecker.checks import geo_query
from threedi_modelchecker.checks.base import _sqlalchemy_to_sqlite_types
from threedi_modelchecker.checks.base import AllEqualCheck
from threedi_modelchecker.checks.base import EnumCheck
from threedi_modelchecker.checks.base import ForeignKeyCheck
from threedi_modelchecker.checks.base import GeometryCheck
Expand Down Expand Up @@ -174,6 +175,52 @@ def test_unique_check_multiple_description():
)


def test_all_equal_check(session):
factories.GlobalSettingsFactory(table_step_size=0.5)
factories.GlobalSettingsFactory(table_step_size=0.5)

check = AllEqualCheck(models.GlobalSetting.table_step_size)
invalid_rows = check.get_invalid(session)
assert len(invalid_rows) == 0


def test_all_equal_check_different_value(session):
factories.GlobalSettingsFactory(table_step_size=0.5)
factories.GlobalSettingsFactory(table_step_size=0.6)
factories.GlobalSettingsFactory(table_step_size=0.5)
factories.GlobalSettingsFactory(table_step_size=0.7)

check = AllEqualCheck(models.GlobalSetting.table_step_size)
invalid_rows = check.get_invalid(session)
assert len(invalid_rows) == 2
assert invalid_rows[0].table_step_size == 0.6
assert invalid_rows[1].table_step_size == 0.7


def test_all_equal_check_null_value(session):
factories.GlobalSettingsFactory(maximum_table_step_size=None)
factories.GlobalSettingsFactory(maximum_table_step_size=None)

check = AllEqualCheck(models.GlobalSetting.maximum_table_step_size)
invalid_rows = check.get_invalid(session)
assert len(invalid_rows) == 0


def test_all_equal_check_null_value_different(session):
factories.GlobalSettingsFactory(maximum_table_step_size=1.0)
factories.GlobalSettingsFactory(maximum_table_step_size=None)

check = AllEqualCheck(models.GlobalSetting.maximum_table_step_size)
invalid_rows = check.get_invalid(session)
assert len(invalid_rows) == 1


def test_all_equal_check_no_records(session):
check = AllEqualCheck(models.GlobalSetting.table_step_size)
invalid_rows = check.get_invalid(session)
assert len(invalid_rows) == 0


def test_null_check(session):
factories.ConnectionNodeFactory.create_batch(5, storage_area=3.0)

Expand Down Expand Up @@ -710,3 +757,25 @@ def test_range_check_invalid(
assert len(invalid_rows) == 1

assert check.description() == msg.format("v2_connection_nodes.storage_area")


def test_check_only_first(session):
factories.GlobalSettingsFactory(dem_obstacle_detection=False)
factories.GlobalSettingsFactory(dem_obstacle_detection=True)

try:
active_settings = Query(models.GlobalSetting.id).limit(1).scalar_subquery()
except AttributeError:
active_settings = Query(models.GlobalSetting.id).limit(1).as_scalar()

check = QueryCheck(
error_code=302,
column=models.GlobalSetting.dem_obstacle_detection,
invalid=Query(models.GlobalSetting).filter(
models.GlobalSetting.id == active_settings,
models.GlobalSetting.dem_obstacle_detection == True,
),
message="v2_global_settings.dem_obstacle_detection is True, while this feature is not supported",
)

assert check.get_invalid(session) == []
17 changes: 17 additions & 0 deletions tests/test_model_checks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from pathlib import Path
from threedi_modelchecker.config import CHECKS
from threedi_modelchecker.model_checks import BaseCheck
from threedi_modelchecker.model_checks import LocalContext
from threedi_modelchecker.model_checks import ThreediModelChecker
from unittest import mock

Expand All @@ -19,3 +22,17 @@ def test_set_base_path(model_checker):
def test_get_model_error_iterator(model_checker):
errors = list(model_checker.errors(level="info"))
assert len(errors) == 0


def id_func(param):
if isinstance(param, BaseCheck):
return "check {}-".format(param.error_code)
return repr(param)


@pytest.mark.filterwarnings("error::")
@pytest.mark.parametrize("check", CHECKS, ids=id_func)
def test_individual_checks(threedi_db, check):
session = threedi_db.get_session()
session.model_checker_context = LocalContext(base_path=threedi_db.base_path)
assert len(check.get_invalid(session)) == 0
15 changes: 15 additions & 0 deletions threedi_modelchecker/checks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,21 @@ def description(self):
return f"{self.column_name} should to be unique"


class AllEqualCheck(BaseCheck):
"""Check all values in `column` are the same, including NULL values."""

def get_invalid(self, session):
val = session.query(self.column).limit(1).scalar()
if val is None:
clause = self.column != None
else:
clause = (self.column != val) | (self.column == None)
return self.to_check(session).filter(clause).all()

def description(self):
return f"{self.column_name} is different and is ignored if it is not in the first record"


class NotNullCheck(BaseCheck):
""" "Check all values in `column` that are not null"""

Expand Down
6 changes: 4 additions & 2 deletions threedi_modelchecker/checks/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,10 @@ def get_invalid(self, session: Session) -> List[NamedTuple]:
if total_objects != 1:
invalid_ids.append(bc.id)

return self.to_check(session).filter(
models.BoundaryCondition1D.id.in_(invalid_ids)
return (
self.to_check(session)
.filter(models.BoundaryCondition1D.id.in_(invalid_ids))
.all()
)

def description(self) -> str:
Expand Down
Loading

0 comments on commit 856f6e6

Please sign in to comment.