From 2a4f61c4c1a0040ca13c9ddbe136323d4628ee70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Wed, 27 Aug 2025 22:21:05 +0200 Subject: [PATCH 1/8] Add GraphSamplingEndpoints --- .../api/graph_sampling_endpoints.py | 164 ++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 graphdatascience/procedure_surface/api/graph_sampling_endpoints.py diff --git a/graphdatascience/procedure_surface/api/graph_sampling_endpoints.py b/graphdatascience/procedure_surface/api/graph_sampling_endpoints.py new file mode 100644 index 000000000..f80cc21f7 --- /dev/null +++ b/graphdatascience/procedure_surface/api/graph_sampling_endpoints.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, List, Optional + +from graphdatascience import Graph +from graphdatascience.procedure_surface.api.base_result import BaseResult + + +class GraphSamplingEndpoints(ABC): + """ + Abstract base class defining the API for graph sampling algorithms algorithm. + """ + + @abstractmethod + def rwr( + self, + G: Graph, + graph_name: str, + startNodes: Optional[List[int]] = None, + restartProbability: Optional[float] = None, + samplingRatio: Optional[float] = None, + nodeLabelStratification: Optional[bool] = None, + relationshipWeightProperty: Optional[str] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + ) -> GraphSamplingResult: + """ + Computes a set of Random Walks with Restart (RWR) for the given graph and stores the result as a new graph in the catalog. + + This method performs a random walk, beginning from a set of nodes (if provided), + where at each step there is a probability to restart back at the original nodes. + The result is turned into a new graph induced by the random walks and stored in the catalog. + + Parameters + ---------- + G : Graph + The input graph on which the Random Walk with Restart (RWR) will be + performed. + graph_name : str + The name of the new graph in the catalog. + startNodes : list of int, optional + A list of node IDs to start the random walk from. If not provided, all + nodes are used as potential starting points. + restartProbability : float, optional + The probability of restarting back to the original node at each step. + Should be a value between 0 and 1. If not specified, a default value is used. + samplingRatio : float, optional + The ratio of nodes to sample during the computation. This value should + be between 0 and 1. If not specified, no sampling is performed. + nodeLabelStratification : bool, optional + If True, the algorithm tries to preserve the label distribution of the original graph in the sampled graph. + relationshipWeightProperty : str, optional + The name of the property on relationships to use as weights during + the random walk. If not specified, the relationships are treated as + unweighted. + relationship_types : Optional[List[str]], default=None + The relationship types used to select relationships for this algorithm run. + node_labels : Optional[List[str]], default=None + The node labels used to select nodes for this algorithm run. + sudo : bool, optional + Override memory estimation limits. Use with caution as this can lead to + memory issues if the estimation is significantly wrong. + log_progress : bool, optional + If True, logs the progress of the computation. + username : str, optional + The username to attribute the procedure run to + concurrency : Any, optional + The number of concurrent threads used for the algorithm execution. + job_id : Any, optional + An identifier for the job that can be used for monitoring and cancellation + + Returns + ------- + GraphSamplingResult + The result of the Random Walk with Restart (RWR), including the sampled + nodes and their scores. + """ + pass + + @abstractmethod + def cnarw( + self, + G: Graph, + graph_name: str, + startNodes: Optional[List[int]] = None, + restartProbability: Optional[float] = None, + samplingRatio: Optional[float] = None, + nodeLabelStratification: Optional[bool] = None, + relationshipWeightProperty: Optional[str] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + ) -> GraphSamplingResult: + """ + Computes a set of Random Walks with Restart (RWR) for the given graph and stores the result as a new graph in the catalog. + + This method performs a random walk, beginning from a set of nodes (if provided), + where at each step there is a probability to restart back at the original nodes. + The result is turned into a new graph induced by the random walks and stored in the catalog. + + Parameters + ---------- + G : Graph + The input graph on which the Random Walk with Restart (RWR) will be + performed. + graph_name : str + The name of the new graph in the catalog. + startNodes : list of int, optional + A list of node IDs to start the random walk from. If not provided, all + nodes are used as potential starting points. + restartProbability : float, optional + The probability of restarting back to the original node at each step. + Should be a value between 0 and 1. If not specified, a default value is used. + samplingRatio : float, optional + The ratio of nodes to sample during the computation. This value should + be between 0 and 1. If not specified, no sampling is performed. + nodeLabelStratification : bool, optional + If True, the algorithm tries to preserve the label distribution of the original graph in the sampled graph. + relationshipWeightProperty : str, optional + The name of the property on relationships to use as weights during + the random walk. If not specified, the relationships are treated as + unweighted. + relationship_types : Optional[List[str]], default=None + The relationship types used to select relationships for this algorithm run. + node_labels : Optional[List[str]], default=None + The node labels used to select nodes for this algorithm run. + sudo : bool, optional + Override memory estimation limits. Use with caution as this can lead to + memory issues if the estimation is significantly wrong. + log_progress : bool, optional + If True, logs the progress of the computation. + username : str, optional + The username to attribute the procedure run to + concurrency : Any, optional + The number of concurrent threads used for the algorithm execution. + job_id : Any, optional + An identifier for the job that can be used for monitoring and cancellation + + Returns + ------- + GraphSamplingResult + The result of the Random Walk with Restart (RWR), including the sampled + nodes and their scores. + """ + pass + + +class GraphSamplingResult(BaseResult): + graph_name: str + from_graph_name: str + node_count: int + relationship_count: int + start_node_count: int + project_millis: int From f051ff62945c69797f533f57f98a97a656dc05a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Wed, 27 Aug 2025 22:21:52 +0200 Subject: [PATCH 2/8] Implement GraphSamplingArrowEndpoints --- .../arrow/graph_sampling_arrow_endpoints.py | 93 ++++++++++++++++++ .../test_graph_sampling_arrow_endpoints.py | 98 +++++++++++++++++++ 2 files changed, 191 insertions(+) create mode 100644 graphdatascience/procedure_surface/arrow/graph_sampling_arrow_endpoints.py create mode 100644 graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graph_sampling_arrow_endpoints.py diff --git a/graphdatascience/procedure_surface/arrow/graph_sampling_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/graph_sampling_arrow_endpoints.py new file mode 100644 index 000000000..444e6520c --- /dev/null +++ b/graphdatascience/procedure_surface/arrow/graph_sampling_arrow_endpoints.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from typing import Any, List, Optional + +from graphdatascience import Graph +from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.arrow_client.v2.job_client import JobClient +from graphdatascience.procedure_surface.api.graph_sampling_endpoints import ( + GraphSamplingEndpoints, + GraphSamplingResult, +) +from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter + + +class GraphSamplingArrowEndpoints(GraphSamplingEndpoints): + def __init__(self, arrow_client: AuthenticatedArrowClient): + self._arrow_client = arrow_client + + def rwr( + self, + G: Graph, + graph_name: str, + startNodes: Optional[List[int]] = None, + restartProbability: Optional[float] = None, + samplingRatio: Optional[float] = None, + nodeLabelStratification: Optional[bool] = None, + relationshipWeightProperty: Optional[str] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + ) -> GraphSamplingResult: + config = ConfigConverter.convert_to_gds_config( + from_graph_name=G.name(), + graph_name=graph_name, + startNodes=startNodes, + restartProbability=restartProbability, + samplingRatio=samplingRatio, + nodeLabelStratification=nodeLabelStratification, + relationshipWeightProperty=relationshipWeightProperty, + relationship_types=relationship_types, + node_labels=node_labels, + sudo=sudo, + log_progress=log_progress, + username=username, + concurrency=concurrency, + job_id=job_id, + ) + + job_id = JobClient.run_job_and_wait(self._arrow_client, "v2/graph.sample.rwr", config) + + return GraphSamplingResult(**JobClient.get_summary(self._arrow_client, job_id)) + + def cnarw( + self, + G: Graph, + graph_name: str, + startNodes: Optional[List[int]] = None, + restartProbability: Optional[float] = None, + samplingRatio: Optional[float] = None, + nodeLabelStratification: Optional[bool] = None, + relationshipWeightProperty: Optional[str] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + ) -> GraphSamplingResult: + config = ConfigConverter.convert_to_gds_config( + from_graph_name=G.name(), + graph_name=graph_name, + startNodes=startNodes, + restartProbability=restartProbability, + samplingRatio=samplingRatio, + nodeLabelStratification=nodeLabelStratification, + relationshipWeightProperty=relationshipWeightProperty, + relationship_types=relationship_types, + node_labels=node_labels, + sudo=sudo, + log_progress=log_progress, + username=username, + concurrency=concurrency, + job_id=job_id, + ) + + job_id = JobClient.run_job_and_wait(self._arrow_client, "v2/graph.sample.cnarw", config) + + return GraphSamplingResult(**JobClient.get_summary(self._arrow_client, job_id)) diff --git a/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graph_sampling_arrow_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graph_sampling_arrow_endpoints.py new file mode 100644 index 000000000..0670936d5 --- /dev/null +++ b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graph_sampling_arrow_endpoints.py @@ -0,0 +1,98 @@ +from typing import Generator + +import pytest + +from graphdatascience import Graph +from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.procedure_surface.arrow.graph_sampling_arrow_endpoints import GraphSamplingArrowEndpoints +from graphdatascience.tests.integrationV2.procedure_surface.arrow.graph_creation_helper import ( + create_graph, +) + + +@pytest.fixture +def sample_graph(arrow_client: AuthenticatedArrowClient) -> Generator[Graph, None, None]: + gdl = """ + (a :Node {id: 0}) + (b :Node {id: 1}) + (c :Node {id: 2}) + (d :Node {id: 3}) + (e :Node {id: 4}) + (a)-[:REL {weight: 1.0}]->(b) + (b)-[:REL {weight: 2.0}]->(c) + (c)-[:REL {weight: 1.5}]->(d) + (d)-[:REL {weight: 0.5}]->(e) + (e)-[:REL {weight: 1.2}]->(a) + """ + + yield create_graph(arrow_client, "sample_graph", gdl) + arrow_client.do_action("v2/graph.drop", {"graphName": "sample_graph"}) + arrow_client.do_action("v2/graph.drop", {"graphName": "sampled"}) + + +@pytest.fixture +def graph_sampling_endpoints( + arrow_client: AuthenticatedArrowClient, +) -> Generator[GraphSamplingArrowEndpoints, None, None]: + yield GraphSamplingArrowEndpoints(arrow_client) + + +def test_rwr_basic(graph_sampling_endpoints: GraphSamplingArrowEndpoints, sample_graph: Graph) -> None: + result = graph_sampling_endpoints.rwr( + G=sample_graph, graph_name="sampled", startNodes=[0, 1], restartProbability=0.15, samplingRatio=0.8 + ) + + assert result.graph_name == "sampled" + assert result.from_graph_name == sample_graph.name() + assert result.node_count > 0 + assert result.relationship_count >= 0 + assert result.start_node_count > 0 + assert result.project_millis >= 0 + + +def test_rwr_with_weights(graph_sampling_endpoints: GraphSamplingArrowEndpoints, sample_graph: Graph) -> None: + result = graph_sampling_endpoints.rwr( + G=sample_graph, + graph_name="sampled", + startNodes=[0], + restartProbability=0.2, + samplingRatio=0.6, + relationshipWeightProperty="weight", + ) + + assert result.graph_name == "sampled" + assert result.from_graph_name == sample_graph.name() + assert result.node_count > 0 + assert result.start_node_count >= 1 + assert result.project_millis >= 0 + + +def test_rwr_minimal_config(graph_sampling_endpoints: GraphSamplingArrowEndpoints, sample_graph: Graph) -> None: + result = graph_sampling_endpoints.rwr(G=sample_graph, graph_name="sampled") + + assert result.graph_name == "sampled" + assert result.from_graph_name == sample_graph.name() + assert result.node_count > 0 + assert result.project_millis >= 0 + + +def test_cnarw_basic(graph_sampling_endpoints: GraphSamplingArrowEndpoints, sample_graph: Graph) -> None: + result = graph_sampling_endpoints.cnarw( + G=sample_graph, graph_name="sampled", startNodes=[0, 1], restartProbability=0.15, samplingRatio=0.8 + ) + + assert result.graph_name == "sampled" + assert result.from_graph_name == sample_graph.name() + assert result.node_count > 0 + assert result.relationship_count >= 0 + assert result.start_node_count == 2 + assert result.project_millis >= 0 + + +def test_cnarw_minimal_config(graph_sampling_endpoints: GraphSamplingArrowEndpoints, sample_graph: Graph) -> None: + result = graph_sampling_endpoints.cnarw(G=sample_graph, graph_name="sampled") + + assert result.graph_name == "sampled" + assert result.from_graph_name == sample_graph.name() + assert result.node_count > 0 + assert result.project_millis >= 0 From 470bb31f314e6b657866ec85460ad2b0d05c9b6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Wed, 27 Aug 2025 23:10:30 +0200 Subject: [PATCH 3/8] Use snake case names for sampling arguments --- .../api/graph_sampling_endpoints.py | 40 +++++++++---------- .../arrow/graph_sampling_arrow_endpoints.py | 40 +++++++++---------- .../test_graph_sampling_arrow_endpoints.py | 12 +++--- 3 files changed, 46 insertions(+), 46 deletions(-) diff --git a/graphdatascience/procedure_surface/api/graph_sampling_endpoints.py b/graphdatascience/procedure_surface/api/graph_sampling_endpoints.py index f80cc21f7..d8524f5ff 100644 --- a/graphdatascience/procedure_surface/api/graph_sampling_endpoints.py +++ b/graphdatascience/procedure_surface/api/graph_sampling_endpoints.py @@ -17,11 +17,11 @@ def rwr( self, G: Graph, graph_name: str, - startNodes: Optional[List[int]] = None, - restartProbability: Optional[float] = None, - samplingRatio: Optional[float] = None, - nodeLabelStratification: Optional[bool] = None, - relationshipWeightProperty: Optional[str] = None, + start_nodes: Optional[List[int]] = None, + restart_probability: Optional[float] = None, + sampling_ratio: Optional[float] = None, + node_label_stratification: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, relationship_types: Optional[List[str]] = None, node_labels: Optional[List[str]] = None, sudo: Optional[bool] = None, @@ -44,18 +44,18 @@ def rwr( performed. graph_name : str The name of the new graph in the catalog. - startNodes : list of int, optional + start_nodes : list of int, optional A list of node IDs to start the random walk from. If not provided, all nodes are used as potential starting points. - restartProbability : float, optional + restart_probability : float, optional The probability of restarting back to the original node at each step. Should be a value between 0 and 1. If not specified, a default value is used. - samplingRatio : float, optional + sampling_ratio : float, optional The ratio of nodes to sample during the computation. This value should be between 0 and 1. If not specified, no sampling is performed. - nodeLabelStratification : bool, optional + node_label_stratification : bool, optional If True, the algorithm tries to preserve the label distribution of the original graph in the sampled graph. - relationshipWeightProperty : str, optional + relationship_weight_property : str, optional The name of the property on relationships to use as weights during the random walk. If not specified, the relationships are treated as unweighted. @@ -88,11 +88,11 @@ def cnarw( self, G: Graph, graph_name: str, - startNodes: Optional[List[int]] = None, - restartProbability: Optional[float] = None, - samplingRatio: Optional[float] = None, - nodeLabelStratification: Optional[bool] = None, - relationshipWeightProperty: Optional[str] = None, + start_nodes: Optional[List[int]] = None, + restart_probability: Optional[float] = None, + sampling_ratio: Optional[float] = None, + node_label_stratification: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, relationship_types: Optional[List[str]] = None, node_labels: Optional[List[str]] = None, sudo: Optional[bool] = None, @@ -115,18 +115,18 @@ def cnarw( performed. graph_name : str The name of the new graph in the catalog. - startNodes : list of int, optional + start_nodes : list of int, optional A list of node IDs to start the random walk from. If not provided, all nodes are used as potential starting points. - restartProbability : float, optional + restart_probability : float, optional The probability of restarting back to the original node at each step. Should be a value between 0 and 1. If not specified, a default value is used. - samplingRatio : float, optional + sampling_ratio : float, optional The ratio of nodes to sample during the computation. This value should be between 0 and 1. If not specified, no sampling is performed. - nodeLabelStratification : bool, optional + node_label_stratification : bool, optional If True, the algorithm tries to preserve the label distribution of the original graph in the sampled graph. - relationshipWeightProperty : str, optional + relationship_weight_property : str, optional The name of the property on relationships to use as weights during the random walk. If not specified, the relationships are treated as unweighted. diff --git a/graphdatascience/procedure_surface/arrow/graph_sampling_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/graph_sampling_arrow_endpoints.py index 444e6520c..1ef0be1c4 100644 --- a/graphdatascience/procedure_surface/arrow/graph_sampling_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/graph_sampling_arrow_endpoints.py @@ -20,11 +20,11 @@ def rwr( self, G: Graph, graph_name: str, - startNodes: Optional[List[int]] = None, - restartProbability: Optional[float] = None, - samplingRatio: Optional[float] = None, - nodeLabelStratification: Optional[bool] = None, - relationshipWeightProperty: Optional[str] = None, + start_nodes: Optional[List[int]] = None, + restart_probability: Optional[float] = None, + sampling_ratio: Optional[float] = None, + node_label_stratification: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, relationship_types: Optional[List[str]] = None, node_labels: Optional[List[str]] = None, sudo: Optional[bool] = None, @@ -36,11 +36,11 @@ def rwr( config = ConfigConverter.convert_to_gds_config( from_graph_name=G.name(), graph_name=graph_name, - startNodes=startNodes, - restartProbability=restartProbability, - samplingRatio=samplingRatio, - nodeLabelStratification=nodeLabelStratification, - relationshipWeightProperty=relationshipWeightProperty, + startNodes=start_nodes, + restartProbability=restart_probability, + samplingRatio=sampling_ratio, + nodeLabelStratification=node_label_stratification, + relationshipWeightProperty=relationship_weight_property, relationship_types=relationship_types, node_labels=node_labels, sudo=sudo, @@ -58,11 +58,11 @@ def cnarw( self, G: Graph, graph_name: str, - startNodes: Optional[List[int]] = None, - restartProbability: Optional[float] = None, - samplingRatio: Optional[float] = None, - nodeLabelStratification: Optional[bool] = None, - relationshipWeightProperty: Optional[str] = None, + start_nodes: Optional[List[int]] = None, + restart_probability: Optional[float] = None, + sampling_ratio: Optional[float] = None, + node_label_stratification: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, relationship_types: Optional[List[str]] = None, node_labels: Optional[List[str]] = None, sudo: Optional[bool] = None, @@ -74,11 +74,11 @@ def cnarw( config = ConfigConverter.convert_to_gds_config( from_graph_name=G.name(), graph_name=graph_name, - startNodes=startNodes, - restartProbability=restartProbability, - samplingRatio=samplingRatio, - nodeLabelStratification=nodeLabelStratification, - relationshipWeightProperty=relationshipWeightProperty, + startNodes=start_nodes, + restartProbability=restart_probability, + samplingRatio=sampling_ratio, + nodeLabelStratification=node_label_stratification, + relationshipWeightProperty=relationship_weight_property, relationship_types=relationship_types, node_labels=node_labels, sudo=sudo, diff --git a/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graph_sampling_arrow_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graph_sampling_arrow_endpoints.py index 0670936d5..ff1111d77 100644 --- a/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graph_sampling_arrow_endpoints.py +++ b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graph_sampling_arrow_endpoints.py @@ -39,7 +39,7 @@ def graph_sampling_endpoints( def test_rwr_basic(graph_sampling_endpoints: GraphSamplingArrowEndpoints, sample_graph: Graph) -> None: result = graph_sampling_endpoints.rwr( - G=sample_graph, graph_name="sampled", startNodes=[0, 1], restartProbability=0.15, samplingRatio=0.8 + G=sample_graph, graph_name="sampled", start_nodes=[0, 1], restart_probability=0.15, sampling_ratio=0.8 ) assert result.graph_name == "sampled" @@ -54,10 +54,10 @@ def test_rwr_with_weights(graph_sampling_endpoints: GraphSamplingArrowEndpoints, result = graph_sampling_endpoints.rwr( G=sample_graph, graph_name="sampled", - startNodes=[0], - restartProbability=0.2, - samplingRatio=0.6, - relationshipWeightProperty="weight", + start_nodes=[0], + restart_probability=0.2, + sampling_ratio=0.6, + relationship_weight_property="weight", ) assert result.graph_name == "sampled" @@ -78,7 +78,7 @@ def test_rwr_minimal_config(graph_sampling_endpoints: GraphSamplingArrowEndpoint def test_cnarw_basic(graph_sampling_endpoints: GraphSamplingArrowEndpoints, sample_graph: Graph) -> None: result = graph_sampling_endpoints.cnarw( - G=sample_graph, graph_name="sampled", startNodes=[0, 1], restartProbability=0.15, samplingRatio=0.8 + G=sample_graph, graph_name="sampled", start_nodes=[0, 1], restart_probability=0.15, sampling_ratio=0.8 ) assert result.graph_name == "sampled" From d647c45f5d4d1073d2a2e9550932a1302c9e2ff6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Wed, 27 Aug 2025 23:10:51 +0200 Subject: [PATCH 4/8] Implement cypher sampling endpoints --- .../cypher/graph_sampling_cypher_endpoints.py | 98 ++++++++++++++ .../cypher/cypher_graph_helper.py | 7 + .../test_graph_sampling_cypher_endpoints.py | 125 ++++++++++++++++++ 3 files changed, 230 insertions(+) create mode 100644 graphdatascience/procedure_surface/cypher/graph_sampling_cypher_endpoints.py create mode 100644 graphdatascience/tests/integrationV2/procedure_surface/cypher/cypher_graph_helper.py create mode 100644 graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graph_sampling_cypher_endpoints.py diff --git a/graphdatascience/procedure_surface/cypher/graph_sampling_cypher_endpoints.py b/graphdatascience/procedure_surface/cypher/graph_sampling_cypher_endpoints.py new file mode 100644 index 000000000..cce413e6e --- /dev/null +++ b/graphdatascience/procedure_surface/cypher/graph_sampling_cypher_endpoints.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from typing import Any, List, Optional + +from ...call_parameters import CallParameters +from ...graph.graph_object import Graph +from ...query_runner.query_runner import QueryRunner +from ..api.graph_sampling_endpoints import GraphSamplingEndpoints, GraphSamplingResult +from ..utils.config_converter import ConfigConverter + + +class GraphSamplingCypherEndpoints(GraphSamplingEndpoints): + def __init__(self, query_runner: QueryRunner): + self._query_runner = query_runner + + def rwr( + self, + G: Graph, + graph_name: str, + start_nodes: Optional[List[int]] = None, + restart_probability: Optional[float] = None, + sampling_ratio: Optional[float] = None, + node_label_stratification: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + ) -> GraphSamplingResult: + config = ConfigConverter.convert_to_gds_config( + startNodes=start_nodes, + restartProbability=restart_probability, + samplingRatio=sampling_ratio, + nodeLabelStratification=node_label_stratification, + relationshipWeightProperty=relationship_weight_property, + relationshipTypes=relationship_types, + nodeLabels=node_labels, + sudo=sudo, + logProgress=log_progress, + username=username, + concurrency=concurrency, + jobId=job_id, + ) + + params = CallParameters( + graph_name=graph_name, + from_graph_name=G.name(), + config=config, + ) + params.ensure_job_id_in_config() + + result = self._query_runner.call_procedure(endpoint="gds.graph.sample.rwr", params=params).squeeze() + return GraphSamplingResult(**result.to_dict()) + + def cnarw( + self, + G: Graph, + graph_name: str, + start_nodes: Optional[List[int]] = None, + restart_probability: Optional[float] = None, + sampling_ratio: Optional[float] = None, + node_label_stratification: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + ) -> GraphSamplingResult: + config = ConfigConverter.convert_to_gds_config( + startNodes=start_nodes, + restartProbability=restart_probability, + samplingRatio=sampling_ratio, + nodeLabelStratification=node_label_stratification, + relationshipWeightProperty=relationship_weight_property, + relationshipTypes=relationship_types, + nodeLabels=node_labels, + sudo=sudo, + logProgress=log_progress, + username=username, + concurrency=concurrency, + jobId=job_id, + ) + + params = CallParameters( + graph_name=graph_name, + from_graph_name=G.name(), + config=config, + ) + params.ensure_job_id_in_config() + + result = self._query_runner.call_procedure(endpoint="gds.graph.sample.cnarw", params=params).squeeze() + return GraphSamplingResult(**result.to_dict()) diff --git a/graphdatascience/tests/integrationV2/procedure_surface/cypher/cypher_graph_helper.py b/graphdatascience/tests/integrationV2/procedure_surface/cypher/cypher_graph_helper.py new file mode 100644 index 000000000..33a772740 --- /dev/null +++ b/graphdatascience/tests/integrationV2/procedure_surface/cypher/cypher_graph_helper.py @@ -0,0 +1,7 @@ +from graphdatascience import QueryRunner + + +def delete_all_graphs(query_runner: QueryRunner) -> None: + query_runner.run_cypher( + "CALL gds.graph.list() YIELD graphName CALL gds.graph.drop(graphName) YIELD graphName as g RETURN g" + ) diff --git a/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graph_sampling_cypher_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graph_sampling_cypher_endpoints.py new file mode 100644 index 000000000..63de3e264 --- /dev/null +++ b/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graph_sampling_cypher_endpoints.py @@ -0,0 +1,125 @@ +from typing import Generator + +import pytest + +from graphdatascience import Graph, QueryRunner +from graphdatascience.procedure_surface.cypher.graph_sampling_cypher_endpoints import GraphSamplingCypherEndpoints +from graphdatascience.tests.integrationV2.procedure_surface.cypher.cypher_graph_helper import delete_all_graphs + + +@pytest.fixture +def sample_graph(query_runner: QueryRunner) -> Generator[Graph, None, None]: + create_statement = """ + CREATE + (a: Node {id: 0}), + (b: Node {id: 1}), + (c: Node {id: 2}), + (d: Node {id: 3}), + (e: Node {id: 4}), + (a)-[:REL {weight: 1.0}]->(b), + (b)-[:REL {weight: 2.0}]->(c), + (c)-[:REL {weight: 1.5}]->(d), + (d)-[:REL {weight: 0.5}]->(e), + (e)-[:REL {weight: 1.2}]->(a) + """ + + query_runner.run_cypher(create_statement) + + query_runner.run_cypher(""" + MATCH (n) + OPTIONAL MATCH (n)-[r]->(m) + WITH gds.graph.project('g', n, m, {relationshipProperties: {weight: r.weight}}) AS G + RETURN G + """) + + yield Graph("g", query_runner) + + delete_all_graphs(query_runner) + query_runner.run_cypher("MATCH (n) DETACH DELETE n") + + +@pytest.fixture +def graph_sampling_endpoints(query_runner: QueryRunner) -> Generator[GraphSamplingCypherEndpoints, None, None]: + yield GraphSamplingCypherEndpoints(query_runner) + + +def test_rwr_basic(graph_sampling_endpoints: GraphSamplingCypherEndpoints, sample_graph: Graph) -> None: + """Test RWR sampling with basic configuration.""" + result = graph_sampling_endpoints.rwr( + G=sample_graph, graph_name="rwr_sampled", start_nodes=[0, 1], restart_probability=0.15, sampling_ratio=0.8 + ) + + assert result.graph_name == "rwr_sampled" + assert result.from_graph_name == sample_graph.name() + assert result.node_count > 0 + assert result.relationship_count >= 0 + assert result.start_node_count >= 1 + assert result.project_millis >= 0 + + +def test_rwr_with_weights(graph_sampling_endpoints: GraphSamplingCypherEndpoints, sample_graph: Graph) -> None: + """Test RWR sampling with weighted relationships.""" + result = graph_sampling_endpoints.rwr( + G=sample_graph, + graph_name="rwr_weighted", + restart_probability=0.2, + sampling_ratio=0.6, + relationship_weight_property="weight", + ) + + assert result.graph_name == "rwr_weighted" + assert result.from_graph_name == sample_graph.name() + assert result.node_count > 0 + assert result.start_node_count >= 1 + assert result.project_millis >= 0 + + +def test_rwr_minimal_config(graph_sampling_endpoints: GraphSamplingCypherEndpoints, sample_graph: Graph) -> None: + """Test RWR sampling with minimal configuration.""" + result = graph_sampling_endpoints.rwr(G=sample_graph, graph_name="rwr_minimal") + + assert result.graph_name == "rwr_minimal" + assert result.from_graph_name == sample_graph.name() + assert result.node_count > 0 + assert result.project_millis >= 0 + + +def test_cnarw_basic(graph_sampling_endpoints: GraphSamplingCypherEndpoints, sample_graph: Graph) -> None: + """Test CNARW sampling with basic configuration.""" + result = graph_sampling_endpoints.cnarw( + G=sample_graph, graph_name="cnarw_sampled", restart_probability=0.15, sampling_ratio=0.8 + ) + + assert result.graph_name == "cnarw_sampled" + assert result.from_graph_name == sample_graph.name() + assert result.node_count > 0 + assert result.relationship_count >= 0 + assert result.start_node_count >= 1 + assert result.project_millis >= 0 + + +def test_cnarw_with_stratification(graph_sampling_endpoints: GraphSamplingCypherEndpoints, sample_graph: Graph) -> None: + """Test CNARW sampling with node label stratification.""" + result = graph_sampling_endpoints.cnarw( + G=sample_graph, + graph_name="cnarw_stratified", + restart_probability=0.1, + sampling_ratio=0.7, + node_label_stratification=True, + ) + + assert result.graph_name == "cnarw_stratified" + assert result.from_graph_name == sample_graph.name() + assert result.node_count > 0 + assert result.start_node_count >= 1 + assert result.project_millis >= 0 + + +def test_cnarw_minimal_config(graph_sampling_endpoints: GraphSamplingCypherEndpoints, sample_graph: Graph) -> None: + """Test CNARW sampling with minimal configuration.""" + result = graph_sampling_endpoints.cnarw(G=sample_graph, graph_name="cnarw_minimal") + + assert result.graph_name == "cnarw_minimal" + assert result.from_graph_name == sample_graph.name() + assert result.start_node_count >= 1 + assert result.project_millis >= 0 From 3895d595a9c25b9e3a4cb0f2ef2b88df82d5c3e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Wed, 27 Aug 2025 23:12:55 +0200 Subject: [PATCH 5/8] Expose sampling endpoints in catalog endpoints --- .../procedure_surface/api/catalog_endpoints.py | 6 ++++++ .../procedure_surface/arrow/catalog_arrow_endpoints.py | 7 +++++++ 2 files changed, 13 insertions(+) diff --git a/graphdatascience/procedure_surface/api/catalog_endpoints.py b/graphdatascience/procedure_surface/api/catalog_endpoints.py index fecf0b45e..552e58e80 100644 --- a/graphdatascience/procedure_surface/api/catalog_endpoints.py +++ b/graphdatascience/procedure_surface/api/catalog_endpoints.py @@ -9,6 +9,7 @@ from graphdatascience import Graph from graphdatascience.procedure_surface.api.base_result import BaseResult +from graphdatascience.procedure_surface.api.graph_sampling_endpoints import GraphSamplingEndpoints class CatalogEndpoints(ABC): @@ -65,6 +66,11 @@ def filter( """ pass + @property + @abstractmethod + def sample(self) -> GraphSamplingEndpoints: + pass + class GraphListResult(BaseResult): graph_name: str diff --git a/graphdatascience/procedure_surface/arrow/catalog_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/catalog_arrow_endpoints.py index cab59fc88..2e9e2aee6 100644 --- a/graphdatascience/procedure_surface/arrow/catalog_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/catalog_arrow_endpoints.py @@ -13,6 +13,8 @@ GraphFilterResult, GraphListResult, ) +from graphdatascience.procedure_surface.api.graph_sampling_endpoints import GraphSamplingEndpoints +from graphdatascience.procedure_surface.arrow.graph_sampling_arrow_endpoints import GraphSamplingArrowEndpoints from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter from graphdatascience.query_runner.protocol.project_protocols import ProjectProtocol from graphdatascience.query_runner.termination_flag import TerminationFlag @@ -116,6 +118,10 @@ def filter( return GraphFilterResult(**JobClient.get_summary(self._arrow_client, job_id)) + @property + def sample(self) -> GraphSamplingEndpoints: + return GraphSamplingArrowEndpoints(self._arrow_client) + def _arrow_config(self) -> dict[str, Any]: connection_info = self._arrow_client.advertised_connection_info() @@ -131,6 +137,7 @@ def _arrow_config(self) -> dict[str, Any]: } + class ProjectionResult(BaseResult): graph_name: str node_count: int From e197129cecccc146c6433393d2d577a368213b14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Thu, 28 Aug 2025 12:17:06 +0200 Subject: [PATCH 6/8] Fix tests and code style --- .../procedure_surface/arrow/catalog_arrow_endpoints.py | 1 - .../arrow/test_graph_sampling_arrow_endpoints.py | 9 ++++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/graphdatascience/procedure_surface/arrow/catalog_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/catalog_arrow_endpoints.py index 2e9e2aee6..3bba9a16d 100644 --- a/graphdatascience/procedure_surface/arrow/catalog_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/catalog_arrow_endpoints.py @@ -137,7 +137,6 @@ def _arrow_config(self) -> dict[str, Any]: } - class ProjectionResult(BaseResult): graph_name: str node_count: int diff --git a/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graph_sampling_arrow_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graph_sampling_arrow_endpoints.py index ff1111d77..677f49361 100644 --- a/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graph_sampling_arrow_endpoints.py +++ b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graph_sampling_arrow_endpoints.py @@ -39,14 +39,14 @@ def graph_sampling_endpoints( def test_rwr_basic(graph_sampling_endpoints: GraphSamplingArrowEndpoints, sample_graph: Graph) -> None: result = graph_sampling_endpoints.rwr( - G=sample_graph, graph_name="sampled", start_nodes=[0, 1], restart_probability=0.15, sampling_ratio=0.8 + G=sample_graph, graph_name="sampled", restart_probability=0.15, sampling_ratio=0.8 ) assert result.graph_name == "sampled" assert result.from_graph_name == sample_graph.name() assert result.node_count > 0 assert result.relationship_count >= 0 - assert result.start_node_count > 0 + assert result.start_node_count >= 1 assert result.project_millis >= 0 @@ -54,7 +54,6 @@ def test_rwr_with_weights(graph_sampling_endpoints: GraphSamplingArrowEndpoints, result = graph_sampling_endpoints.rwr( G=sample_graph, graph_name="sampled", - start_nodes=[0], restart_probability=0.2, sampling_ratio=0.6, relationship_weight_property="weight", @@ -78,14 +77,14 @@ def test_rwr_minimal_config(graph_sampling_endpoints: GraphSamplingArrowEndpoint def test_cnarw_basic(graph_sampling_endpoints: GraphSamplingArrowEndpoints, sample_graph: Graph) -> None: result = graph_sampling_endpoints.cnarw( - G=sample_graph, graph_name="sampled", start_nodes=[0, 1], restart_probability=0.15, sampling_ratio=0.8 + G=sample_graph, graph_name="sampled", restart_probability=0.15, sampling_ratio=0.8 ) assert result.graph_name == "sampled" assert result.from_graph_name == sample_graph.name() assert result.node_count > 0 assert result.relationship_count >= 0 - assert result.start_node_count == 2 + assert result.start_node_count >= 1 assert result.project_millis >= 0 From c5e00b463d22a16683f275191ead1c38cb06cf16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Thu, 28 Aug 2025 21:56:06 +0200 Subject: [PATCH 7/8] Minor cleanups --- .../api/graph_sampling_endpoints.py | 26 +++++++------- .../arrow/graph_sampling_arrow_endpoints.py | 20 +++++------ .../cypher/graph_sampling_cypher_endpoints.py | 36 +++++++++---------- 3 files changed, 41 insertions(+), 41 deletions(-) diff --git a/graphdatascience/procedure_surface/api/graph_sampling_endpoints.py b/graphdatascience/procedure_surface/api/graph_sampling_endpoints.py index d8524f5ff..c58afbfa6 100644 --- a/graphdatascience/procedure_surface/api/graph_sampling_endpoints.py +++ b/graphdatascience/procedure_surface/api/graph_sampling_endpoints.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, List, Optional +from typing import List, Optional from graphdatascience import Graph from graphdatascience.procedure_surface.api.base_result import BaseResult @@ -27,8 +27,8 @@ def rwr( sudo: Optional[bool] = None, log_progress: Optional[bool] = None, username: Optional[str] = None, - concurrency: Optional[Any] = None, - job_id: Optional[Any] = None, + concurrency: Optional[int] = None, + job_id: Optional[str] = None, ) -> GraphSamplingResult: """ Computes a set of Random Walks with Restart (RWR) for the given graph and stores the result as a new graph in the catalog. @@ -59,9 +59,9 @@ def rwr( The name of the property on relationships to use as weights during the random walk. If not specified, the relationships are treated as unweighted. - relationship_types : Optional[List[str]], default=None + relationship_types : list of str, optional The relationship types used to select relationships for this algorithm run. - node_labels : Optional[List[str]], default=None + node_labels : list of str, optional The node labels used to select nodes for this algorithm run. sudo : bool, optional Override memory estimation limits. Use with caution as this can lead to @@ -70,9 +70,9 @@ def rwr( If True, logs the progress of the computation. username : str, optional The username to attribute the procedure run to - concurrency : Any, optional + concurrency : int, optional The number of concurrent threads used for the algorithm execution. - job_id : Any, optional + job_id : str, optional An identifier for the job that can be used for monitoring and cancellation Returns @@ -98,8 +98,8 @@ def cnarw( sudo: Optional[bool] = None, log_progress: Optional[bool] = None, username: Optional[str] = None, - concurrency: Optional[Any] = None, - job_id: Optional[Any] = None, + concurrency: Optional[int] = None, + job_id: Optional[str] = None, ) -> GraphSamplingResult: """ Computes a set of Random Walks with Restart (RWR) for the given graph and stores the result as a new graph in the catalog. @@ -130,9 +130,9 @@ def cnarw( The name of the property on relationships to use as weights during the random walk. If not specified, the relationships are treated as unweighted. - relationship_types : Optional[List[str]], default=None + relationship_types : list of str, optional The relationship types used to select relationships for this algorithm run. - node_labels : Optional[List[str]], default=None + node_labels : list of str, optional The node labels used to select nodes for this algorithm run. sudo : bool, optional Override memory estimation limits. Use with caution as this can lead to @@ -141,9 +141,9 @@ def cnarw( If True, logs the progress of the computation. username : str, optional The username to attribute the procedure run to - concurrency : Any, optional + concurrency : int, optional The number of concurrent threads used for the algorithm execution. - job_id : Any, optional + job_id : str, optional An identifier for the job that can be used for monitoring and cancellation Returns diff --git a/graphdatascience/procedure_surface/arrow/graph_sampling_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/graph_sampling_arrow_endpoints.py index 1ef0be1c4..74cd43400 100644 --- a/graphdatascience/procedure_surface/arrow/graph_sampling_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/graph_sampling_arrow_endpoints.py @@ -36,11 +36,11 @@ def rwr( config = ConfigConverter.convert_to_gds_config( from_graph_name=G.name(), graph_name=graph_name, - startNodes=start_nodes, - restartProbability=restart_probability, - samplingRatio=sampling_ratio, - nodeLabelStratification=node_label_stratification, - relationshipWeightProperty=relationship_weight_property, + start_nodes=start_nodes, + restart_probability=restart_probability, + sampling_ratio=sampling_ratio, + node_label_stratification=node_label_stratification, + relationship_weight_property=relationship_weight_property, relationship_types=relationship_types, node_labels=node_labels, sudo=sudo, @@ -74,11 +74,11 @@ def cnarw( config = ConfigConverter.convert_to_gds_config( from_graph_name=G.name(), graph_name=graph_name, - startNodes=start_nodes, - restartProbability=restart_probability, - samplingRatio=sampling_ratio, - nodeLabelStratification=node_label_stratification, - relationshipWeightProperty=relationship_weight_property, + start_nodes=start_nodes, + restart_probability=restart_probability, + sampling_ratio=sampling_ratio, + node_label_stratification=node_label_stratification, + relationship_weight_property=relationship_weight_property, relationship_types=relationship_types, node_labels=node_labels, sudo=sudo, diff --git a/graphdatascience/procedure_surface/cypher/graph_sampling_cypher_endpoints.py b/graphdatascience/procedure_surface/cypher/graph_sampling_cypher_endpoints.py index cce413e6e..5dd138b3a 100644 --- a/graphdatascience/procedure_surface/cypher/graph_sampling_cypher_endpoints.py +++ b/graphdatascience/procedure_surface/cypher/graph_sampling_cypher_endpoints.py @@ -31,18 +31,18 @@ def rwr( job_id: Optional[Any] = None, ) -> GraphSamplingResult: config = ConfigConverter.convert_to_gds_config( - startNodes=start_nodes, - restartProbability=restart_probability, - samplingRatio=sampling_ratio, - nodeLabelStratification=node_label_stratification, - relationshipWeightProperty=relationship_weight_property, - relationshipTypes=relationship_types, - nodeLabels=node_labels, + start_nodes=start_nodes, + restart_probability=restart_probability, + sampling_ratio=sampling_ratio, + node_label_stratification=node_label_stratification, + relationship_weight_property=relationship_weight_property, + relationship_types=relationship_types, + node_labels=node_labels, sudo=sudo, - logProgress=log_progress, + log_progress=log_progress, username=username, concurrency=concurrency, - jobId=job_id, + job_id=job_id, ) params = CallParameters( @@ -73,18 +73,18 @@ def cnarw( job_id: Optional[Any] = None, ) -> GraphSamplingResult: config = ConfigConverter.convert_to_gds_config( - startNodes=start_nodes, - restartProbability=restart_probability, - samplingRatio=sampling_ratio, - nodeLabelStratification=node_label_stratification, - relationshipWeightProperty=relationship_weight_property, - relationshipTypes=relationship_types, - nodeLabels=node_labels, + start_nodes=start_nodes, + restart_probability=restart_probability, + sampling_ratio=sampling_ratio, + node_label_stratification=node_label_stratification, + relationship_weight_property=relationship_weight_property, + relationship_types=relationship_types, + node_labels=node_labels, sudo=sudo, - logProgress=log_progress, + log_progress=log_progress, username=username, concurrency=concurrency, - jobId=job_id, + job_id=job_id, ) params = CallParameters( From c24762a05c00b6cfa7b367d645df4e7b67400f08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Thu, 28 Aug 2025 22:11:24 +0200 Subject: [PATCH 8/8] Fix tests --- .../cypher/test_graph_sampling_cypher_endpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graph_sampling_cypher_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graph_sampling_cypher_endpoints.py index 63de3e264..a6bb7fb20 100644 --- a/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graph_sampling_cypher_endpoints.py +++ b/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graph_sampling_cypher_endpoints.py @@ -46,7 +46,7 @@ def graph_sampling_endpoints(query_runner: QueryRunner) -> Generator[GraphSampli def test_rwr_basic(graph_sampling_endpoints: GraphSamplingCypherEndpoints, sample_graph: Graph) -> None: """Test RWR sampling with basic configuration.""" result = graph_sampling_endpoints.rwr( - G=sample_graph, graph_name="rwr_sampled", start_nodes=[0, 1], restart_probability=0.15, sampling_ratio=0.8 + G=sample_graph, graph_name="rwr_sampled", restart_probability=0.15, sampling_ratio=0.8 ) assert result.graph_name == "rwr_sampled"