Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,6 @@ examples/*.json
debug_app/.gemini_cache/
debug_app/user_overrides.json
debug_app/test_results.json

.test_baseline.json
.test_final.json
32 changes: 32 additions & 0 deletions circe/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
CohortExpressionQueryBuilder,
MarkdownRender,
)
from .cohortdefinition.yaml_utils import snake_case_dict_to_cohort_expression
from .vocabulary.concept import ConceptSet

if TYPE_CHECKING:
Expand Down Expand Up @@ -80,6 +81,37 @@ def cohort_expression_from_json(json_str: str) -> CohortExpression:
raise ValueError(f"Invalid cohort expression JSON: {str(e)}") from e


def cohort_expression_from_yaml(yaml_str: str) -> CohortExpression:
"""Load a cohort expression from a YAML string.

Args:
yaml_str: YAML string containing the cohort definition with snake_case field names

Returns:
CohortExpression instance

Raises:
ValueError: If the YAML is invalid or doesn't conform to the schema

Example:
>>> yaml_str = '''
... title: "My Cohort"
... concept_sets: []
... primary_criteria: {...}
... '''
>>> expression = cohort_expression_from_yaml(yaml_str)
"""
import yaml

try:
data = yaml.safe_load(yaml_str)
if data is None:
data = {}
return snake_case_dict_to_cohort_expression(data)
except Exception as e:
raise ValueError(f"Invalid cohort expression YAML: {str(e)}") from e


def build_cohort_query(
expression: CohortExpression,
options: Optional[BuildExpressionQueryOptions] = None,
Expand Down
48 changes: 17 additions & 31 deletions circe/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import sys
from pathlib import Path

from .api import build_cohort_query, cohort_expression_from_json, cohort_print_friendly
from .api import build_cohort_query, cohort_print_friendly
from .cohortdefinition import BuildExpressionQueryOptions
from .cohortdefinition.code_generator import to_python_code
from .io import load_expression


def main():
Expand All @@ -24,12 +25,12 @@ def main():

# Validate command
validate_parser = subparsers.add_parser("validate", help="Validate a cohort definition")
validate_parser.add_argument("input", help="Input JSON file")
validate_parser.add_argument("input", help="Input JSON or YAML file")
validate_parser.add_argument("--quiet", "-q", action="store_true", help="Only show errors")

# Generate SQL command
sql_parser = subparsers.add_parser("generate-sql", help="Generate SQL from cohort definition")
sql_parser.add_argument("input", help="Input JSON file")
sql_parser.add_argument("input", help="Input JSON or YAML file")
sql_parser.add_argument("--output", "-o", help="Output SQL file (default: stdout)")
sql_parser.add_argument("--cdm-schema", default="@cdm_database_schema", help="CDM schema name")
sql_parser.add_argument(
Expand All @@ -47,7 +48,7 @@ def main():

# Render markdown command
md_parser = subparsers.add_parser("render-markdown", help="Render cohort definition as Markdown")
md_parser.add_argument("input", help="Input JSON file")
md_parser.add_argument("input", help="Input JSON or YAML file")
md_parser.add_argument("--output", "-o", help="Output Markdown file (default: stdout)")
md_parser.add_argument("--no-validate", action="store_true", help="Skip validation")
md_parser.add_argument("--title", "-t", type=str, help="Title to add to markdown document")
Expand All @@ -56,12 +57,12 @@ def main():
source_parser = subparsers.add_parser(
"generate-source", help="Generate Python source code from cohort definition"
)
source_parser.add_argument("input", help="Input JSON file")
source_parser.add_argument("input", help="Input JSON or YAML file")
source_parser.add_argument("--output", "-o", help="Output Python file (default: stdout)")

# Process command (all-in-one)
process_parser = subparsers.add_parser("process", help="Validate, generate SQL and Markdown")
process_parser.add_argument("input", help="Input JSON file")
process_parser.add_argument("input", help="Input JSON or YAML file")
process_parser.add_argument("--sql-output", help="SQL output file")
process_parser.add_argument("--md-output", help="Markdown output file")
process_parser.add_argument("--cdm-schema", default="@cdm_database_schema", help="CDM schema name")
Expand Down Expand Up @@ -101,11 +102,8 @@ def main():

def validate_command(args):
"""Validate a cohort definition."""
# Read JSON
json_str = Path(args.input).read_text()

# Load and validate
expression = cohort_expression_from_json(json_str)
# Load expression (auto-detects JSON or YAML)
expression = load_expression(Path(args.input))

# Run validation checks
warnings = expression.check()
Expand All @@ -131,11 +129,8 @@ def validate_command(args):

def generate_sql_command(args):
"""Generate SQL from cohort definition."""
# Read JSON
json_str = Path(args.input).read_text()

# Load expression
expression = cohort_expression_from_json(json_str)
# Load expression (auto-detects JSON or YAML)
expression = load_expression(Path(args.input))

# Validate if requested
if not args.no_validate:
Expand Down Expand Up @@ -166,11 +161,8 @@ def generate_sql_command(args):

def render_markdown_command(args):
"""Render cohort definition as Markdown."""
# Read JSON
json_str = Path(args.input).read_text()

# Load expression
expression = cohort_expression_from_json(json_str)
# Load expression (auto-detects JSON or YAML)
expression = load_expression(Path(args.input))

# Validate if requested
if not args.no_validate:
Expand All @@ -195,11 +187,8 @@ def render_markdown_command(args):

def process_command(args):
"""Process cohort definition (validate, generate SQL and Markdown)."""
# Read JSON
json_str = Path(args.input).read_text()

# Load expression
expression = cohort_expression_from_json(json_str)
# Load expression (auto-detects JSON or YAML)
expression = load_expression(Path(args.input))

# Validate
warnings = expression.check()
Expand Down Expand Up @@ -237,11 +226,8 @@ def process_command(args):

def generate_source_command(args):
"""Generate Python source code from cohort definition."""
# Read JSON
json_str = Path(args.input).read_text()

# Load expression
expression = cohort_expression_from_json(json_str)
# Load expression (auto-detects JSON or YAML)
expression = load_expression(Path(args.input))

# Generate Source Code
source_code = to_python_code(expression)
Expand Down
103 changes: 103 additions & 0 deletions circe/cohortdefinition/yaml_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Utilities for YAML conversion with snake_case naming."""

import re
from typing import Any

from circe.cohortdefinition.cohort import CohortExpression


def to_snake_case(name: str) -> str:
"""Convert camelCase or PascalCase string to snake_case.

Args:
name: String in camelCase or PascalCase format

Returns:
String in snake_case format
"""
# Insert underscore before uppercase letters preceded by lowercase
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
# Insert underscore before uppercase letters preceded by lowercase or numbers
s2 = re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1)
return s2.lower()


def to_pascal_case(name: str) -> str:
"""Convert snake_case string to PascalCase.

Args:
name: String in snake_case format

Returns:
String in PascalCase format
"""
components = name.split("_")
return "".join(x.title() for x in components)


def dict_to_snake_case(data: Any) -> Any:
"""Recursively convert all dict keys from PascalCase/camelCase to snake_case.

Args:
data: Dictionary, list, or primitive value

Returns:
Same structure with all dict keys converted to snake_case
"""
if isinstance(data, dict):
return {to_snake_case(key): dict_to_snake_case(value) for key, value in data.items()}
elif isinstance(data, list):
return [dict_to_snake_case(item) for item in data]
else:
return data


def dict_to_pascal_case(data: Any) -> Any:
"""Recursively convert all dict keys from snake_case to PascalCase.

Args:
data: Dictionary, list, or primitive value

Returns:
Same structure with all dict keys converted to PascalCase
"""
if isinstance(data, dict):
return {to_pascal_case(key): dict_to_pascal_case(value) for key, value in data.items()}
elif isinstance(data, list):
return [dict_to_pascal_case(item) for item in data]
else:
return data


def cohort_expression_to_snake_case(expr: CohortExpression) -> dict[str, Any]:
"""Convert CohortExpression to dict with snake_case field names.

Args:
expr: CohortExpression instance

Returns:
Dictionary representation with all keys in snake_case
"""
# Use model_dump to convert to dict with serialization aliases
expr_dict = expr.model_dump(by_alias=True)
# Convert all keys to snake_case
return dict_to_snake_case(expr_dict)


def snake_case_dict_to_cohort_expression(data: dict[str, Any]) -> CohortExpression:
"""Convert snake_case dict to CohortExpression.

Args:
data: Dictionary with snake_case keys

Returns:
CohortExpression instance
"""
# CohortExpression models have populate_by_name=True which accepts snake_case
# So we can pass the data directly without conversion
try:
return CohortExpression.model_validate(data)
except Exception:
# If that fails, try converting to PascalCase as fallback
pascal_dict = dict_to_pascal_case(data)
return CohortExpression.model_validate(pascal_dict)
45 changes: 39 additions & 6 deletions circe/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from pathlib import Path
from typing import Any, Union

from .api import cohort_expression_from_json
from .api import cohort_expression_from_json, cohort_expression_from_yaml
from .cohortdefinition import CohortExpression
from .cohortdefinition.yaml_utils import cohort_expression_to_snake_case

ExpressionInput = Union[CohortExpression, Mapping[str, Any], str, Path]

Expand All @@ -25,7 +26,8 @@ def load_expression(value: ExpressionInput) -> CohortExpression:
- CohortExpression
- mapping/dict compatible with CohortExpression
- JSON string
- path to a JSON file
- YAML string
- path to a JSON or YAML file
"""
if isinstance(value, CohortExpression):
return value
Expand All @@ -34,7 +36,11 @@ def load_expression(value: ExpressionInput) -> CohortExpression:
return CohortExpression.model_validate(dict(value))

if isinstance(value, Path):
return cohort_expression_from_json(value.read_text(encoding="utf-8"))
content = value.read_text(encoding="utf-8")
if value.suffix in (".yaml", ".yml"):
return cohort_expression_from_yaml(content)
else:
return cohort_expression_from_json(content)

if isinstance(value, str):
stripped = value.strip()
Expand All @@ -46,17 +52,44 @@ def load_expression(value: ExpressionInput) -> CohortExpression:
# File-system path
path = Path(value)
if path.exists() and path.is_file():
return cohort_expression_from_json(path.read_text(encoding="utf-8"))
content = path.read_text(encoding="utf-8")
if path.suffix in (".yaml", ".yml"):
return cohort_expression_from_yaml(content)
else:
return cohort_expression_from_json(content)

# If it wasn't an existing path, attempt JSON parse for clearer errors.
try:
parsed = json.loads(stripped)
except json.JSONDecodeError as exc:
raise ValueError(
"Expected JSON string or path to a JSON file for cohort expression input."
"Expected JSON string, YAML string, or path to a JSON/YAML file for cohort expression input."
) from exc
return CohortExpression.model_validate(parsed)

raise TypeError(
"Unsupported expression input type. Expected CohortExpression, mapping, JSON string, or Path."
"Unsupported expression input type. Expected CohortExpression, mapping, JSON/YAML string, or Path."
)


def save_expression_as_yaml(expr: CohortExpression, path: str | Path) -> None:
"""Save a CohortExpression as a YAML file with snake_case field names.

Args:
expr: CohortExpression instance to save
path: File path to save the YAML file to
"""
import yaml

path = Path(path)
yaml_dict = cohort_expression_to_snake_case(expr)

# Write to file with nice YAML formatting
with open(path, "w", encoding="utf-8") as f:
yaml.dump(
yaml_dict,
f,
default_flow_style=False,
sort_keys=False,
allow_unicode=True,
)
Loading
Loading