Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
262 changes: 170 additions & 92 deletions src/avalan/flow/flow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from ..flow.connection import Connection
from ..flow.node import Node

from re import match
from collections import deque
import re
from typing import Any, Callable


Expand All @@ -12,10 +13,12 @@ def __init__(self) -> None:
self.nodes: dict[str, Node] = {}
self.connections: list[Connection] = []
self.outgoing: dict[str, list[Connection]] = {}
self.incoming: dict[str, list[Connection]] = {}

def add_node(self, node: Node) -> None:
self.nodes[node.name] = node
self.outgoing.setdefault(node.name, [])
self.incoming.setdefault(node.name, [])

def add_connection(
self,
Expand All @@ -32,127 +35,202 @@ def add_connection(
conn = Connection(src, dest, label, conditions, filters)
self.connections.append(conn)
self.outgoing[src_name].append(conn)
self.incoming.setdefault(dest_name, []).append(conn)

def parse_mermaid(self, mermaid: str) -> None:
"""Populate the flow from a Mermaid diagram.

Edge labels are accepted in both ``A -- label --> B`` and
``A -->|label| B`` forms. Lines without edges are ignored.
"""
lines = [line.strip() for line in mermaid.splitlines() if line.strip()]
if lines and lines[0].lower().startswith("graph"):
lines = lines[1:]
for line in lines:
if "-->" not in line:
continue
left, right = line.split("-->", 1)
left = left.strip()
right = right.strip()
label = None
if "--" in left:
parts = left.split("--", 1)
left, label = parts[0].strip(), parts[1].strip()
elif right.startswith("|"):
label, right = right[1:].split("|", 1)
label = label.strip()
right = right.strip()
src_id, src_lbl, src_shape = self._parse_node(left)
dst_id, dst_lbl, dst_shape = self._parse_node(right)
for nid, nlbl, nshape in [
(src_id, src_lbl, src_shape),
(dst_id, dst_lbl, dst_shape),
]:
if nid not in self.nodes:
self.add_node(Node(nid, label=nlbl or nid, shape=nshape))
else:
node = self.nodes[nid]
if nlbl and node.label == node.name:
node.label = nlbl
if nshape and not node.shape:
node.shape = nshape
self.add_connection(src_id, dst_id, label=label)
for src_text, label, dest_text in self._iter_mermaid_edges(mermaid):
src_id, src_label, src_shape = self._parse_node(src_text)
dest_id, dest_label, dest_shape = self._parse_node(dest_text)
src_node = self._ensure_node(src_id, src_label, src_shape)
dest_node = self._ensure_node(dest_id, dest_label, dest_shape)
self.add_connection(src_node.name, dest_node.name, label=label)

def _parse_node(self, text: str) -> tuple[str, str | None, str | None]:
m = match(r"^([A-Za-z0-9_]+)", text)
if not m:
match = re.match(r"^([A-Za-z0-9_]+)", text)
if not match:
return text, None, None
nid = m.group(1)
rem = text[len(nid) :].strip()
if not rem:
return nid, None, None
node_id = match.group(1)
remainder = text[len(node_id) :].strip()
if not remainder:
return node_id, None, None
shape = None
label = None
if rem.startswith("[") and rem.endswith("]"):
if remainder.startswith("[") and remainder.endswith("]"):
shape = "rect"
label = rem[1:-1]
elif rem.startswith("(((") and rem.endswith(")))"):
label = remainder[1:-1]
elif remainder.startswith("(((") and remainder.endswith(")))"):
shape = "circle"
label = rem[3:-3]
elif rem.startswith("(") and rem.endswith(")"):
label = remainder[3:-3]
elif remainder.startswith("(") and remainder.endswith(")"):
shape = "roundrect"
label = rem[1:-1]
elif rem.startswith("{") and rem.endswith("}"):
label = remainder[1:-1]
elif remainder.startswith("{") and remainder.endswith("}"):
shape = "diamond"
label = rem[1:-1]
label = remainder[1:-1]
else:
label = rem
return nid, label, shape
label = remainder
return node_id, label, shape

def _ensure_node(
self, name: str, label: str | None, shape: str | None
) -> Node:
node = self.nodes.get(name)
if node is None:
node = Node(name, label=label or name, shape=shape)
self.add_node(node)
return node
if label and node.label == node.name:
node.label = label
if shape and not node.shape:
node.shape = shape
return node

def _iter_mermaid_edges(
self, mermaid: str
) -> list[tuple[str, str | None, str]]:
lines = [line.strip() for line in mermaid.splitlines() if line.strip()]
if lines and lines[0].lower().startswith("graph"):
lines = lines[1:]
edges: list[tuple[str, str | None, str]] = []
for line in lines:
if "-->" not in line:
continue
left, right = [part.strip() for part in line.split("-->", 1)]
label: str | None = None
if "--" in left:
src_text, label_text = left.split("--", 1)
left = src_text.strip()
label = label_text.strip()
elif right.startswith("|") and "|" in right[1:]:
label_text, dest_text = right[1:].split("|", 1)
label = label_text.strip()
right = dest_text.strip()
edges.append((left, label, right))
return edges

def execute(
self,
initial_node: str | Node | None = None,
initial_data: Any = None,
) -> Any:
# Determine start nodes
if initial_node is None:
indegree = {n: 0 for n in self.nodes}
for c in self.connections:
indegree[c.dest.name] += 1
start = [self.nodes[n] for n, d in indegree.items() if d == 0]
else:
start = (
[self.nodes[initial_node]]
if isinstance(initial_node, str)
else [initial_node]
) # type: ignore
indegree_map = {n: 0 for n in self.nodes}
for c in self.connections:
indegree_map[c.dest.name] += 1

buffers: dict[str, dict[str, Any]] = {n: {} for n in self.nodes}
if initial_data is not None and len(start) == 1:
buffers[start[0].name] = {"__init__": initial_data}

ready: list[Node] = []
for n in start:
indegree_map[n.name] = 0
ready.append(n)
start_nodes = self._resolve_start_nodes(initial_node)
if not start_nodes:
raise ValueError(
"Flow has no valid starting node; graph may contain a cycle"
)

incoming_counts = {
name: len(self.incoming.get(name, [])) for name in self.nodes
}
indegree = dict(incoming_counts)
buffers: dict[str, dict[str, Any]] = {name: {} for name in self.nodes}
if initial_data is not None and len(start_nodes) == 1:
buffers[start_nodes[0].name] = {"__init__": initial_data}

queue: deque[Node] = deque()
for node in start_nodes:
indegree[node.name] = 0
queue.append(node)

outputs: dict[str, Any] = {}
while ready:
node = ready.pop(0)
inp = buffers[node.name]
incoming = sum(
1 for c in self.connections if c.dest.name == node.name
)
if incoming > 0 and not inp:
processed: set[str] = set()
reachable = self._collect_reachable(start_nodes)
while queue:
node = queue.popleft()
processed.add(node.name)
inputs = buffers[node.name]
if incoming_counts[node.name] > 0 and not inputs:
outputs[node.name] = None
else:
outputs[node.name] = node.execute(inp)
outval = outputs[node.name]
if outval is not None:
for c in self.outgoing.get(node.name, []):
if c.check_conditions(outval):
forwarded = c.apply_filters(outval)
buffers[c.dest.name][node.name] = forwarded
for c in self.outgoing.get(node.name, []):
indegree_map[c.dest.name] -= 1
if indegree_map[c.dest.name] == 0:
ready.append(c.dest)
outputs[node.name] = node.execute(inputs)
out_value = outputs[node.name]
for connection in self.outgoing.get(node.name, []):
indegree[connection.dest.name] -= 1
if out_value is not None and connection.check_conditions(
out_value
):
forwarded = connection.apply_filters(out_value)
buffers[connection.dest.name][node.name] = forwarded
if indegree[connection.dest.name] == 0:
queue.append(connection.dest)

if processed != reachable:
remaining = sorted(reachable - processed)
raise ValueError(
"Flow contains a cycle involving: " + ", ".join(remaining)
)

cycle_nodes = self._detect_cycle_nodes(start_nodes)
if cycle_nodes:
remaining = sorted(cycle_nodes)
raise ValueError(
"Flow contains a cycle involving: " + ", ".join(remaining)
)

terminal = {
n: outputs[n] for n, outs in self.outgoing.items() if not outs
name: outputs[name]
for name, outs in self.outgoing.items()
if not outs
}
if len(terminal) == 1:
return next(iter(terminal.values()))
return terminal

def _resolve_start_nodes(
self, initial_node: str | Node | None
) -> list[Node]:
if initial_node is None:
return [
self.nodes[name]
for name, inbound in self.incoming.items()
if not inbound
]
if isinstance(initial_node, Node):
if initial_node.name not in self.nodes:
raise KeyError(f"Unknown node: {initial_node.name}")
return [self.nodes[initial_node.name]]
if initial_node not in self.nodes:
raise KeyError(f"Unknown node: {initial_node}")
return [self.nodes[initial_node]]

def _collect_reachable(self, start_nodes: list[Node]) -> set[str]:
reachable: set[str] = set()
stack = [node.name for node in start_nodes]
while stack:
name = stack.pop()
if name in reachable:
continue
reachable.add(name)
for connection in self.outgoing.get(name, []):
stack.append(connection.dest.name)
return reachable

def _detect_cycle_nodes(self, start_nodes: list[Node]) -> set[str]:
visited: set[str] = set()
recursion_stack: set[str] = set()
cycle_nodes: set[str] = set()

def dfs(name: str) -> bool:
if name in recursion_stack:
cycle_nodes.add(name)
return True
if name in visited:
return False
visited.add(name)
recursion_stack.add(name)
found_cycle = False
for connection in self.outgoing.get(name, []):
if dfs(connection.dest.name):
cycle_nodes.add(name)
found_cycle = True
recursion_stack.remove(name)
return found_cycle

for node in start_nodes:
dfs(node.name)
return cycle_nodes
6 changes: 2 additions & 4 deletions src/avalan/flow/node.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable

if TYPE_CHECKING:
Expand All @@ -15,15 +13,15 @@ def __init__(
input_schema: type | None = None,
output_schema: type | None = None,
func: Callable[..., Any] | None = None,
subgraph: Flow | None = None,
subgraph: "Flow | None" = None,
) -> None:
self.name: str = name
self.label: str = label or name
self.shape: str | None = shape
self.input_schema: type | None = input_schema
self.output_schema: type | None = output_schema
self.func: Callable[..., Any] | None = func
self.subgraph: Flow | None = subgraph
self.subgraph: "Flow | None" = subgraph

def execute(self, inputs: dict[str, Any]) -> Any:
# Delegate to subgraph if present
Expand Down
41 changes: 41 additions & 0 deletions tests/flow/flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,27 @@ def times_two(inputs):
self.assertEqual(result, 4)
self.assertEqual(executed, ["A", "B", "C"])

def test_skip_node_without_inputs(self):
executed = []

def start(_):
executed.append("A")
return "ignored"

def should_not_run(_):
executed.append("B")
raise AssertionError("Callback should not run")

flow = Flow()
flow.add_node(Node("A", func=start))
flow.add_node(Node("B", func=should_not_run))
flow.add_connection("A", "B", conditions=[lambda _: False])

result = flow.execute()

self.assertIsNone(result)
self.assertEqual(executed, ["A"])

def test_execute_with_initial_node(self):
executed = []

Expand Down Expand Up @@ -126,6 +147,26 @@ def times_two(inputs):
self.assertEqual(result, 12)
self.assertEqual(executed, ["B", "C"])

def test_execute_raises_when_no_start_nodes(self):
flow = Flow()
flow.add_node(Node("A"))
flow.add_node(Node("B"))
flow.add_connection("A", "B")
flow.add_connection("B", "A")
with self.assertRaises(ValueError) as context:
flow.execute()
self.assertIn("cycle", str(context.exception))

def test_execute_detects_cycle_with_initial_node(self):
flow = Flow()
flow.add_node(Node("A"))
flow.add_node(Node("B"))
flow.add_connection("A", "B")
flow.add_connection("B", "A")
with self.assertRaises(ValueError) as context:
flow.execute(initial_node="A", initial_data=1)
self.assertIn("cycle", str(context.exception))


class FlowAddConnectionTestCase(TestCase):
def test_add_connection_unknown_src(self):
Expand Down