Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: improve downstream node performance #21

Merged
merged 40 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
cbcad42
feat: improve downstream nodes performance with local search
jaapschoutenalliander Jan 23, 2025
3d72bab
test: add testing on sorting
jaapschoutenalliander Jan 24, 2025
da0c33d
chore: update performance test
jaapschoutenalliander Jan 28, 2025
6d791f5
Add .find_connected method and alter .get_downstream_nodes
Thijss Jan 29, 2025
fc99bc0
Update test_get_downstream_nodes
Thijss Jan 29, 2025
7cb6fa2
Update performance tests
Thijss Jan 29, 2025
098e3ac
Merge branch 'feat/update-performance-tests' into feature/downstream_…
Thijss Jan 29, 2025
5300faf
Merge remote-tracking branch 'origin/main' into feature/downstream_no…
Thijss Feb 3, 2025
3ccf1e9
rename to find_first_connected and update docstring
Thijss Feb 4, 2025
09595b9
Feature: add Grid.from_txt method
Thijss Feb 5, 2025
8441e89
support unordered branches
Thijss Feb 5, 2025
a1c17b5
update documentation
Thijss Feb 5, 2025
0dddcf2
update gitignore
Thijss Feb 5, 2025
960a09d
update documentation
Thijss Feb 5, 2025
ce2e0dd
fix test
Thijss Feb 5, 2025
6c39904
Merge branch 'feat/add-from-txt' into feat/improve-downstream-perform…
Thijss Feb 5, 2025
fa34b94
update downstream tests
Thijss Feb 5, 2025
5d541fb
ruff
Thijss Feb 5, 2025
a7bdd9d
switch to regex for better text support
Thijss Feb 5, 2025
4866e6b
add support for both list[str] and str
Thijss Feb 5, 2025
2860f04
add test for docstring
Thijss Feb 5, 2025
8124ef2
switch to args input
Thijss Feb 5, 2025
9d406ba
fix constants for graph performance tests (#28)
Thijss Feb 4, 2025
a6f0681
fix: delete_branch3 used wrong argument name (#29)
vincentkoppen Feb 4, 2025
6ead9c5
chore: remove unused/nonfunctional cache on graphcontainer (#30)
vincentkoppen Feb 4, 2025
5bcb179
Feature: add Grid.from_txt method
Thijss Feb 5, 2025
3f1df6a
support unordered branches
Thijss Feb 5, 2025
7693c0a
update documentation
Thijss Feb 5, 2025
a8a5c2c
update gitignore
Thijss Feb 5, 2025
3fe3a6a
update documentation
Thijss Feb 5, 2025
92eba6c
fix test
Thijss Feb 5, 2025
9dabf71
update downstream tests
Thijss Feb 5, 2025
afc8a17
merge
Thijss Feb 5, 2025
2b99ea3
add downstream test
Thijss Feb 5, 2025
b5713b0
add TestFindFirstConnected
Thijss Feb 5, 2025
84863dc
remove re module implementation
Thijss Feb 5, 2025
0018e2f
bump minor
Thijss Feb 5, 2025
82ec35e
Merge branch 'feat/add-from-txt' into feature/downstream_node_perform…
Thijss Feb 5, 2025
e7b9ab3
re-add grid logic from downstream nodes
Thijss Feb 5, 2025
e352085
merge main
Thijss Feb 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/power_grid_model_ds/_core/model/graphs/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,41 @@ def get_connected(
nodes_to_ignore=self._externals_to_internals(nodes_to_ignore),
inclusive=inclusive,
)

return self._internals_to_externals(nodes)

def find_connected(self, node_id: int, candidate_node_ids: list[int]) -> int:
"""Find a connection between a node and a list of candidate nodes.

Note:
If multiple candidate nodes are connected to the node, the first one found is returned.
There is no guarantee that the same candidate node will be returned each time.

Raises:
MissingNodeError: if no connected node is found
ValueError: if the node_id is in candidate_node_ids
"""
internal_node_id = self.external_to_internal(node_id)
internal_candidates = self._externals_to_internals(candidate_node_ids)
if internal_node_id in internal_candidates:
raise ValueError("node_id cannot be in candidate_node_ids")
return self.internal_to_external(self._find_connected(internal_node_id, internal_candidates))

def get_downstream_nodes(self, node_id: int, stop_node_ids: list[int], inclusive: bool = False) -> list[int]:
Thijss marked this conversation as resolved.
Show resolved Hide resolved
"""Find all nodes connected to the node_id
args:
node_id: node id to start the search from
stop_node_ids: list of node ids to stop the search at
inclusive: whether to include the given node id in the result
returns:
list of node ids sorted by distance, downstream of to the node id
"""
connected_node = self.find_connected(node_id, stop_node_ids)
path, _ = self.get_shortest_path(node_id, connected_node)
_, upstream_node, *_ = path # path is at least 2 elements long or find_connected would have raised an error
vincentkoppen marked this conversation as resolved.
Show resolved Hide resolved

return self.get_connected(node_id, [upstream_node], inclusive)

def find_fundamental_cycles(self) -> list[list[int]]:
"""Find all fundamental cycles in the graph.
Returns:
Expand Down Expand Up @@ -273,6 +306,9 @@ def _branch_is_relevant(self, branch: BranchArray) -> bool:
@abstractmethod
def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bool = False) -> list[int]: ...

@abstractmethod
def _find_connected(self, node_id: int, candidate_node_ids: list[int]) -> int: ...

@abstractmethod
def _has_branch(self, from_node_id, to_node_id) -> bool: ...

Expand Down
22 changes: 21 additions & 1 deletion src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import rustworkx as rx
from rustworkx import NoEdgeBetweenNodes
from rustworkx.visit import BFSVisitor, PruneSearch
from rustworkx.visit import BFSVisitor, PruneSearch, StopSearch

from power_grid_model_ds._core.model.graphs.errors import MissingBranchError, MissingNodeError, NoPathBetweenNodes
from power_grid_model_ds._core.model.graphs.models._rustworkx_search import find_fundamental_cycles_rustworkx
Expand Down Expand Up @@ -99,6 +99,13 @@ def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bo

return connected_nodes

def _find_connected(self, node_id: int, candidate_node_ids: list[int]) -> int:
Thijss marked this conversation as resolved.
Show resolved Hide resolved
visitor = _NodeFinder(candidate_nodes=candidate_node_ids)
rx.bfs_search(self._graph, [node_id], visitor)
if visitor.found_node is None:
raise MissingNodeError(f"node {node_id} is not connected to any of the candidate nodes")
return visitor.found_node

def _find_fundamental_cycles(self) -> list[list[int]]:
"""Find all fundamental cycles in the graph using Rustworkx.

Expand All @@ -117,3 +124,16 @@ def discover_vertex(self, v):
if v in self.nodes_to_ignore:
raise PruneSearch
self.nodes.append(v)


class _NodeFinder(BFSVisitor):
"""Visitor that stops the search when a candidate node is found"""

def __init__(self, candidate_nodes: list[int]):
self.candidate_nodes = candidate_nodes
self.found_node: int | None = None

def discover_vertex(self, v):
if v in self.candidate_nodes:
self.found_node = v
raise StopSearch
12 changes: 6 additions & 6 deletions src/power_grid_model_ds/_core/model/grids/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def get_nearest_substation_node(self, node_id: int):

def get_downstream_nodes(self, node_id: int, inclusive: bool = False):
"""Get the downstream nodes from a node.
Assuming each node has a single feeding substation and the grid is radial

Example:
given this graph: [1] - [2] - [3] - [4], with 1 being a substation node
Expand All @@ -349,15 +350,14 @@ def get_downstream_nodes(self, node_id: int, inclusive: bool = False):
Returns:
list[int]: The downstream nodes.
"""
substation_node_id = self.get_nearest_substation_node(node_id).id.item()
substation_nodes = self.node.filter(node_type=NodeType.SUBSTATION_NODE.value)

if node_id == substation_node_id:
if node_id in substation_nodes.id:
raise NotImplementedError("get_downstream_nodes is not implemented for substation nodes!")

path_to_substation, _ = self.graphs.active_graph.get_shortest_path(node_id, substation_node_id)
upstream_node = path_to_substation[1]

return self.graphs.active_graph.get_connected(node_id, nodes_to_ignore=[upstream_node], inclusive=inclusive)
return self.graphs.active_graph.get_downstream_nodes(
node_id=node_id, stop_node_ids=list(substation_nodes.id), inclusive=inclusive
)

def cache(self, cache_dir: Path, cache_name: str, compress: bool = True):
"""Cache Grid to a folder
Expand Down
18 changes: 9 additions & 9 deletions tests/performance/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@
"dtype = [('id', '<i8'), ('test_int', '<i8'), ('test_float', '<f8'), ('test_str', '<U50'), ('test_bool', '?')]; "
)

SETUP_CODES = {
"structured": "import numpy as np;" + NUMPY_DTYPE + "input_array = np.zeros({array_size}, dtype=dtype)",
"rec": "import numpy as np;" + NUMPY_DTYPE + "input_array = np.recarray(({array_size},),dtype=dtype)",
"fancy": "from tests.conftest import FancyTestArray; input_array=FancyTestArray.zeros({array_size});"
+ "import numpy as np;input_array.id = np.arange({array_size})",
ARRAY_SETUP_CODES = {
"structured": "import numpy as np;" + NUMPY_DTYPE + "input_array = np.zeros({size}, dtype=dtype)",
"rec": "import numpy as np;" + NUMPY_DTYPE + "input_array = np.recarray(({size},),dtype=dtype)",
"fancy": "from tests.conftest import FancyTestArray; input_array=FancyTestArray.zeros({size});"
+ "import numpy as np;input_array.id = np.arange({size})",
}

GRAPH_SETUP_CODES = {
"rustworkx": "from power_grid_model_ds.model.grids.base import Grid;"
+ "from power_grid_model_ds.data_source.generator.grid_generators import RadialGridGenerator;"
+ "from power_grid_model_ds.model.graphs.models import RustworkxGraphModel;"
+ "grid=RadialGridGenerator(nr_nodes={graph_size}, grid_class=Grid, graph_model=RustworkxGraphModel).run()",
"rustworkx": "from power_grid_model_ds import Grid;"
+ "from power_grid_model_ds.generators import RadialGridGenerator;"
+ "from power_grid_model_ds.graph_models import RustworkxGraphModel;"
+ "grid=RadialGridGenerator(nr_nodes={size}, grid_class=Grid, graph_model=RustworkxGraphModel).run()",
}

SINGLE_REPEATS = 1000
Expand Down
101 changes: 29 additions & 72 deletions tests/performance/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,100 +4,57 @@

import inspect
import timeit
from typing import Generator
from itertools import product
from typing import Generator, Union

from tests.performance._constants import GRAPH_SETUP_CODES, SETUP_CODES


def do_performance_test(code_to_test: str | dict[str, str], array_sizes: list[int], repeats: int):
"""Run the performance test for the given code."""

def do_performance_test(
code_to_test: Union[str, dict[str, str], list[str]],
size_list: list[int],
repeats: int,
setup_codes: dict[str, str],
):
"""Generalized performance test runner."""
print(f"{'-' * 20} {inspect.stack()[1][3]} {'-' * 20}")

for array_size in array_sizes:
for size in size_list:
formatted_setup_codes = {key: code.format(size=size) for key, code in setup_codes.items()}
if isinstance(code_to_test, dict):
code_to_test_list = [code_to_test[variant].format(array_size=array_size) for variant in SETUP_CODES]
else:
code_to_test_list = [code_to_test.format(array_size=array_size)] * len(SETUP_CODES)
print(f"\n\tArray size: {array_size}\n")
setup_codes = [setup_code.format(array_size=array_size) for setup_code in SETUP_CODES.values()]
timings = _get_timings(setup_codes, code_to_test_list, repeats)

if code_to_test == "pass":
_print_timings(timings, list(SETUP_CODES.keys()), setup_codes)
code_to_test_list = [code_to_test[variant].format(size=size) for variant in setup_codes]
test_generator = zip(formatted_setup_codes.items(), code_to_test_list)
elif isinstance(code_to_test, list):
code_to_test_list = [code.format(size=size) for code in code_to_test]
test_generator = product(formatted_setup_codes.items(), code_to_test_list)
else:
_print_timings(timings, list(SETUP_CODES.keys()), code_to_test_list)
print()
test_generator = product(formatted_setup_codes.items(), [code_to_test.format(size=size)])

print(f"\n\tsize: {size}\n")

def do_graph_test(code_to_test: str | dict[str, str], graph_sizes: list[int], repeats: int):
"""Run the performance test for the given code."""
timings = _get_timings(test_generator, repeats=repeats)
_print_timings(timings)

print(f"{'-' * 20} {inspect.stack()[1][3]} {'-' * 20}")

for graph_size in graph_sizes:
if isinstance(code_to_test, dict):
code_to_test_list = [code_to_test[variant] for variant in GRAPH_SETUP_CODES]
else:
code_to_test_list = [code_to_test] * len(GRAPH_SETUP_CODES)
print(f"\n\tGraph size: {graph_size}\n")
setup_codes = [setup_code.format(graph_size=graph_size) for setup_code in GRAPH_SETUP_CODES.values()]
timings = _get_timings(setup_codes, code_to_test_list, repeats)

if code_to_test == "pass":
_print_graph_timings(timings, list(GRAPH_SETUP_CODES.keys()), setup_codes)
else:
_print_graph_timings(timings, list(GRAPH_SETUP_CODES.keys()), code_to_test_list)
print()


def _print_test_code(code: str | dict[str, str], repeats: int):
print(f"{'-' * 40}")
if isinstance(code, dict):
for variant, code_variant in code.items():
print(f"{variant}")
print(f"\t{code_variant} (x {repeats})")
return
print(f"{code} (x {repeats})")


def _print_graph_timings(timings: Generator, graph_types: list[str], code_list: list[str]):
for graph_type, timing, code in zip(graph_types, timings, code_list):
if ";" in code:
code = code.split(";")[-1]

code = code.replace("\n", " ").replace("\t", " ")
code = f"{graph_type}: " + code

if isinstance(timing, Exception):
print(f"\t\t{code.ljust(100)} | Not supported")
continue
print(f"\t\t{code.ljust(100)} | {sum(timing):.2f}s")


def _print_timings(timings: Generator, array_types: list[str], code_list: list[str]):
for array, timing, code in zip(array_types, timings, code_list):
if ";" in code:
code = code.split(";")[-1]

code = code.replace("\n", " ").replace("\t", " ")
array_name = f"{array}_array"
code = code.replace("input_array", array_name)
def _print_timings(timings: Generator):
for key, code, timing in timings:
code = code.split(";")[-1].replace("\n", " ").replace("\t", " ")
code = f"{key}: {code}"

if isinstance(timing, Exception):
print(f"\t\t{code.ljust(100)} | Not supported")
continue
print(f"\t\t{code.ljust(100)} | {sum(timing):.2f}s")


def _get_timings(setup_codes: list[str], test_codes: list[str], repeats: int):
def _get_timings(test_generator, repeats: int):
"""Return a generator with the timings for each array type."""
for setup_code, test_code in zip(setup_codes, test_codes):
for (key, setup_code), test_code in test_generator:
if test_code == "pass":
yield timeit.repeat(setup_code, number=repeats)
yield key, "intialise", timeit.repeat(setup_code, number=repeats)
else:
try:
yield timeit.repeat(test_code, setup_code, number=repeats)
yield key, test_code, timeit.repeat(test_code, setup_code, number=repeats)
# pylint: disable=broad-exception-caught
except Exception as error: # noqa
yield error
yield key, test_code, error
Loading