diff --git a/src/avalan/flow/flow.py b/src/avalan/flow/flow.py index 97cabdd5..b2c68ed3 100644 --- a/src/avalan/flow/flow.py +++ b/src/avalan/flow/flow.py @@ -1,7 +1,8 @@ from ..flow.connection import Connection from ..flow.node import Node -from re import match +from collections import deque +import re from typing import Any, Callable @@ -12,10 +13,12 @@ def __init__(self) -> None: self.nodes: dict[str, Node] = {} self.connections: list[Connection] = [] self.outgoing: dict[str, list[Connection]] = {} + self.incoming: dict[str, list[Connection]] = {} def add_node(self, node: Node) -> None: self.nodes[node.name] = node self.outgoing.setdefault(node.name, []) + self.incoming.setdefault(node.name, []) def add_connection( self, @@ -32,6 +35,7 @@ def add_connection( conn = Connection(src, dest, label, conditions, filters) self.connections.append(conn) self.outgoing[src_name].append(conn) + self.incoming.setdefault(dest_name, []).append(conn) def parse_mermaid(self, mermaid: str) -> None: """Populate the flow from a Mermaid diagram. @@ -39,120 +43,194 @@ def parse_mermaid(self, mermaid: str) -> None: Edge labels are accepted in both ``A -- label --> B`` and ``A -->|label| B`` forms. Lines without edges are ignored. """ - lines = [line.strip() for line in mermaid.splitlines() if line.strip()] - if lines and lines[0].lower().startswith("graph"): - lines = lines[1:] - for line in lines: - if "-->" not in line: - continue - left, right = line.split("-->", 1) - left = left.strip() - right = right.strip() - label = None - if "--" in left: - parts = left.split("--", 1) - left, label = parts[0].strip(), parts[1].strip() - elif right.startswith("|"): - label, right = right[1:].split("|", 1) - label = label.strip() - right = right.strip() - src_id, src_lbl, src_shape = self._parse_node(left) - dst_id, dst_lbl, dst_shape = self._parse_node(right) - for nid, nlbl, nshape in [ - (src_id, src_lbl, src_shape), - (dst_id, dst_lbl, dst_shape), - ]: - if nid not in self.nodes: - self.add_node(Node(nid, label=nlbl or nid, shape=nshape)) - else: - node = self.nodes[nid] - if nlbl and node.label == node.name: - node.label = nlbl - if nshape and not node.shape: - node.shape = nshape - self.add_connection(src_id, dst_id, label=label) + for src_text, label, dest_text in self._iter_mermaid_edges(mermaid): + src_id, src_label, src_shape = self._parse_node(src_text) + dest_id, dest_label, dest_shape = self._parse_node(dest_text) + src_node = self._ensure_node(src_id, src_label, src_shape) + dest_node = self._ensure_node(dest_id, dest_label, dest_shape) + self.add_connection(src_node.name, dest_node.name, label=label) def _parse_node(self, text: str) -> tuple[str, str | None, str | None]: - m = match(r"^([A-Za-z0-9_]+)", text) - if not m: + match = re.match(r"^([A-Za-z0-9_]+)", text) + if not match: return text, None, None - nid = m.group(1) - rem = text[len(nid) :].strip() - if not rem: - return nid, None, None + node_id = match.group(1) + remainder = text[len(node_id) :].strip() + if not remainder: + return node_id, None, None shape = None label = None - if rem.startswith("[") and rem.endswith("]"): + if remainder.startswith("[") and remainder.endswith("]"): shape = "rect" - label = rem[1:-1] - elif rem.startswith("(((") and rem.endswith(")))"): + label = remainder[1:-1] + elif remainder.startswith("(((") and remainder.endswith(")))"): shape = "circle" - label = rem[3:-3] - elif rem.startswith("(") and rem.endswith(")"): + label = remainder[3:-3] + elif remainder.startswith("(") and remainder.endswith(")"): shape = "roundrect" - label = rem[1:-1] - elif rem.startswith("{") and rem.endswith("}"): + label = remainder[1:-1] + elif remainder.startswith("{") and remainder.endswith("}"): shape = "diamond" - label = rem[1:-1] + label = remainder[1:-1] else: - label = rem - return nid, label, shape + label = remainder + return node_id, label, shape + + def _ensure_node( + self, name: str, label: str | None, shape: str | None + ) -> Node: + node = self.nodes.get(name) + if node is None: + node = Node(name, label=label or name, shape=shape) + self.add_node(node) + return node + if label and node.label == node.name: + node.label = label + if shape and not node.shape: + node.shape = shape + return node + + def _iter_mermaid_edges( + self, mermaid: str + ) -> list[tuple[str, str | None, str]]: + lines = [line.strip() for line in mermaid.splitlines() if line.strip()] + if lines and lines[0].lower().startswith("graph"): + lines = lines[1:] + edges: list[tuple[str, str | None, str]] = [] + for line in lines: + if "-->" not in line: + continue + left, right = [part.strip() for part in line.split("-->", 1)] + label: str | None = None + if "--" in left: + src_text, label_text = left.split("--", 1) + left = src_text.strip() + label = label_text.strip() + elif right.startswith("|") and "|" in right[1:]: + label_text, dest_text = right[1:].split("|", 1) + label = label_text.strip() + right = dest_text.strip() + edges.append((left, label, right)) + return edges def execute( self, initial_node: str | Node | None = None, initial_data: Any = None, ) -> Any: - # Determine start nodes - if initial_node is None: - indegree = {n: 0 for n in self.nodes} - for c in self.connections: - indegree[c.dest.name] += 1 - start = [self.nodes[n] for n, d in indegree.items() if d == 0] - else: - start = ( - [self.nodes[initial_node]] - if isinstance(initial_node, str) - else [initial_node] - ) # type: ignore - indegree_map = {n: 0 for n in self.nodes} - for c in self.connections: - indegree_map[c.dest.name] += 1 - - buffers: dict[str, dict[str, Any]] = {n: {} for n in self.nodes} - if initial_data is not None and len(start) == 1: - buffers[start[0].name] = {"__init__": initial_data} - - ready: list[Node] = [] - for n in start: - indegree_map[n.name] = 0 - ready.append(n) + start_nodes = self._resolve_start_nodes(initial_node) + if not start_nodes: + raise ValueError( + "Flow has no valid starting node; graph may contain a cycle" + ) + + incoming_counts = { + name: len(self.incoming.get(name, [])) for name in self.nodes + } + indegree = dict(incoming_counts) + buffers: dict[str, dict[str, Any]] = {name: {} for name in self.nodes} + if initial_data is not None and len(start_nodes) == 1: + buffers[start_nodes[0].name] = {"__init__": initial_data} + + queue: deque[Node] = deque() + for node in start_nodes: + indegree[node.name] = 0 + queue.append(node) outputs: dict[str, Any] = {} - while ready: - node = ready.pop(0) - inp = buffers[node.name] - incoming = sum( - 1 for c in self.connections if c.dest.name == node.name - ) - if incoming > 0 and not inp: + processed: set[str] = set() + reachable = self._collect_reachable(start_nodes) + while queue: + node = queue.popleft() + processed.add(node.name) + inputs = buffers[node.name] + if incoming_counts[node.name] > 0 and not inputs: outputs[node.name] = None else: - outputs[node.name] = node.execute(inp) - outval = outputs[node.name] - if outval is not None: - for c in self.outgoing.get(node.name, []): - if c.check_conditions(outval): - forwarded = c.apply_filters(outval) - buffers[c.dest.name][node.name] = forwarded - for c in self.outgoing.get(node.name, []): - indegree_map[c.dest.name] -= 1 - if indegree_map[c.dest.name] == 0: - ready.append(c.dest) + outputs[node.name] = node.execute(inputs) + out_value = outputs[node.name] + for connection in self.outgoing.get(node.name, []): + indegree[connection.dest.name] -= 1 + if out_value is not None and connection.check_conditions( + out_value + ): + forwarded = connection.apply_filters(out_value) + buffers[connection.dest.name][node.name] = forwarded + if indegree[connection.dest.name] == 0: + queue.append(connection.dest) + + if processed != reachable: + remaining = sorted(reachable - processed) + raise ValueError( + "Flow contains a cycle involving: " + ", ".join(remaining) + ) + + cycle_nodes = self._detect_cycle_nodes(start_nodes) + if cycle_nodes: + remaining = sorted(cycle_nodes) + raise ValueError( + "Flow contains a cycle involving: " + ", ".join(remaining) + ) terminal = { - n: outputs[n] for n, outs in self.outgoing.items() if not outs + name: outputs[name] + for name, outs in self.outgoing.items() + if not outs } if len(terminal) == 1: return next(iter(terminal.values())) return terminal + + def _resolve_start_nodes( + self, initial_node: str | Node | None + ) -> list[Node]: + if initial_node is None: + return [ + self.nodes[name] + for name, inbound in self.incoming.items() + if not inbound + ] + if isinstance(initial_node, Node): + if initial_node.name not in self.nodes: + raise KeyError(f"Unknown node: {initial_node.name}") + return [self.nodes[initial_node.name]] + if initial_node not in self.nodes: + raise KeyError(f"Unknown node: {initial_node}") + return [self.nodes[initial_node]] + + def _collect_reachable(self, start_nodes: list[Node]) -> set[str]: + reachable: set[str] = set() + stack = [node.name for node in start_nodes] + while stack: + name = stack.pop() + if name in reachable: + continue + reachable.add(name) + for connection in self.outgoing.get(name, []): + stack.append(connection.dest.name) + return reachable + + def _detect_cycle_nodes(self, start_nodes: list[Node]) -> set[str]: + visited: set[str] = set() + recursion_stack: set[str] = set() + cycle_nodes: set[str] = set() + + def dfs(name: str) -> bool: + if name in recursion_stack: + cycle_nodes.add(name) + return True + if name in visited: + return False + visited.add(name) + recursion_stack.add(name) + found_cycle = False + for connection in self.outgoing.get(name, []): + if dfs(connection.dest.name): + cycle_nodes.add(name) + found_cycle = True + recursion_stack.remove(name) + return found_cycle + + for node in start_nodes: + dfs(node.name) + return cycle_nodes diff --git a/src/avalan/flow/node.py b/src/avalan/flow/node.py index a894526b..ad596a85 100644 --- a/src/avalan/flow/node.py +++ b/src/avalan/flow/node.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from typing import TYPE_CHECKING, Any, Callable if TYPE_CHECKING: @@ -15,7 +13,7 @@ def __init__( input_schema: type | None = None, output_schema: type | None = None, func: Callable[..., Any] | None = None, - subgraph: Flow | None = None, + subgraph: "Flow | None" = None, ) -> None: self.name: str = name self.label: str = label or name @@ -23,7 +21,7 @@ def __init__( self.input_schema: type | None = input_schema self.output_schema: type | None = output_schema self.func: Callable[..., Any] | None = func - self.subgraph: Flow | None = subgraph + self.subgraph: "Flow | None" = subgraph def execute(self, inputs: dict[str, Any]) -> Any: # Delegate to subgraph if present diff --git a/tests/flow/flow_test.py b/tests/flow/flow_test.py index 8eb75914..e52ec31b 100644 --- a/tests/flow/flow_test.py +++ b/tests/flow/flow_test.py @@ -97,6 +97,27 @@ def times_two(inputs): self.assertEqual(result, 4) self.assertEqual(executed, ["A", "B", "C"]) + def test_skip_node_without_inputs(self): + executed = [] + + def start(_): + executed.append("A") + return "ignored" + + def should_not_run(_): + executed.append("B") + raise AssertionError("Callback should not run") + + flow = Flow() + flow.add_node(Node("A", func=start)) + flow.add_node(Node("B", func=should_not_run)) + flow.add_connection("A", "B", conditions=[lambda _: False]) + + result = flow.execute() + + self.assertIsNone(result) + self.assertEqual(executed, ["A"]) + def test_execute_with_initial_node(self): executed = [] @@ -126,6 +147,26 @@ def times_two(inputs): self.assertEqual(result, 12) self.assertEqual(executed, ["B", "C"]) + def test_execute_raises_when_no_start_nodes(self): + flow = Flow() + flow.add_node(Node("A")) + flow.add_node(Node("B")) + flow.add_connection("A", "B") + flow.add_connection("B", "A") + with self.assertRaises(ValueError) as context: + flow.execute() + self.assertIn("cycle", str(context.exception)) + + def test_execute_detects_cycle_with_initial_node(self): + flow = Flow() + flow.add_node(Node("A")) + flow.add_node(Node("B")) + flow.add_connection("A", "B") + flow.add_connection("B", "A") + with self.assertRaises(ValueError) as context: + flow.execute(initial_node="A", initial_data=1) + self.assertIn("cycle", str(context.exception)) + class FlowAddConnectionTestCase(TestCase): def test_add_connection_unknown_src(self):