-
Notifications
You must be signed in to change notification settings - Fork 54
AV2 - graph.sample endpoints #942
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2a4f61c
f051ff6
470bb31
d647c45
3895d59
e197129
c5e00b4
c24762a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
from __future__ import annotations | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import List, Optional | ||
|
||
from graphdatascience import Graph | ||
from graphdatascience.procedure_surface.api.base_result import BaseResult | ||
|
||
|
||
class GraphSamplingEndpoints(ABC): | ||
""" | ||
Abstract base class defining the API for graph sampling algorithms algorithm. | ||
""" | ||
|
||
@abstractmethod | ||
def rwr( | ||
self, | ||
G: Graph, | ||
graph_name: str, | ||
start_nodes: Optional[List[int]] = None, | ||
restart_probability: Optional[float] = None, | ||
sampling_ratio: Optional[float] = None, | ||
node_label_stratification: Optional[bool] = None, | ||
relationship_weight_property: Optional[str] = None, | ||
relationship_types: Optional[List[str]] = None, | ||
node_labels: Optional[List[str]] = None, | ||
sudo: Optional[bool] = None, | ||
log_progress: Optional[bool] = None, | ||
username: Optional[str] = None, | ||
concurrency: Optional[int] = None, | ||
job_id: Optional[str] = None, | ||
) -> GraphSamplingResult: | ||
""" | ||
Computes a set of Random Walks with Restart (RWR) for the given graph and stores the result as a new graph in the catalog. | ||
|
||
This method performs a random walk, beginning from a set of nodes (if provided), | ||
where at each step there is a probability to restart back at the original nodes. | ||
The result is turned into a new graph induced by the random walks and stored in the catalog. | ||
Mats-SX marked this conversation as resolved.
Show resolved
Hide resolved
Mats-SX marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Parameters | ||
---------- | ||
G : Graph | ||
The input graph on which the Random Walk with Restart (RWR) will be | ||
performed. | ||
graph_name : str | ||
The name of the new graph in the catalog. | ||
Mats-SX marked this conversation as resolved.
Show resolved
Hide resolved
|
||
start_nodes : list of int, optional | ||
Mats-SX marked this conversation as resolved.
Show resolved
Hide resolved
|
||
A list of node IDs to start the random walk from. If not provided, all | ||
nodes are used as potential starting points. | ||
Mats-SX marked this conversation as resolved.
Show resolved
Hide resolved
|
||
restart_probability : float, optional | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We do not actually specify the GDS defined default value here. None essentially means that the default value defined by GDS will be used. That was a decision I made in order to avoid differences between GDS and the API. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should include all default values, or no default values. it is strange to me that we include some default values, such as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah the none is there, so that Python allows you to omit this parameter. It has nothing to do with documenting an actual GDS default. The default will be set on the GDS side |
||
The probability of restarting back to the original node at each step. | ||
Mats-SX marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Should be a value between 0 and 1. If not specified, a default value is used. | ||
Mats-SX marked this conversation as resolved.
Show resolved
Hide resolved
|
||
sampling_ratio : float, optional | ||
The ratio of nodes to sample during the computation. This value should | ||
be between 0 and 1. If not specified, no sampling is performed. | ||
Mats-SX marked this conversation as resolved.
Show resolved
Hide resolved
|
||
node_label_stratification : bool, optional | ||
If True, the algorithm tries to preserve the label distribution of the original graph in the sampled graph. | ||
relationship_weight_property : str, optional | ||
The name of the property on relationships to use as weights during | ||
the random walk. If not specified, the relationships are treated as | ||
unweighted. | ||
relationship_types : list of str, optional | ||
The relationship types used to select relationships for this algorithm run. | ||
node_labels : list of str, optional | ||
The node labels used to select nodes for this algorithm run. | ||
sudo : bool, optional | ||
Override memory estimation limits. Use with caution as this can lead to | ||
Mats-SX marked this conversation as resolved.
Show resolved
Hide resolved
|
||
memory issues if the estimation is significantly wrong. | ||
Mats-SX marked this conversation as resolved.
Show resolved
Hide resolved
|
||
log_progress : bool, optional | ||
If True, logs the progress of the computation. | ||
Mats-SX marked this conversation as resolved.
Show resolved
Hide resolved
|
||
username : str, optional | ||
The username to attribute the procedure run to | ||
Mats-SX marked this conversation as resolved.
Show resolved
Hide resolved
|
||
concurrency : int, optional | ||
The number of concurrent threads used for the algorithm execution. | ||
job_id : str, optional | ||
An identifier for the job that can be used for monitoring and cancellation | ||
|
||
Returns | ||
------- | ||
GraphSamplingResult | ||
The result of the Random Walk with Restart (RWR), including the sampled | ||
nodes and their scores. | ||
Mats-SX marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def cnarw( | ||
self, | ||
G: Graph, | ||
graph_name: str, | ||
start_nodes: Optional[List[int]] = None, | ||
restart_probability: Optional[float] = None, | ||
sampling_ratio: Optional[float] = None, | ||
node_label_stratification: Optional[bool] = None, | ||
relationship_weight_property: Optional[str] = None, | ||
relationship_types: Optional[List[str]] = None, | ||
node_labels: Optional[List[str]] = None, | ||
sudo: Optional[bool] = None, | ||
log_progress: Optional[bool] = None, | ||
username: Optional[str] = None, | ||
concurrency: Optional[int] = None, | ||
job_id: Optional[str] = None, | ||
) -> GraphSamplingResult: | ||
""" | ||
Computes a set of Random Walks with Restart (RWR) for the given graph and stores the result as a new graph in the catalog. | ||
|
||
This method performs a random walk, beginning from a set of nodes (if provided), | ||
where at each step there is a probability to restart back at the original nodes. | ||
The result is turned into a new graph induced by the random walks and stored in the catalog. | ||
Mats-SX marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Parameters | ||
---------- | ||
G : Graph | ||
The input graph on which the Random Walk with Restart (RWR) will be | ||
performed. | ||
graph_name : str | ||
The name of the new graph in the catalog. | ||
start_nodes : list of int, optional | ||
A list of node IDs to start the random walk from. If not provided, all | ||
nodes are used as potential starting points. | ||
restart_probability : float, optional | ||
The probability of restarting back to the original node at each step. | ||
Should be a value between 0 and 1. If not specified, a default value is used. | ||
sampling_ratio : float, optional | ||
The ratio of nodes to sample during the computation. This value should | ||
be between 0 and 1. If not specified, no sampling is performed. | ||
node_label_stratification : bool, optional | ||
If True, the algorithm tries to preserve the label distribution of the original graph in the sampled graph. | ||
relationship_weight_property : str, optional | ||
The name of the property on relationships to use as weights during | ||
the random walk. If not specified, the relationships are treated as | ||
unweighted. | ||
relationship_types : list of str, optional | ||
The relationship types used to select relationships for this algorithm run. | ||
node_labels : list of str, optional | ||
The node labels used to select nodes for this algorithm run. | ||
sudo : bool, optional | ||
Override memory estimation limits. Use with caution as this can lead to | ||
Mats-SX marked this conversation as resolved.
Show resolved
Hide resolved
|
||
memory issues if the estimation is significantly wrong. | ||
log_progress : bool, optional | ||
If True, logs the progress of the computation. | ||
username : str, optional | ||
The username to attribute the procedure run to | ||
concurrency : int, optional | ||
The number of concurrent threads used for the algorithm execution. | ||
job_id : str, optional | ||
An identifier for the job that can be used for monitoring and cancellation | ||
|
||
Returns | ||
------- | ||
GraphSamplingResult | ||
The result of the Random Walk with Restart (RWR), including the sampled | ||
nodes and their scores. | ||
Mats-SX marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
pass | ||
|
||
|
||
class GraphSamplingResult(BaseResult): | ||
graph_name: str | ||
from_graph_name: str | ||
node_count: int | ||
relationship_count: int | ||
start_node_count: int | ||
project_millis: int |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any, List, Optional | ||
|
||
from graphdatascience import Graph | ||
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient | ||
from graphdatascience.arrow_client.v2.job_client import JobClient | ||
from graphdatascience.procedure_surface.api.graph_sampling_endpoints import ( | ||
GraphSamplingEndpoints, | ||
GraphSamplingResult, | ||
) | ||
from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter | ||
|
||
|
||
class GraphSamplingArrowEndpoints(GraphSamplingEndpoints): | ||
def __init__(self, arrow_client: AuthenticatedArrowClient): | ||
self._arrow_client = arrow_client | ||
|
||
def rwr( | ||
self, | ||
G: Graph, | ||
graph_name: str, | ||
start_nodes: Optional[List[int]] = None, | ||
restart_probability: Optional[float] = None, | ||
sampling_ratio: Optional[float] = None, | ||
node_label_stratification: Optional[bool] = None, | ||
relationship_weight_property: Optional[str] = None, | ||
relationship_types: Optional[List[str]] = None, | ||
node_labels: Optional[List[str]] = None, | ||
sudo: Optional[bool] = None, | ||
log_progress: Optional[bool] = None, | ||
username: Optional[str] = None, | ||
concurrency: Optional[Any] = None, | ||
job_id: Optional[Any] = None, | ||
) -> GraphSamplingResult: | ||
config = ConfigConverter.convert_to_gds_config( | ||
from_graph_name=G.name(), | ||
graph_name=graph_name, | ||
start_nodes=start_nodes, | ||
restart_probability=restart_probability, | ||
sampling_ratio=sampling_ratio, | ||
node_label_stratification=node_label_stratification, | ||
relationship_weight_property=relationship_weight_property, | ||
relationship_types=relationship_types, | ||
node_labels=node_labels, | ||
sudo=sudo, | ||
log_progress=log_progress, | ||
username=username, | ||
concurrency=concurrency, | ||
job_id=job_id, | ||
) | ||
|
||
job_id = JobClient.run_job_and_wait(self._arrow_client, "v2/graph.sample.rwr", config) | ||
|
||
return GraphSamplingResult(**JobClient.get_summary(self._arrow_client, job_id)) | ||
|
||
def cnarw( | ||
self, | ||
G: Graph, | ||
graph_name: str, | ||
start_nodes: Optional[List[int]] = None, | ||
restart_probability: Optional[float] = None, | ||
sampling_ratio: Optional[float] = None, | ||
node_label_stratification: Optional[bool] = None, | ||
relationship_weight_property: Optional[str] = None, | ||
relationship_types: Optional[List[str]] = None, | ||
node_labels: Optional[List[str]] = None, | ||
sudo: Optional[bool] = None, | ||
log_progress: Optional[bool] = None, | ||
username: Optional[str] = None, | ||
concurrency: Optional[Any] = None, | ||
job_id: Optional[Any] = None, | ||
) -> GraphSamplingResult: | ||
config = ConfigConverter.convert_to_gds_config( | ||
from_graph_name=G.name(), | ||
graph_name=graph_name, | ||
start_nodes=start_nodes, | ||
restart_probability=restart_probability, | ||
sampling_ratio=sampling_ratio, | ||
node_label_stratification=node_label_stratification, | ||
relationship_weight_property=relationship_weight_property, | ||
relationship_types=relationship_types, | ||
node_labels=node_labels, | ||
sudo=sudo, | ||
log_progress=log_progress, | ||
username=username, | ||
concurrency=concurrency, | ||
job_id=job_id, | ||
) | ||
|
||
job_id = JobClient.run_job_and_wait(self._arrow_client, "v2/graph.sample.cnarw", config) | ||
|
||
return GraphSamplingResult(**JobClient.get_summary(self._arrow_client, job_id)) |
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,98 @@ | ||||||||||||||
from __future__ import annotations | ||||||||||||||
|
||||||||||||||
from typing import Any, List, Optional | ||||||||||||||
|
||||||||||||||
from ...call_parameters import CallParameters | ||||||||||||||
from ...graph.graph_object import Graph | ||||||||||||||
from ...query_runner.query_runner import QueryRunner | ||||||||||||||
from ..api.graph_sampling_endpoints import GraphSamplingEndpoints, GraphSamplingResult | ||||||||||||||
from ..utils.config_converter import ConfigConverter | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
class GraphSamplingCypherEndpoints(GraphSamplingEndpoints): | ||||||||||||||
def __init__(self, query_runner: QueryRunner): | ||||||||||||||
self._query_runner = query_runner | ||||||||||||||
|
||||||||||||||
def rwr( | ||||||||||||||
self, | ||||||||||||||
G: Graph, | ||||||||||||||
graph_name: str, | ||||||||||||||
start_nodes: Optional[List[int]] = None, | ||||||||||||||
restart_probability: Optional[float] = None, | ||||||||||||||
sampling_ratio: Optional[float] = None, | ||||||||||||||
node_label_stratification: Optional[bool] = None, | ||||||||||||||
relationship_weight_property: Optional[str] = None, | ||||||||||||||
relationship_types: Optional[List[str]] = None, | ||||||||||||||
node_labels: Optional[List[str]] = None, | ||||||||||||||
sudo: Optional[bool] = None, | ||||||||||||||
log_progress: Optional[bool] = None, | ||||||||||||||
username: Optional[str] = None, | ||||||||||||||
concurrency: Optional[Any] = None, | ||||||||||||||
job_id: Optional[Any] = None, | ||||||||||||||
) -> GraphSamplingResult: | ||||||||||||||
config = ConfigConverter.convert_to_gds_config( | ||||||||||||||
start_nodes=start_nodes, | ||||||||||||||
restart_probability=restart_probability, | ||||||||||||||
sampling_ratio=sampling_ratio, | ||||||||||||||
node_label_stratification=node_label_stratification, | ||||||||||||||
relationship_weight_property=relationship_weight_property, | ||||||||||||||
relationship_types=relationship_types, | ||||||||||||||
node_labels=node_labels, | ||||||||||||||
sudo=sudo, | ||||||||||||||
log_progress=log_progress, | ||||||||||||||
username=username, | ||||||||||||||
concurrency=concurrency, | ||||||||||||||
job_id=job_id, | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
params = CallParameters( | ||||||||||||||
graph_name=graph_name, | ||||||||||||||
from_graph_name=G.name(), | ||||||||||||||
config=config, | ||||||||||||||
Comment on lines
+49
to
+51
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess it doesn't matter functionally here, but we could try to be consistent if we like
Suggested change
|
||||||||||||||
) | ||||||||||||||
params.ensure_job_id_in_config() | ||||||||||||||
|
||||||||||||||
result = self._query_runner.call_procedure(endpoint="gds.graph.sample.rwr", params=params).squeeze() | ||||||||||||||
return GraphSamplingResult(**result.to_dict()) | ||||||||||||||
|
||||||||||||||
def cnarw( | ||||||||||||||
self, | ||||||||||||||
G: Graph, | ||||||||||||||
graph_name: str, | ||||||||||||||
start_nodes: Optional[List[int]] = None, | ||||||||||||||
restart_probability: Optional[float] = None, | ||||||||||||||
sampling_ratio: Optional[float] = None, | ||||||||||||||
node_label_stratification: Optional[bool] = None, | ||||||||||||||
relationship_weight_property: Optional[str] = None, | ||||||||||||||
relationship_types: Optional[List[str]] = None, | ||||||||||||||
node_labels: Optional[List[str]] = None, | ||||||||||||||
sudo: Optional[bool] = None, | ||||||||||||||
log_progress: Optional[bool] = None, | ||||||||||||||
username: Optional[str] = None, | ||||||||||||||
concurrency: Optional[Any] = None, | ||||||||||||||
job_id: Optional[Any] = None, | ||||||||||||||
) -> GraphSamplingResult: | ||||||||||||||
config = ConfigConverter.convert_to_gds_config( | ||||||||||||||
start_nodes=start_nodes, | ||||||||||||||
restart_probability=restart_probability, | ||||||||||||||
sampling_ratio=sampling_ratio, | ||||||||||||||
node_label_stratification=node_label_stratification, | ||||||||||||||
relationship_weight_property=relationship_weight_property, | ||||||||||||||
relationship_types=relationship_types, | ||||||||||||||
node_labels=node_labels, | ||||||||||||||
sudo=sudo, | ||||||||||||||
log_progress=log_progress, | ||||||||||||||
username=username, | ||||||||||||||
concurrency=concurrency, | ||||||||||||||
job_id=job_id, | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
params = CallParameters( | ||||||||||||||
graph_name=graph_name, | ||||||||||||||
from_graph_name=G.name(), | ||||||||||||||
config=config, | ||||||||||||||
) | ||||||||||||||
params.ensure_job_id_in_config() | ||||||||||||||
|
||||||||||||||
result = self._query_runner.call_procedure(endpoint="gds.graph.sample.cnarw", params=params).squeeze() | ||||||||||||||
return GraphSamplingResult(**result.to_dict()) |
Uh oh!
There was an error while loading. Please reload this page.