Skip to content

Commit

Permalink
Add .find_connected method and alter .get_downstream_nodes
Browse files Browse the repository at this point in the history
Signed-off-by: Thijs Baaijen <[email protected]>
  • Loading branch information
Thijss committed Jan 29, 2025
1 parent da0c33d commit 81b8d23
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 17 deletions.
25 changes: 18 additions & 7 deletions src/power_grid_model_ds/_core/model/graphs/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,19 @@ def get_connected(

return self._internals_to_externals(nodes)

def find_connected(self, node_id: int, candidate_node_ids) -> int:
"""Returns the first (!) node in candidate_node_ids that is connected to node_id
Raises:
MissingNodeError: if no connected node is found
GraphError: if the
"""
internal_node_id = self.external_to_internal(node_id)
internal_candidates = self._externals_to_internals(candidate_node_ids)
if internal_node_id in internal_candidates:
raise ValueError("node_id cannot be in candidate_node_ids")
return self.internal_to_external(self._find_connected(internal_node_id, internal_candidates))

def get_downstream_nodes(self, node_id: int, stop_node_ids: list[int], inclusive: bool = False) -> list[int]:
"""Find all nodes connected to the node_id
args:
Expand All @@ -247,13 +260,11 @@ def get_downstream_nodes(self, node_id: int, stop_node_ids: list[int], inclusive
returns:
list of node ids sorted by distance, downstream of to the node id
"""
downstream_nodes = self._get_downstream_nodes(
node_id=self.external_to_internal(node_id),
stop_node_ids=self._externals_to_internals(stop_node_ids),
inclusive=inclusive,
)
connected_node = self.find_connected(node_id, stop_node_ids)
path, _ = self.get_shortest_path(node_id, connected_node)
_, upstream_node, *_ = path # path is at least 2 elements long or find_connected would have raised an error

return self._internals_to_externals(downstream_nodes)
return self.get_connected(node_id, [upstream_node], inclusive)

def find_fundamental_cycles(self) -> list[list[int]]:
"""Find all fundamental cycles in the graph.
Expand Down Expand Up @@ -292,7 +303,7 @@ def _branch_is_relevant(self, branch: BranchArray) -> bool:
def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bool = False) -> list[int]: ...

@abstractmethod
def _get_downstream_nodes(self, node_id: int, stop_node_ids: list[int], inclusive: bool = False) -> list[int]: ...
def _find_connected(self, node_id: int, candidate_node_ids: list[int]) -> int: ...

@abstractmethod
def _has_branch(self, from_node_id, to_node_id) -> bool: ...
Expand Down
29 changes: 19 additions & 10 deletions src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import rustworkx as rx
from rustworkx import NoEdgeBetweenNodes
from rustworkx.visit import BFSVisitor, PruneSearch
from rustworkx.visit import BFSVisitor, PruneSearch, StopSearch

from power_grid_model_ds._core.model.graphs.errors import MissingBranchError, MissingNodeError, NoPathBetweenNodes
from power_grid_model_ds._core.model.graphs.models._rustworkx_search import find_fundamental_cycles_rustworkx
Expand Down Expand Up @@ -99,14 +99,12 @@ def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bo

return connected_nodes

def _get_downstream_nodes(self, node_id: int, stop_node_ids: list[int], inclusive: bool = False) -> list[int]:
visitor = _NodeVisitor(stop_node_ids)
def _find_connected(self, node_id: int, candidate_node_ids: list[int]) -> int:
visitor = _NodeFinder(candidate_nodes=candidate_node_ids)
rx.bfs_search(self._graph, [node_id], visitor)
connected_nodes = visitor.nodes
path_to_substation, _ = self._get_shortest_path(node_id, visitor.discovered_nodes_to_ignore[0])
if inclusive:
_ = path_to_substation.pop(0)
return [node for node in connected_nodes if node not in path_to_substation]
if visitor.found_node is None:
raise MissingNodeError(f"node {node_id} is not connected to any of the candidate nodes")
return visitor.found_node

def _find_fundamental_cycles(self) -> list[list[int]]:
"""Find all fundamental cycles in the graph using Rustworkx.
Expand All @@ -121,10 +119,21 @@ class _NodeVisitor(BFSVisitor):
def __init__(self, nodes_to_ignore: list[int]):
self.nodes_to_ignore = nodes_to_ignore
self.nodes: list[int] = []
self.discovered_nodes_to_ignore: list[int] = []

def discover_vertex(self, v):
if v in self.nodes_to_ignore:
self.discovered_nodes_to_ignore.append(v)
raise PruneSearch
self.nodes.append(v)


class _NodeFinder(BFSVisitor):
"""Visitor that stops the search when a candidate node is found"""

def __init__(self, candidate_nodes: list[int]):
self.candidate_nodes = candidate_nodes
self.found_node: int | None = None

def discover_vertex(self, v):
if v in self.candidate_nodes:
self.found_node = v
raise StopSearch

0 comments on commit 81b8d23

Please sign in to comment.