Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Claude Instructions for circepy

## Python Environment
- Always use virtualenv for Python operations (don't rely on system Python or unauthenticated pip installs)
- Activate the virtual environment before running Python commands or installing packages

## Starting tasks - record testing state

At the start of any task, record the state of tests as a baseline. It is not your job to fix pre-existing issues unless otherwise specified.

Run tests with multiprocess for speed and store the state:
```bash
pytest -n auto --tb=short -v --json-report --json-report-file=.test_baseline.json
```

If the test state file is not created, check that pytest-xdist and pytest-json-report are installed in the virtualenv.

## Pre-completion Checklist
Before completing any task:

1. Re-run pytest to verify no regressions:
```bash
pytest -n auto --tb=short -v --json-report --json-report-file=.test_final.json
```

Compare `.test_baseline.json` with `.test_final.json` — the final state should not show new failures.

2. Run git pre-commit checks:
```bash
git pre-commit run --all-files
```

If pre-commit checks fail, fix the issues and re-run until they pass.

## Git Workflow
- Do not run `git commit` — the user will handle commits
- Run pre-commit checks to validate code quality before marking tasks complete
20 changes: 10 additions & 10 deletions circe/check/checkers/base_corelated_criteria_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def _internal_check(self, expression: "CohortExpression", reporter: WarningRepor
if inclusion_rule.expression and inclusion_rule.expression.criteria_list:
for criteria in inclusion_rule.expression.criteria_list:
# Skip if criteria is still a dict (shouldn't happen after deserialization, but be defensive)
if isinstance(criteria, dict):
continue
if isinstance(criteria, dict): # type: ignore[unreachable]
continue # type: ignore[unreachable]
group_name = f"{self.INCLUSION_RULE}{inclusion_rule.name}"
self._check_criteria(criteria, group_name, reporter)
if hasattr(criteria, "criteria") and criteria.criteria:
Expand All @@ -60,16 +60,16 @@ def _check_criteria_group(self, criteria: "Criteria", group_name: str, reporter:
reporter: The warning reporter to use
"""
# Skip if criteria is still a dict (not yet deserialized)
if isinstance(criteria, dict):
return
if isinstance(criteria, dict): # type: ignore[unreachable]
return # type: ignore[unreachable]

if hasattr(criteria, "correlated_criteria") and criteria.correlated_criteria:
correlated = criteria.correlated_criteria
if hasattr(correlated, "criteria_list") and correlated.criteria_list:
for corelated_criteria in correlated.criteria_list:
# Skip dicts
if isinstance(corelated_criteria, dict):
continue
if isinstance(corelated_criteria, dict): # type: ignore[unreachable]
continue # type: ignore[unreachable]
self._check_criteria(corelated_criteria, group_name, reporter)
if hasattr(corelated_criteria, "criteria") and corelated_criteria.criteria:
self._check_criteria_group(corelated_criteria.criteria, group_name, reporter)
Expand All @@ -78,8 +78,8 @@ def _check_criteria_group(self, criteria: "Criteria", group_name: str, reporter:
if hasattr(group, "criteria_list") and group.criteria_list:
for corelated_criteria in group.criteria_list:
# Skip dicts
if isinstance(corelated_criteria, dict):
continue
if isinstance(corelated_criteria, dict): # type: ignore[unreachable]
continue # type: ignore[unreachable]
self._check_criteria(corelated_criteria, group_name, reporter)
if hasattr(corelated_criteria, "criteria") and corelated_criteria.criteria:
self._check_criteria_group(corelated_criteria.criteria, group_name, reporter)
Expand All @@ -99,7 +99,7 @@ def _check_criteria(
"""
# Skip if criteria is still a dict (not yet deserialized)
# This can happen when Pydantic doesn't fully deserialize polymorphic types
if isinstance(criteria, dict):
return
if isinstance(criteria, dict): # type: ignore[unreachable]
return # type: ignore[unreachable]

raise NotImplementedError("Subclasses must implement _check_criteria")
4 changes: 2 additions & 2 deletions circe/check/checkers/comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def compare_to(filter_val: "ObservationFilter", window: "Window") -> int:
An integer representing the comparison result
"""
if filter_val is None or window is None:
return 0
return 0 # type: ignore[unreachable]

range1 = filter_val.post_days + filter_val.prior_days
range2_start = 0
Expand All @@ -144,7 +144,7 @@ def is_before(window: "Window") -> bool:
True if the window is before, False otherwise
"""
if window is None:
return False
return False # type: ignore[unreachable]
return Comparisons.is_before_endpoint(window.start) and not Comparisons.is_after_endpoint(window.end)

@staticmethod
Expand Down
4 changes: 3 additions & 1 deletion circe/check/checkers/death_time_window_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
Reference: JAVA_CLASS_MAPPINGS.md for Java equivalents.
"""

from typing import Any

from ..operations.operations import Operations
from ..utils.criteria_name_helper import CriteriaNameHelper
from ..warning_severity import WarningSeverity
Expand Down Expand Up @@ -122,7 +124,7 @@ def _check_criteria(
"""
name = f"{group_name} {CriteriaNameHelper.get_criteria_name(criteria.criteria)}"

match_result = Operations.match(criteria.criteria)
match_result: Any = Operations.match(criteria.criteria)
match_result.is_a(Death)
match_result.then(
lambda death: (
Expand Down
8 changes: 5 additions & 3 deletions circe/check/checkers/drug_era_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
Reference: JAVA_CLASS_MAPPINGS.md for Java equivalents.
"""

from typing import Any

from ..operations.operations import Operations
from ..warning_severity import WarningSeverity
from .base_corelated_criteria_check import BaseCorelatedCriteriaCheck
Expand Down Expand Up @@ -53,15 +55,15 @@ def _check_criteria(
reporter: The warning reporter to use
"""
# Handle case where criteria is still a dict (not yet deserialized)
if isinstance(criteria, dict):
if isinstance(criteria, dict): # type: ignore[unreachable]
# Skip validation for dict-based criteria - they need to be deserialized first
return
return # type: ignore[unreachable]

# Ensure criteria has a criteria attribute
if not hasattr(criteria, "criteria") or not criteria.criteria:
return

match_result = Operations.match(criteria.criteria)
match_result: Any = Operations.match(criteria.criteria)
match_result.is_a(DrugEra)
match_result.then(
lambda c: (
Expand Down
4 changes: 3 additions & 1 deletion circe/check/checkers/exit_criteria_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
Reference: JAVA_CLASS_MAPPINGS.md for Java equivalents.
"""

from typing import Any

from ..operations.operations import Operations
from .base_check import BaseCheck
from .warning_reporter import WarningReporter
Expand Down Expand Up @@ -39,7 +41,7 @@ def _check(self, expression: "CohortExpression", reporter: WarningReporter) -> N
expression: The cohort expression to check
reporter: The warning reporter to use
"""
match_result = Operations.match(expression.end_strategy)
match_result: Any = Operations.match(expression.end_strategy)
match_result.is_a(CustomEraStrategy)
match_result.then(
lambda s: (
Expand Down
4 changes: 3 additions & 1 deletion circe/check/checkers/exit_criteria_days_offset_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
Reference: JAVA_CLASS_MAPPINGS.md for Java equivalents.
"""

from typing import Any

from ..operations.operations import Operations
from ..warning_severity import WarningSeverity
from .base_check import BaseCheck
Expand Down Expand Up @@ -48,7 +50,7 @@ def _check(self, expression: "CohortExpression", reporter: WarningReporter) -> N
expression: The cohort expression to check
reporter: The warning reporter to use
"""
match_result = Operations.match(expression.end_strategy)
match_result: Any = Operations.match(expression.end_strategy)
match_result.is_a(DateOffsetStrategy)
match_result.then(
lambda s: (
Expand Down
4 changes: 3 additions & 1 deletion circe/check/checkers/initial_event_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
Reference: JAVA_CLASS_MAPPINGS.md for Java equivalents.
"""

from typing import Any

from ..operations.operations import Operations
from .base_check import BaseCheck
from .warning_reporter import WarningReporter
Expand Down Expand Up @@ -37,7 +39,7 @@ def _check(self, expression: "CohortExpression", reporter: WarningReporter) -> N
expression: The cohort expression to check
reporter: The warning reporter to use
"""
match_result = Operations.match(expression)
match_result: Any = Operations.match(expression)
match_result.when(
lambda e: (
e.primary_criteria is None
Expand Down
4 changes: 3 additions & 1 deletion circe/check/checkers/no_exit_criteria_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
Reference: JAVA_CLASS_MAPPINGS.md for Java equivalents.
"""

from typing import Any

from ..operations.operations import Operations
from ..warning_severity import WarningSeverity
from .base_check import BaseCheck
Expand Down Expand Up @@ -46,7 +48,7 @@ def _check(self, expression: "CohortExpression", reporter: WarningReporter) -> N
expression: The cohort expression to check
reporter: The warning reporter to use
"""
match_result = Operations.match(expression)
match_result: Any = Operations.match(expression)
match_result.when(
lambda e: (
e.primary_criteria
Expand Down
4 changes: 2 additions & 2 deletions circe/check/checkers/ocurrence_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Reference: JAVA_CLASS_MAPPINGS.md for Java equivalents.
"""

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from ..operations.operations import Operations
from ..warning_severity import WarningSeverity
Expand Down Expand Up @@ -52,6 +52,6 @@ def _check_criteria(
reporter: The warning reporter to use
"""
if criteria.occurrence:
match_result = Operations.match(criteria.occurrence)
match_result: Any = Operations.match(criteria.occurrence)
match_result.when(lambda o: o.type == self.AT_LEAST and o.count == 0)
match_result.then(lambda o: reporter(self.AT_LEAST_0_WARNING))
4 changes: 2 additions & 2 deletions circe/check/checkers/range_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def _check_inclusion_rules(self, expression: "CohortExpression", reporter: Warni
if rule.expression and rule.expression.criteria_list:
for criteria in rule.expression.criteria_list:
# Handle both dict and CorelatedCriteria objects
if isinstance(criteria, dict):
start_window = criteria.get("startWindow") or criteria.get("start_window")
if isinstance(criteria, dict): # type: ignore[unreachable]
start_window = criteria.get("startWindow") or criteria.get("start_window") # type: ignore[unreachable]
end_window = criteria.get("endWindow") or criteria.get("end_window")
else:
start_window = getattr(criteria, "start_window", None) or getattr(
Expand Down
8 changes: 4 additions & 4 deletions circe/check/checkers/range_checker_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Reference: JAVA_CLASS_MAPPINGS.md for Java equivalents.
"""

from typing import Callable, Optional
from typing import Any, Callable, Optional

from ..constants import Constants
from ..operations.operations import Operations
Expand Down Expand Up @@ -613,7 +613,7 @@ def warning(template: str) -> None:

if isinstance(range_val, DateRange):
# Date range checks
match_result = Operations.match(range_val)
match_result: Any = Operations.match(range_val)
match_result.when(lambda r: r.value is not None and not Comparisons.is_date_valid(r.value)).then(
lambda x: warning(self.WARNING_DATE_IS_INVALID)
)
Expand All @@ -639,7 +639,7 @@ def warning(template: str) -> None:
)
elif isinstance(range_val, NumericRange):
# Numeric range checks
match_result = Operations.match(range_val)
match_result: Any = Operations.match(range_val)
match_result.when(lambda r: r.op is not None and r.op.endswith("bt")).then(
lambda r: (
Operations.match(r)
Expand Down Expand Up @@ -673,7 +673,7 @@ def check_range(self, period: Optional["Period"], criteria_name: str, attribute:
def warning(template: str) -> None:
self._reporter(template, self._group_name, criteria_name, attribute)

match_result = Operations.match(period)
match_result: Any = Operations.match(period)
match_result.when(
lambda x: x.start_date is not None and not Comparisons.is_date_valid(x.start_date)
).then(lambda x: warning(self.WARNING_DATE_IS_INVALID))
Expand Down
4 changes: 2 additions & 2 deletions circe/check/checkers/time_window_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Reference: JAVA_CLASS_MAPPINGS.md for Java equivalents.
"""

from typing import Optional
from typing import Any, Optional

from ..operations.operations import Operations
from ..utils.criteria_name_helper import CriteriaNameHelper
Expand Down Expand Up @@ -77,7 +77,7 @@ def _check_criteria(
"""
name = f"{group_name} {CriteriaNameHelper.get_criteria_name(criteria.criteria)}"

match_result = Operations.match(criteria)
match_result: Any = Operations.match(criteria)
match_result.when(
lambda c: (
c.start_window is not None
Expand Down
2 changes: 1 addition & 1 deletion circe/check/utils/criteria_name_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,5 @@ def get_criteria_name(criteria) -> str:
.is_a(PayerPlanPeriod)
.then_return(lambda c: Constants.Criteria.PAYER_PLAN_PERIOD)
.value()
or "unknown criteria"
or "unknown criteria" # type: ignore[unreachable]
)
2 changes: 0 additions & 2 deletions circe/cohortdefinition/builders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ def get_criteria_sql_with_options(self, criteria: T, options: Optional[BuilderOp
)
else:
query = query.replace("@additionalColumns", "")
else:
query = query.replace("@additionalColumns", "")

return query

Expand Down
2 changes: 0 additions & 2 deletions circe/cohortdefinition/builders/visit_occurrence.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,6 @@ def resolve_where_clauses(

return where_clauses

return where_clauses

def embed_ordinal_expression(
self,
query: str,
Expand Down
2 changes: 1 addition & 1 deletion circe/cohortdefinition/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def instance_is_pydantic(o):
# Generate Imports
import_lines = []
# Group by module
module_map = {}
module_map: dict[str, list[str]] = {}
for cls in required_classes:
mod = cls.__module__
if mod not in module_map:
Expand Down
4 changes: 2 additions & 2 deletions circe/cohortdefinition/cohort_expression_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,9 +1471,9 @@ def get_criteria_sql(self, criteria: Criteria, options: Optional[BuilderOptions]
Java equivalent: Various getCriteriaSql methods
"""
# Handle case where criteria is still a dict (shouldn't happen, but be defensive)
if isinstance(criteria, dict):
if isinstance(criteria, dict): # type: ignore[unreachable]
# Try to deserialize it - import here to avoid circular dependency issues
from .criteria import ConditionEra as CE
from .criteria import ConditionEra as CE # type: ignore[unreachable]
from .criteria import ConditionOccurrence as CO
from .criteria import Death as D
from .criteria import DeviceExposure as DevE
Expand Down
10 changes: 1 addition & 9 deletions circe/cohortdefinition/criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,14 +271,6 @@ def _serialize_polymorphic(self, serializer, info):

return {self.__class__.__name__: data}

# Get the serialized data using default serialization
data = serializer(self)
# Wrap in class name for polymorphic deserialization in Java
# Only wrap if this is a subclass (not the base Criteria class)
if self.__class__.__name__ != "Criteria":
return {self.__class__.__name__: data}
return data

def accept(self, dispatcher: Any, options: Optional[Any] = None) -> str:
"""Accept method for visitor pattern."""
return dispatcher.get_criteria_sql(self, options)
Expand Down Expand Up @@ -1146,7 +1138,7 @@ def deserialize_criteria_list(cls, v: Any) -> Any:
# Helper window normalizer (same as before)
def normalize_window(window_dict: dict) -> dict:
if not isinstance(window_dict, dict):
return window_dict
return window_dict # type: ignore[unreachable]
normalized = {}
if "UseEventEnd" in window_dict:
normalized["useEventEnd"] = window_dict["UseEventEnd"]
Expand Down
2 changes: 1 addition & 1 deletion circe/cohortdefinition/printfriendly/markdown_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def _format_number(self, value: Union[int, float]) -> str:
Formatted string (e.g. "1,500" or "1.5")
"""
if value is None:
return ""
return "" # type: ignore[unreachable]

# If matches integer, convert to int for clean formatting
if isinstance(value, float) and value.is_integer():
Expand Down
Loading
Loading