Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 111 additions & 48 deletions dagster_sqlmesh/console.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import inspect
import logging
import textwrap
import typing as t
import unittest
import uuid
Expand Down Expand Up @@ -148,9 +147,10 @@ class Plan(BaseConsoleEvent):
@dataclass(kw_only=True)
class LogTestResults(BaseConsoleEvent):
result: unittest.result.TestResult
output: str | None
output: str | None = None
target_dialect: str


@dataclass(kw_only=True)
class ShowSQL(BaseConsoleEvent):
sql: str
Expand Down Expand Up @@ -221,7 +221,7 @@ class ShowTableDiffSummary(BaseConsoleEvent):

@dataclass(kw_only=True)
class PlanBuilt(BaseConsoleEvent):
plan: SQLMeshPlan
plan: SQLMeshPlan

ConsoleEvent = (
StartPlanEvaluation
Expand Down Expand Up @@ -277,6 +277,8 @@ class PlanBuilt(BaseConsoleEvent):
]

T = t.TypeVar("T")
EventType = t.TypeVar("EventType", bound=BaseConsoleEvent)


def get_console_event_by_name(
event_name: str,
Expand All @@ -303,7 +305,7 @@ def __init_subclass__(cls):
for known_event in known_events_classes:
assert inspect.isclass(known_event), "event must be a class"
known_events.append(known_event.__name__)


# Iterate through all the available abstract methods in console
for method_name in Console.__abstractmethods__:
Expand All @@ -319,7 +321,7 @@ def __init_subclass__(cls):
# events has it's values checked. The dataclass should define the
# required fields and everything else should be sent to a catchall
# argument in the dataclass for the event

# Convert method name from snake_case to camel case
camel_case_method_name = "".join(
word.capitalize()
Expand All @@ -329,7 +331,9 @@ def __init_subclass__(cls):
if camel_case_method_name in known_events:
logger.debug(f"Creating {method_name} for {camel_case_method_name}")
signature = inspect.signature(getattr(Console, method_name))
handler = cls.create_event_handler(method_name, camel_case_method_name, signature)
event_cls = get_console_event_by_name(camel_case_method_name)
assert event_cls is not None, f"Event {camel_case_method_name} not found"
handler = cls.create_event_handler(method_name, event_cls, signature)
setattr(cls, method_name, handler)
else:
logger.debug(f"Creating {method_name} for unknown event")
Expand All @@ -338,51 +342,23 @@ def __init_subclass__(cls):
setattr(cls, method_name, handler)

@classmethod
def create_event_handler(cls, method_name: str, event_name: str, signature: inspect.Signature):
func_signature, call_params = cls.create_signatures_and_params(signature)
def create_event_handler(cls, method_name: str, event_cls: type[BaseConsoleEvent], signature: inspect.Signature) -> t.Callable[..., None]:
"""Create a GeneratedCallable for known events."""
def handler(self: IntrospectingConsole, *args: t.Any, **kwargs: t.Any) -> None:
callable_handler = GeneratedCallable(self, event_cls, signature, method_name)
return callable_handler(*args, **kwargs)

event_handler_str = textwrap.dedent(f"""
def {method_name}({", ".join(func_signature)}):
self.publish_known_event('{event_name}', {", ".join(call_params)})
""")
exec(event_handler_str)
return t.cast(t.Callable[[t.Any], t.Any], locals()[method_name])
return handler

@classmethod
def create_signatures_and_params(cls, signature: inspect.Signature):
func_signature: list[str] = []
call_params: list[str] = []
for param_name, param in signature.parameters.items():
if param_name == "self":
func_signature.append("self")
continue

if param.default is inspect._empty:
param_type_name = param.annotation
if not isinstance(param_type_name, str):
param_type_name = param_type_name.__name__
func_signature.append(f"{param_name}: '{param_type_name}'")
else:
default_value = param.default
param_type_name = param.annotation
if not isinstance(param_type_name, str):
param_type_name = param_type_name.__name__
if isinstance(param.default, str):
default_value = f"'{param.default}'"
func_signature.append(f"{param_name}: '{param_type_name}' = {default_value}")
call_params.append(f"{param_name}={param_name}")
return (func_signature, call_params)

@classmethod
def create_unknown_event_handler(cls, method_name: str, signature: inspect.Signature):
func_signature, call_params = cls.create_signatures_and_params(signature)
def create_unknown_event_handler(cls, method_name: str, signature: inspect.Signature) -> t.Callable[..., None]:
"""Create an UnknownEventCallable for unknown events."""
def handler(self: IntrospectingConsole, *args: t.Any, **kwargs: t.Any) -> None:
callable_handler = UnknownEventCallable(self, method_name, signature)
return callable_handler(*args, **kwargs)

event_handler_str = textwrap.dedent(f"""
def {method_name}({", ".join(func_signature)}):
self.publish_unknown_event('{method_name}', {", ".join(call_params)})
""")
exec(event_handler_str)
return t.cast(t.Callable[[t.Any], t.Any], locals()[method_name])
return handler

def __init__(self, log_override: logging.Logger | None = None) -> None:
self._handlers: dict[str, ConsoleEventHandler] = {}
Expand All @@ -394,16 +370,17 @@ def __init__(self, log_override: logging.Logger | None = None) -> None:
def publish_known_event(self, event_name: str, **kwargs: t.Any) -> None:
console_event = get_console_event_by_name(event_name)
assert console_event is not None, f"Event {event_name} not found"

expected_kwargs_fields = console_event.__dataclass_fields__
expected_kwargs: dict[str, t.Any] = {}
unknown_args: dict[str, t.Any] = {}

for key, value in kwargs.items():
if key not in expected_kwargs_fields:
unknown_args[key] = value
else:
expected_kwargs[key] = value

event = console_event(**expected_kwargs, unknown_args=unknown_args)

self.publish(event)
Expand Down Expand Up @@ -446,6 +423,92 @@ def capture_built_plan(self, plan: SQLMeshPlan) -> None:
"""Capture the built plan and publish a PlanBuilt event."""
self.publish(PlanBuilt(plan=plan))


class GeneratedCallable(t.Generic[EventType]):
"""A callable that dynamically handles console method invocations and converts them to events."""

def __init__(
self,
console: IntrospectingConsole,
event_cls: type[EventType],
original_signature: inspect.Signature,
method_name: str
):
self.console = console
self.event_cls = event_cls
self.original_signature = original_signature
self.method_name = method_name

def __call__(self, *args: t.Any, **kwargs: t.Any) -> None:
"""Create an instance of the event class with the provided arguments."""
# Bind arguments to the original signature
try:
bound = self.original_signature.bind(*args, **kwargs)
bound.apply_defaults()
except TypeError as e:
# If binding fails, collect all args/kwargs as unknown
self.console.logger.warning(f"Failed to bind arguments for {self.method_name}: {e}")
unknown_args = {str(i): arg for i, arg in enumerate(args[1:])} # Skip 'self'
unknown_args.update(kwargs)
self._create_and_publish_event({}, unknown_args)
return

# Process bound arguments
bound_args = dict(bound.arguments)
bound_args.pop("self", None) # Remove self from arguments

self._create_and_publish_event(bound_args, {})

def _create_and_publish_event(self, bound_args: dict[str, t.Any], extra_unknown: dict[str, t.Any]) -> None:
"""Create and publish the event with proper argument handling."""
expected_fields = self.event_cls.__dataclass_fields__
expected_kwargs: dict[str, t.Any] = {}
unknown_args: dict[str, t.Any] = {}

# Add any extra unknown args first
unknown_args.update(extra_unknown)

# Process bound arguments
for key, value in bound_args.items():
if key in expected_fields:
expected_kwargs[key] = value
else:
unknown_args[key] = value

# Create and publish the event
event = self.event_cls(**expected_kwargs, unknown_args=unknown_args)
self.console.publish(t.cast(ConsoleEvent, event))


class UnknownEventCallable:
"""A callable for handling unknown console events."""

def __init__(
self,
console: IntrospectingConsole,
method_name: str,
original_signature: inspect.Signature
):
self.console = console
self.method_name = method_name
self.original_signature = original_signature

def __call__(self, *args: t.Any, **kwargs: t.Any) -> None:
"""Handle unknown event method calls."""
# Bind arguments to the original signature
try:
bound = self.original_signature.bind(*args, **kwargs)
bound.apply_defaults()
bound_args = dict(bound.arguments)
bound_args.pop("self", None) # Remove self from arguments
except TypeError:
# If binding fails, collect all args/kwargs
bound_args = {str(i): arg for i, arg in enumerate(args[1:])} # Skip 'self'
bound_args.update(kwargs)

self.console.publish_unknown_event(self.method_name, **bound_args)


class EventConsole(IntrospectingConsole):
"""
A console implementation that manages and publishes events related to
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ readme = "README.md"
requires-python = ">=3.11,<3.13"
dependencies = [
"dagster>=1.7.8",
"sqlmesh<0.188",
"sqlmesh>=0.188",
"pytest>=8.3.2",
"pyarrow>=18.0.0",
"pydantic>=2.11.5",
Expand Down Expand Up @@ -41,7 +41,7 @@ exclude = [
"**/.github",
"**/.vscode",
"**/.idea",
"**/.pytest_cache",
"**/.pytest_cache",
]
pythonVersion = "3.11"
reportUnknownParameterType = true
Expand Down
Loading
Loading