Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions graphdatascience/procedure_surface/api/catalog_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from graphdatascience import Graph
from graphdatascience.procedure_surface.api.base_result import BaseResult
from graphdatascience.procedure_surface.api.graph_sampling_endpoints import GraphSamplingEndpoints


class CatalogEndpoints(ABC):
Expand Down Expand Up @@ -65,6 +66,11 @@ def filter(
"""
pass

@property
@abstractmethod
def sample(self) -> GraphSamplingEndpoints:
pass


class GraphListResult(BaseResult):
graph_name: str
Expand Down
164 changes: 164 additions & 0 deletions graphdatascience/procedure_surface/api/graph_sampling_endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import List, Optional

from graphdatascience import Graph
from graphdatascience.procedure_surface.api.base_result import BaseResult


class GraphSamplingEndpoints(ABC):
"""
Abstract base class defining the API for graph sampling algorithms algorithm.
"""

@abstractmethod
def rwr(
self,
G: Graph,
graph_name: str,
start_nodes: Optional[List[int]] = None,
restart_probability: Optional[float] = None,
sampling_ratio: Optional[float] = None,
node_label_stratification: Optional[bool] = None,
relationship_weight_property: Optional[str] = None,
relationship_types: Optional[List[str]] = None,
node_labels: Optional[List[str]] = None,
sudo: Optional[bool] = None,
log_progress: Optional[bool] = None,
username: Optional[str] = None,
concurrency: Optional[int] = None,
job_id: Optional[str] = None,
) -> GraphSamplingResult:
"""
Computes a set of Random Walks with Restart (RWR) for the given graph and stores the result as a new graph in the catalog.

This method performs a random walk, beginning from a set of nodes (if provided),
where at each step there is a probability to restart back at the original nodes.
The result is turned into a new graph induced by the random walks and stored in the catalog.

Parameters
----------
G : Graph
The input graph on which the Random Walk with Restart (RWR) will be
performed.
graph_name : str
The name of the new graph in the catalog.
start_nodes : list of int, optional
A list of node IDs to start the random walk from. If not provided, all
nodes are used as potential starting points.
restart_probability : float, optional
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it float, optional here but Optional[List[str]], default=None below?
shouldn't it be Optional[float], default=0.44 (or whatever the default value is)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not actually specify the GDS defined default value here. None essentially means that the default value defined by GDS will be used. That was a decision I made in order to avoid differences between GDS and the API.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should include all default values, or no default values. it is strange to me that we include some default values, such as None in default=None but we don't for others, such as float, optional.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah the none is there, so that Python allows you to omit this parameter. It has nothing to do with documenting an actual GDS default. The default will be set on the GDS side

The probability of restarting back to the original node at each step.
Should be a value between 0 and 1. If not specified, a default value is used.
sampling_ratio : float, optional
The ratio of nodes to sample during the computation. This value should
be between 0 and 1. If not specified, no sampling is performed.
node_label_stratification : bool, optional
If True, the algorithm tries to preserve the label distribution of the original graph in the sampled graph.
relationship_weight_property : str, optional
The name of the property on relationships to use as weights during
the random walk. If not specified, the relationships are treated as
unweighted.
relationship_types : list of str, optional
The relationship types used to select relationships for this algorithm run.
node_labels : list of str, optional
The node labels used to select nodes for this algorithm run.
sudo : bool, optional
Override memory estimation limits. Use with caution as this can lead to
memory issues if the estimation is significantly wrong.
log_progress : bool, optional
If True, logs the progress of the computation.
username : str, optional
The username to attribute the procedure run to
concurrency : int, optional
The number of concurrent threads used for the algorithm execution.
job_id : str, optional
An identifier for the job that can be used for monitoring and cancellation

Returns
-------
GraphSamplingResult
The result of the Random Walk with Restart (RWR), including the sampled
nodes and their scores.
"""
pass

@abstractmethod
def cnarw(
self,
G: Graph,
graph_name: str,
start_nodes: Optional[List[int]] = None,
restart_probability: Optional[float] = None,
sampling_ratio: Optional[float] = None,
node_label_stratification: Optional[bool] = None,
relationship_weight_property: Optional[str] = None,
relationship_types: Optional[List[str]] = None,
node_labels: Optional[List[str]] = None,
sudo: Optional[bool] = None,
log_progress: Optional[bool] = None,
username: Optional[str] = None,
concurrency: Optional[int] = None,
job_id: Optional[str] = None,
) -> GraphSamplingResult:
"""
Computes a set of Random Walks with Restart (RWR) for the given graph and stores the result as a new graph in the catalog.

This method performs a random walk, beginning from a set of nodes (if provided),
where at each step there is a probability to restart back at the original nodes.
The result is turned into a new graph induced by the random walks and stored in the catalog.

Parameters
----------
G : Graph
The input graph on which the Random Walk with Restart (RWR) will be
performed.
graph_name : str
The name of the new graph in the catalog.
start_nodes : list of int, optional
A list of node IDs to start the random walk from. If not provided, all
nodes are used as potential starting points.
restart_probability : float, optional
The probability of restarting back to the original node at each step.
Should be a value between 0 and 1. If not specified, a default value is used.
sampling_ratio : float, optional
The ratio of nodes to sample during the computation. This value should
be between 0 and 1. If not specified, no sampling is performed.
node_label_stratification : bool, optional
If True, the algorithm tries to preserve the label distribution of the original graph in the sampled graph.
relationship_weight_property : str, optional
The name of the property on relationships to use as weights during
the random walk. If not specified, the relationships are treated as
unweighted.
relationship_types : list of str, optional
The relationship types used to select relationships for this algorithm run.
node_labels : list of str, optional
The node labels used to select nodes for this algorithm run.
sudo : bool, optional
Override memory estimation limits. Use with caution as this can lead to
memory issues if the estimation is significantly wrong.
log_progress : bool, optional
If True, logs the progress of the computation.
username : str, optional
The username to attribute the procedure run to
concurrency : int, optional
The number of concurrent threads used for the algorithm execution.
job_id : str, optional
An identifier for the job that can be used for monitoring and cancellation

Returns
-------
GraphSamplingResult
The result of the Random Walk with Restart (RWR), including the sampled
nodes and their scores.
"""
pass


class GraphSamplingResult(BaseResult):
graph_name: str
from_graph_name: str
node_count: int
relationship_count: int
start_node_count: int
project_millis: int
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
GraphFilterResult,
GraphListResult,
)
from graphdatascience.procedure_surface.api.graph_sampling_endpoints import GraphSamplingEndpoints
from graphdatascience.procedure_surface.arrow.graph_sampling_arrow_endpoints import GraphSamplingArrowEndpoints
from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter
from graphdatascience.query_runner.protocol.project_protocols import ProjectProtocol
from graphdatascience.query_runner.termination_flag import TerminationFlag
Expand Down Expand Up @@ -116,6 +118,10 @@ def filter(

return GraphFilterResult(**JobClient.get_summary(self._arrow_client, job_id))

@property
def sample(self) -> GraphSamplingEndpoints:
return GraphSamplingArrowEndpoints(self._arrow_client)

def _arrow_config(self) -> dict[str, Any]:
connection_info = self._arrow_client.advertised_connection_info()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from __future__ import annotations

from typing import Any, List, Optional

from graphdatascience import Graph
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
from graphdatascience.arrow_client.v2.job_client import JobClient
from graphdatascience.procedure_surface.api.graph_sampling_endpoints import (
GraphSamplingEndpoints,
GraphSamplingResult,
)
from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter


class GraphSamplingArrowEndpoints(GraphSamplingEndpoints):
def __init__(self, arrow_client: AuthenticatedArrowClient):
self._arrow_client = arrow_client

def rwr(
self,
G: Graph,
graph_name: str,
start_nodes: Optional[List[int]] = None,
restart_probability: Optional[float] = None,
sampling_ratio: Optional[float] = None,
node_label_stratification: Optional[bool] = None,
relationship_weight_property: Optional[str] = None,
relationship_types: Optional[List[str]] = None,
node_labels: Optional[List[str]] = None,
sudo: Optional[bool] = None,
log_progress: Optional[bool] = None,
username: Optional[str] = None,
concurrency: Optional[Any] = None,
job_id: Optional[Any] = None,
) -> GraphSamplingResult:
config = ConfigConverter.convert_to_gds_config(
from_graph_name=G.name(),
graph_name=graph_name,
start_nodes=start_nodes,
restart_probability=restart_probability,
sampling_ratio=sampling_ratio,
node_label_stratification=node_label_stratification,
relationship_weight_property=relationship_weight_property,
relationship_types=relationship_types,
node_labels=node_labels,
sudo=sudo,
log_progress=log_progress,
username=username,
concurrency=concurrency,
job_id=job_id,
)

job_id = JobClient.run_job_and_wait(self._arrow_client, "v2/graph.sample.rwr", config)

return GraphSamplingResult(**JobClient.get_summary(self._arrow_client, job_id))

def cnarw(
self,
G: Graph,
graph_name: str,
start_nodes: Optional[List[int]] = None,
restart_probability: Optional[float] = None,
sampling_ratio: Optional[float] = None,
node_label_stratification: Optional[bool] = None,
relationship_weight_property: Optional[str] = None,
relationship_types: Optional[List[str]] = None,
node_labels: Optional[List[str]] = None,
sudo: Optional[bool] = None,
log_progress: Optional[bool] = None,
username: Optional[str] = None,
concurrency: Optional[Any] = None,
job_id: Optional[Any] = None,
) -> GraphSamplingResult:
config = ConfigConverter.convert_to_gds_config(
from_graph_name=G.name(),
graph_name=graph_name,
start_nodes=start_nodes,
restart_probability=restart_probability,
sampling_ratio=sampling_ratio,
node_label_stratification=node_label_stratification,
relationship_weight_property=relationship_weight_property,
relationship_types=relationship_types,
node_labels=node_labels,
sudo=sudo,
log_progress=log_progress,
username=username,
concurrency=concurrency,
job_id=job_id,
)

job_id = JobClient.run_job_and_wait(self._arrow_client, "v2/graph.sample.cnarw", config)

return GraphSamplingResult(**JobClient.get_summary(self._arrow_client, job_id))
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from __future__ import annotations

from typing import Any, List, Optional

from ...call_parameters import CallParameters
from ...graph.graph_object import Graph
from ...query_runner.query_runner import QueryRunner
from ..api.graph_sampling_endpoints import GraphSamplingEndpoints, GraphSamplingResult
from ..utils.config_converter import ConfigConverter


class GraphSamplingCypherEndpoints(GraphSamplingEndpoints):
def __init__(self, query_runner: QueryRunner):
self._query_runner = query_runner

def rwr(
self,
G: Graph,
graph_name: str,
start_nodes: Optional[List[int]] = None,
restart_probability: Optional[float] = None,
sampling_ratio: Optional[float] = None,
node_label_stratification: Optional[bool] = None,
relationship_weight_property: Optional[str] = None,
relationship_types: Optional[List[str]] = None,
node_labels: Optional[List[str]] = None,
sudo: Optional[bool] = None,
log_progress: Optional[bool] = None,
username: Optional[str] = None,
concurrency: Optional[Any] = None,
job_id: Optional[Any] = None,
) -> GraphSamplingResult:
config = ConfigConverter.convert_to_gds_config(
start_nodes=start_nodes,
restart_probability=restart_probability,
sampling_ratio=sampling_ratio,
node_label_stratification=node_label_stratification,
relationship_weight_property=relationship_weight_property,
relationship_types=relationship_types,
node_labels=node_labels,
sudo=sudo,
log_progress=log_progress,
username=username,
concurrency=concurrency,
job_id=job_id,
)

params = CallParameters(
graph_name=graph_name,
from_graph_name=G.name(),
config=config,
Comment on lines +49 to +51
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it doesn't matter functionally here, but we could try to be consistent if we like

Suggested change
graph_name=graph_name,
from_graph_name=G.name(),
config=config,
graphName=graph_name,
fromGraphName=G.name(),
config=config,

)
params.ensure_job_id_in_config()

result = self._query_runner.call_procedure(endpoint="gds.graph.sample.rwr", params=params).squeeze()
return GraphSamplingResult(**result.to_dict())

def cnarw(
self,
G: Graph,
graph_name: str,
start_nodes: Optional[List[int]] = None,
restart_probability: Optional[float] = None,
sampling_ratio: Optional[float] = None,
node_label_stratification: Optional[bool] = None,
relationship_weight_property: Optional[str] = None,
relationship_types: Optional[List[str]] = None,
node_labels: Optional[List[str]] = None,
sudo: Optional[bool] = None,
log_progress: Optional[bool] = None,
username: Optional[str] = None,
concurrency: Optional[Any] = None,
job_id: Optional[Any] = None,
) -> GraphSamplingResult:
config = ConfigConverter.convert_to_gds_config(
start_nodes=start_nodes,
restart_probability=restart_probability,
sampling_ratio=sampling_ratio,
node_label_stratification=node_label_stratification,
relationship_weight_property=relationship_weight_property,
relationship_types=relationship_types,
node_labels=node_labels,
sudo=sudo,
log_progress=log_progress,
username=username,
concurrency=concurrency,
job_id=job_id,
)

params = CallParameters(
graph_name=graph_name,
from_graph_name=G.name(),
config=config,
)
params.ensure_job_id_in_config()

result = self._query_runner.call_procedure(endpoint="gds.graph.sample.cnarw", params=params).squeeze()
return GraphSamplingResult(**result.to_dict())
Loading