diff --git a/changelog.md b/changelog.md index 61bb9e5..d00de45 100644 --- a/changelog.md +++ b/changelog.md @@ -3,20 +3,29 @@ ## Breaking changes - Do not automatically derive size and caption for `from_neo4j` and `from_gql_create`. Use the `size_property` and `node_caption` parameters to explicitly configure them. +- Change API of integrations to only provide basic parameters. Any further configuration should happen ons the Visualization Graph object: + - `from_gds` + - Drop parameters size_property, node_radius_min_max. `Use VG.resize_nodes(property=...)` instead + - rename additional_node_properties to node_properties + - Don't derive fields from properties. Use `VG.map_properties_to_fields` instead + - `from_pandas` + - Drop `node_radius_min_max` parameter. `VG.resize_nodes(...)` instead ## New features -- Allow to include db node properties in addition to the properties in the GDS Graph. Specify `additional_db_node_properties` in `from_gds`. - +- Allow to include db node properties in addition to the properties in the GDS Graph. Specify `db_node_properties` in `from_gds`. ## Bug fixes - fixed a bug in `from_neo4j`, where the node size would always be set to the `size` property. - fixed a bug in `from_neo4j`, where the node caption would always be set to the `caption` property. +- Color nodes in `from_snowflake` only if there are less than 13 node tables used. This avoids reuse of colors for different tables. ## Improvements - Validate fields of a node and relationship not only at construction but also on assignment. - Allow resizing per node property such as `VG.resize_nodes(property="score")`. +- Color nodes by label in `from_gds`. +- Add `table` property to nodes and relationships created by `from_snowflake`. This is used as a default caption. ## Other changes diff --git a/docs/source/conf.py b/docs/source/conf.py index d19093c..2fa9ee2 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -33,7 +33,10 @@ # -- Options for autodoc extension ------------------------------------------- autodoc_typehints = "description" +autoclass_content = "both" +# -- Options for napoleon extension ------------------------------------------- +napoleon_use_admonition_for_examples = True # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output diff --git a/docs/source/integration.rst b/docs/source/integration.rst index 3a8a6c2..d719885 100644 --- a/docs/source/integration.rst +++ b/docs/source/integration.rst @@ -35,28 +35,17 @@ The ``from_dfs`` method takes two mandatory positional parameters: * A Pandas ``DataFrame``, or iterable (eg. list) of DataFrames representing the nodes of the graph. The rows of the DataFrame(s) should represent the individual nodes, and the columns should represent the node IDs and attributes. - If a column shares the name with a field of :doc:`Node <./api-reference/node>`, the values it contains will be set - on corresponding nodes under that field name. - Otherwise, the column name will be a key in each node's `properties` dictionary, that maps to the node's corresponding + The node ID will be set on the :doc:`Node <./api-reference/node>`, + Other columns will be a key in each node's `properties` dictionary, that maps to the node's corresponding value in the column. If the graph has no node properties, the nodes can be derived from the relationships DataFrame alone. * A Pandas ``DataFrame``, or iterable (eg. list) of DataFrames representing the relationships of the graph. The rows of the DataFrame(s) should represent the individual relationships, and the columns should represent the relationship IDs and attributes. - If a column shares the name with a field of :doc:`Relationship <./api-reference/relationship>`, the values it contains - will be set on corresponding relationships under that field name. - Otherwise, the column name will be a key in each node's `properties` dictionary, that maps to the node's corresponding + The relationship id, source and target node IDs will be set on the :doc:`Relationship <./api-reference/relationship>`. + Other columns will be a key in each relationship's `properties` dictionary, that maps to the relationship's corresponding value in the column. -``from_dfs`` also takes an optional property, ``node_radius_min_max``, that can be used (and is used by default) to -scale the node sizes for the visualization. -It is a tuple of two numbers, representing the radii (sizes) in pixels of the smallest and largest nodes respectively in -the visualization. -The node sizes will be scaled such that the smallest node will have the size of the first value, and the largest node -will have the size of the second value. -The other nodes will be scaled linearly between these two values according to their relative size. -This can be useful if node sizes vary a lot, or are all very small or very big. - Example ~~~~~~~ @@ -111,25 +100,13 @@ If you want to have more control of the sampling, such as choosing a specific st a `sampling `_ method yourself and passing the resulting projection to ``from_gds``. -We can also provide an optional ``size_property`` parameter, which should refer to a node property of the projection, -and will be used to determine the sizes of the nodes in the visualization. - -The ``additional_node_properties`` parameter is also optional, and should be a list of additional node properties of the +The ``node_properties`` parameter is also optional, and should be a list of additional node properties of the projection that you want to include in the visualization. The default is ``None``, which means that all properties of the nodes in the projection will be included. Apart from being visible through on-hover tooltips, these properties could be used to color the nodes, or give captions to them in the visualization, or simply included in the nodes' ``Node.properties`` maps without directly impacting the visualization. -If you want to include node properties stored at the Neo4j database, you can include them in the visualization by using the `additional_db_node_properties` parameter. - -The last optional property, ``node_radius_min_max``, can be used (and is used by default) to scale the node sizes for -the visualization. -It is a tuple of two numbers, representing the radii (sizes) in pixels of the smallest and largest nodes respectively in -the visualization. -The node sizes will be scaled such that the smallest node will have the size of the first value, and the largest node -will have the size of the second value. -The other nodes will be scaled linearly between these two values according to their relative size. -This can be useful if node sizes vary a lot, or are all very small or very big. +If you want to include node properties stored at the Neo4j database, you can include them in the visualization by using the `db_node_properties` parameter. Example @@ -137,7 +114,7 @@ Example In this small example, we import a graph projection from the GDS library, that has the node properties "pagerank" and "componentId". -We use the "pagerank" property to determine the size of the nodes, and the "componentId" property to color the nodes. +We use the "pagerank" property to compute the size of the nodes, and the "componentId" property to color the nodes. .. code-block:: python @@ -156,9 +133,10 @@ We use the "pagerank" property to determine the size of the nodes, and the "comp VG = from_gds( gds, G, - size_property="pagerank", - additional_node_properties=["componentId"], + node_properties=["componentId"], ) + # Size the nodes by the `pagerank` property + VG.resize_nodes(property="pagerank") # Color the nodes by the `componentId` property, so that the nodes are # colored by the connected component they belong to diff --git a/python-wrapper/src/neo4j_viz/gds.py b/python-wrapper/src/neo4j_viz/gds.py index 6297375..10c1c8b 100644 --- a/python-wrapper/src/neo4j_viz/gds.py +++ b/python-wrapper/src/neo4j_viz/gds.py @@ -8,6 +8,8 @@ import pandas as pd from graphdatascience import Graph, GraphDataScience +from neo4j_viz.colors import NEO4J_COLORS_DISCRETE, ColorSpace + from .pandas import _from_dfs from .visualization_graph import VisualizationGraph @@ -55,18 +57,20 @@ def _fetch_rel_dfs(gds: GraphDataScience, G: Graph) -> list[pd.DataFrame]: def from_gds( gds: GraphDataScience, G: Graph, - size_property: Optional[str] = None, - additional_node_properties: Optional[list[str]] = None, - additional_db_node_properties: Optional[list[str]] = None, - node_radius_min_max: Optional[tuple[float, float]] = (3, 60), + node_properties: Optional[list[str]] = None, + db_node_properties: Optional[list[str]] = None, max_node_count: int = 10_000, ) -> VisualizationGraph: """ Create a VisualizationGraph from a GraphDataScience object and a Graph object. - All `additional_node_properties` will be included in the visualization graph. - If the properties are named as the fields of the `Node` class, they will be included as top level fields of the - created `Node` objects. Otherwise, they will be included in the `properties` dictionary. + By default: + + * the caption of a node will be based on its `labels`. + * the caption of a relationship will be based on its `relationshipType`. + * the color of nodes will be set based on their label, unless there are more than 12 unique labels. + + All `node_properties` and `db_node_properties` will be included in the visualization graph under the `properties` field. Additionally, a new "labels" node property will be added, containing the node labels of the node. Similarly for relationships, a new "relationshipType" property will be added. @@ -76,49 +80,36 @@ def from_gds( GraphDataScience object. G : Graph Graph object. - size_property : str, optional - Property to use for node size, by default None. - additional_node_properties : list[str], optional + node_properties : list[str], optional Additional properties to include in the visualization node, by default None which means that all node properties from the Graph will be fetched. - additional_db_node_properties : list[str], optional + db_node_properties : list[str], optional Additional node properties to fetch from the database, by default None. Only works if the graph was projected from the database. - node_radius_min_max : tuple[float, float], optional - Minimum and maximum node radius, by default (3, 60). - To avoid tiny or huge nodes in the visualization, the node sizes are scaled to fit in the given range. max_node_count : int, optional The maximum number of nodes to fetch from the graph. The graph will be sampled using random walk with restarts if its node count exceeds this number. """ - if additional_db_node_properties is None: - additional_db_node_properties = [] + if db_node_properties is None: + db_node_properties = [] node_properties_from_gds = G.node_properties() assert isinstance(node_properties_from_gds, pd.Series) actual_node_properties: dict[str, list[str]] = cast(dict[str, list[str]], node_properties_from_gds.to_dict()) all_actual_node_properties = list(chain.from_iterable(actual_node_properties.values())) - if size_property is not None: - if size_property not in all_actual_node_properties: - raise ValueError(f"There is no node property '{size_property}' in graph '{G.name()}'") - node_properties_by_label_sets: dict[str, set[str]] = dict() - if additional_node_properties is None: + if node_properties is None: node_properties_by_label_sets = {k: set(v) for k, v in actual_node_properties.items()} else: - for prop in additional_node_properties: + for prop in node_properties: if prop not in all_actual_node_properties: raise ValueError(f"There is no node property '{prop}' in graph '{G.name()}'") for label, props in actual_node_properties.items(): node_properties_by_label_sets[label] = { - prop for prop in actual_node_properties[label] if prop in additional_node_properties + prop for prop in actual_node_properties[label] if prop in node_properties } - if size_property is not None: - for label, label_props in node_properties_by_label_sets.items(): - label_props.add(size_property) - node_properties_by_label = {k: list(v) for k, v in node_properties_by_label_sets.items()} node_count = G.node_count() @@ -143,7 +134,7 @@ def from_gds( props.append(property_name) node_dfs = _fetch_node_dfs( - gds, G_fetched, node_properties_by_label, G_fetched.node_labels(), additional_db_node_properties + gds, G_fetched, node_properties_by_label, G_fetched.node_labels(), db_node_properties ) if property_name is not None: for df in node_dfs.values(): @@ -161,13 +152,6 @@ def from_gds( df.drop(columns=[property_name], inplace=True) node_props_df = pd.concat(node_dfs.values(), ignore_index=True, axis=0).drop_duplicates() - if size_property is not None: - if "size" in all_actual_node_properties and size_property != "size": - node_props_df.rename(columns={"size": "__size"}, inplace=True) - if additional_node_properties is not None and size_property not in additional_node_properties: - node_props_df.rename(columns={size_property: "size"}, inplace=True) - else: - node_props_df["size"] = node_props_df[size_property] for lbl, df in node_dfs.items(): if "labels" in all_actual_node_properties: @@ -179,22 +163,22 @@ def from_gds( node_df = node_props_df.merge(node_labels_df, on="nodeId") - if "caption" not in all_actual_node_properties: - node_df["caption"] = node_df["labels"].astype(str) + try: + VG = _from_dfs(node_df, rel_dfs, dropna=True) - for rel_df in rel_dfs: - if "caption" not in rel_df.columns: - rel_df["caption"] = rel_df["relationshipType"] + for node in VG.nodes: + node.caption = str(node.properties.get("labels")) + for rel in VG.relationships: + rel.caption = rel.properties.get("relationshipType") - try: - return _from_dfs( - node_df, rel_dfs, node_radius_min_max=node_radius_min_max, rename_properties={"__size": "size"}, dropna=True - ) + number_of_colors = node_df["labels"].drop_duplicates().count() + if number_of_colors <= len(NEO4J_COLORS_DISCRETE): + VG.color_nodes(property="labels", color_space=ColorSpace.DISCRETE) + + return VG except ValueError as e: err_msg = str(e) if "column" in err_msg: err_msg = err_msg.replace("column", "property") - if ("'size'" in err_msg) and (size_property is not None): - err_msg = err_msg.replace("'size'", f"'{size_property}'") raise ValueError(err_msg) raise e diff --git a/python-wrapper/src/neo4j_viz/node.py b/python-wrapper/src/neo4j_viz/node.py index 03dbc29..db0ad5f 100644 --- a/python-wrapper/src/neo4j_viz/node.py +++ b/python-wrapper/src/neo4j_viz/node.py @@ -98,3 +98,10 @@ def all_validation_aliases(exempted_fields: Optional[list[str]] = None) -> set[s by_field = [v.validation_alias.choices for k, v in Node.model_fields.items() if k not in exempted_fields] # type: ignore return {str(alias) for aliases in by_field for alias in aliases} + + @staticmethod + def basic_fields_validation_aliases() -> set[str]: + mandatory_fields = ["id"] + by_field = [v.validation_alias.choices for k, v in Node.model_fields.items() if k in mandatory_fields] # type: ignore + + return {str(alias) for aliases in by_field for alias in aliases} diff --git a/python-wrapper/src/neo4j_viz/pandas.py b/python-wrapper/src/neo4j_viz/pandas.py index 8ed41c3..f7f01c0 100644 --- a/python-wrapper/src/neo4j_viz/pandas.py +++ b/python-wrapper/src/neo4j_viz/pandas.py @@ -29,8 +29,6 @@ def _parse_validation_error(e: ValidationError, entity_type: type[BaseModel]) -> def _from_dfs( node_dfs: Optional[DFS_TYPE] = None, rel_dfs: Optional[DFS_TYPE] = None, - node_radius_min_max: Optional[tuple[float, float]] = (3, 60), - rename_properties: Optional[dict[str, str]] = None, dropna: bool = False, ) -> VisualizationGraph: if node_dfs is None and rel_dfs is None: @@ -39,29 +37,21 @@ def _from_dfs( if rel_dfs is None: relationships = [] else: - relationships = _parse_relationships(rel_dfs, rename_properties=rename_properties, dropna=dropna) + relationships = _parse_relationships(rel_dfs, dropna=dropna) if node_dfs is None: - has_size = False node_ids = set() for rel in relationships: node_ids.add(rel.source) node_ids.add(rel.target) nodes = [Node(id=id) for id in node_ids] else: - nodes, has_size = _parse_nodes(node_dfs, rename_properties=rename_properties, dropna=dropna) + nodes = _parse_nodes(node_dfs, dropna=dropna) - VG = VisualizationGraph(nodes=nodes, relationships=relationships) + return VisualizationGraph(nodes=nodes, relationships=relationships) - if node_radius_min_max is not None and has_size: - VG.resize_nodes(node_radius_min_max=node_radius_min_max) - return VG - - -def _parse_nodes( - node_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]], dropna: bool = False -) -> tuple[list[Node], bool]: +def _parse_nodes(node_dfs: DFS_TYPE, dropna: bool = False) -> list[Node]: if isinstance(node_dfs, DataFrame): node_dfs_iter: Iterable[DataFrame] = [node_dfs] elif node_dfs is None: @@ -69,37 +59,31 @@ def _parse_nodes( else: node_dfs_iter = node_dfs - all_node_field_aliases = Node.all_validation_aliases() + basic_node_fields_aliases = Node.basic_fields_validation_aliases() - has_size = True nodes = [] for node_df in node_dfs_iter: - has_size &= "size" in [c.lower() for c in node_df.columns] for _, row in node_df.iterrows(): if dropna: row = row.dropna(inplace=False) - top_level = {} + mandatory_fields = {} properties = {} for key, value in row.to_dict().items(): - if key in all_node_field_aliases: - top_level[key] = value + if key in basic_node_fields_aliases: + mandatory_fields[key] = value else: - if rename_properties and key in rename_properties: - key = rename_properties[key] properties[key] = value try: - nodes.append(Node(**top_level, properties=properties)) + nodes.append(Node(**mandatory_fields, properties=properties)) except ValidationError as e: _parse_validation_error(e, Node) - return nodes, has_size + return nodes -def _parse_relationships( - rel_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]], dropna: bool = False -) -> list[Relationship]: - all_rel_field_aliases = Relationship.all_validation_aliases() +def _parse_relationships(rel_dfs: DFS_TYPE, dropna: bool = False) -> list[Relationship]: + basic_rel_field_aliases = Relationship.basic_fields_validation_aliases() if isinstance(rel_dfs, DataFrame): rel_dfs_iter: Iterable[DataFrame] = [rel_dfs] @@ -111,18 +95,16 @@ def _parse_relationships( for _, row in rel_df.iterrows(): if dropna: row = row.dropna(inplace=False) - top_level = {} + mandatory_fields = {} properties = {} for key, value in row.to_dict().items(): - if key in all_rel_field_aliases: - top_level[key] = value + if key in basic_rel_field_aliases: + mandatory_fields[key] = value else: - if rename_properties and key in rename_properties: - key = rename_properties[key] properties[key] = value try: - relationships.append(Relationship(**top_level, properties=properties)) + relationships.append(Relationship(**mandatory_fields, properties=properties)) except ValidationError as e: _parse_validation_error(e, Relationship) @@ -132,14 +114,17 @@ def _parse_relationships( def from_dfs( node_dfs: Optional[DFS_TYPE] = None, rel_dfs: Optional[DFS_TYPE] = None, - node_radius_min_max: Optional[tuple[float, float]] = (3, 60), ) -> VisualizationGraph: """ Create a VisualizationGraph from pandas DataFrames representing a graph. All columns will be included in the visualization graph. - If the columns are named as the fields of the `Node` or `Relationship` classes, they will be included as - top level fields of the respective objects. Otherwise, they will be included in the `properties` dictionary. + The following columns will be treated as fields: + + * `id` for the node_dfs + * `id`, `source`, `target` for the rel_dfs + + Other columns will be included in the `properties` dictionary on the respective node or relationship objects. Parameters ---------- @@ -149,9 +134,7 @@ def from_dfs( rel_dfs: Optional[Union[DataFrame, Iterable[DataFrame]]], optional DataFrame or iterable of DataFrames containing relationship data. If None, no relationships will be created. - node_radius_min_max : tuple[float, float], optional - Minimum and maximum node radius. - To avoid tiny or huge nodes in the visualization, the node sizes are scaled to fit in the given range. + """ - return _from_dfs(node_dfs, rel_dfs, node_radius_min_max, dropna=False) + return _from_dfs(node_dfs, rel_dfs, dropna=False) diff --git a/python-wrapper/src/neo4j_viz/relationship.py b/python-wrapper/src/neo4j_viz/relationship.py index f72fb66..0498e09 100644 --- a/python-wrapper/src/neo4j_viz/relationship.py +++ b/python-wrapper/src/neo4j_viz/relationship.py @@ -109,3 +109,14 @@ def all_validation_aliases(exempted_fields: Optional[list[str]] = None) -> set[s ] return {str(alias) for aliases in by_field for alias in aliases} + + @staticmethod + def basic_fields_validation_aliases() -> set[str]: + basic_fields = ["id", "source", "target"] + by_field = [ + v.validation_alias.choices # type: ignore + for k, v in Relationship.model_fields.items() + if k in basic_fields + ] + + return {str(alias) for aliases in by_field for alias in aliases} diff --git a/python-wrapper/src/neo4j_viz/snowflake.py b/python-wrapper/src/neo4j_viz/snowflake.py index 8e5e181..4332601 100644 --- a/python-wrapper/src/neo4j_viz/snowflake.py +++ b/python-wrapper/src/neo4j_viz/snowflake.py @@ -37,7 +37,7 @@ ) from neo4j_viz import VisualizationGraph -from neo4j_viz.colors import ColorSpace +from neo4j_viz.colors import NEO4J_COLORS_DISCRETE, ColorSpace from neo4j_viz.pandas import from_dfs @@ -312,43 +312,39 @@ def _map_tables( def from_snowflake( session: Session, project_config: dict[str, Any], - node_radius_min_max: Optional[tuple[float, float]] = (3, 60), ) -> VisualizationGraph: """ Create a VisualizationGraph from Snowflake tables based on a project configuration. + By default: + * The caption of the nodes will be set to the table name. + * The caption of the relationships will be set to the table name. + * The color of the nodes will be set based on the caption, unless there are more than 12 node tables used. + Otherwise, columns will be included as properties on the nodes and relationships. + Args: - session (Session): The Snowflake session to use for querying the tables. - project_config (dict[str, Any]): The project configuration dictionary defining node and relationship tables. - node_radius_min_max (Optional[tuple[float, float]], optional): Tuple defining the min and max radius for nodes. Defaults to (3, 60). + session (Session): An active Snowflake session. + project_config (dict[str, Any]): A dictionary representing the project configuration. Returns: - VisualizationGraph: The constructed visualization graph. + VisualizationGraph: The resulting visualization graph. """ project_model = VizProjectConfig.model_validate(project_config, strict=False, context={"session": session}) node_dfs, rel_dfs, rel_table_names = _map_tables(session, project_model) - node_caption_present = False - for node_df in node_dfs: - if "CAPTION" in node_df.columns: - node_caption_present = True - break - - if not node_caption_present: - for i, node_df in enumerate(node_dfs): - node_df["caption"] = project_model.nodeTables[i].split(".")[-1] - - rel_caption_present = False - for rel_df in rel_dfs: - if "CAPTION" in rel_df.columns: - rel_caption_present = True - break + for i, node_df in enumerate(node_dfs): + node_df["table"] = project_model.nodeTables[i].split(".")[-1] + for i, rel_df in enumerate(rel_dfs): + rel_df["table"] = rel_table_names[i].split(".")[-1] - if not rel_caption_present: - for i, rel_df in enumerate(rel_dfs): - rel_df["caption"] = rel_table_names[i].split(".")[-1] + VG = from_dfs(node_dfs, rel_dfs) - VG = from_dfs(node_dfs, rel_dfs, node_radius_min_max) + for node in VG.nodes: + node.caption = node.properties.get("table") + for rel in VG.relationships: + rel.caption = rel.properties.get("table") - VG.color_nodes(field="caption", color_space=ColorSpace.DISCRETE) + number_of_colors = node_df["table"].drop_duplicates().count() + if number_of_colors <= len(NEO4J_COLORS_DISCRETE): + VG.color_nodes(field="caption", color_space=ColorSpace.DISCRETE) return VG diff --git a/python-wrapper/src/neo4j_viz/visualization_graph.py b/python-wrapper/src/neo4j_viz/visualization_graph.py index a9661bc..ecc40ab 100644 --- a/python-wrapper/src/neo4j_viz/visualization_graph.py +++ b/python-wrapper/src/neo4j_viz/visualization_graph.py @@ -22,9 +22,14 @@ from .relationship import Relationship +# TODO helper for map properties to fields. helper for set caption (simplicity) class VisualizationGraph: """ A graph to visualize. + + The `VisualizationGraph` class represents a collection of nodes and relationships that can be + rendered as an interactive graph visualization. You can customize the appearance of nodes and + relationships by setting their properties, colors, sizes, and other visual attributes. """ #: "The nodes in the graph" @@ -33,15 +38,48 @@ class VisualizationGraph: relationships: list[Relationship] def __init__(self, nodes: list[Node], relationships: list[Relationship]) -> None: - """ " - Create a new `VisualizationGraph`. - + """ Parameters ---------- - nodes: + nodes : list[Node] The nodes in the graph. - relationships: + relationships : list[Relationship] The relationships in the graph. + + Examples + -------- + Basic usage with nodes and relationships: + + >>> from neo4j_viz import Node, Relationship, VisualizationGraph + >>> nodes = [ + ... Node(id="1", properties={"name": "Alice", "age": 30}), + ... Node(id="2", properties={"name": "Bob", "age": 25}), + ... ] + >>> relationships = [ + ... Relationship(id="r1", source="1", target="2", properties={"type": "KNOWS"}) + ... ] + >>> VG = VisualizationGraph(nodes=nodes, relationships=relationships) + + Setting a node field such as captions from properties: + + >>> # Set caption from a specific property + >>> for node in VG.nodes: + ... node.caption = node.properties.get("name") + + Setting a relationship field such as type from properties: + + >>> # Set relationship caption from property + >>> for rel in VG.relationships: + ... rel.caption = rel.properties.get("type") + + Using built-in helper methods: + + >>> # Use the color_nodes method for automatic coloring + >>> VG.color_nodes(property="age", color_space=ColorSpace.CONTINUOUS) + >>> + >>> # Use resize_nodes for automatic sizing + >>> VG.resize_nodes(property="degree", node_radius_min_max=(10, 50)) + """ self.nodes = nodes self.relationships = relationships @@ -90,6 +128,12 @@ def render( The maximum allowed number of nodes to render. show_hover_tooltip: Whether to show an info tooltip when hovering over nodes and relationships. + + + Example + ------- + Basic rendering of a VisualizationGraph: + >>> from neo4j_viz import Node, Relationship, VisualizationGraph """ num_nodes = len(self.nodes) @@ -280,6 +324,29 @@ def color_nodes( colors are assigned based on unique field/property values or a gradient of the values of the field/property. override: Whether to override existing colors of the nodes, if they have any. + + Examples + -------- + + Given a VisualizationGraph `VG`: + + >>> nodes = [ + ... Node(id="0", properties={"label": "Person", "score": 10}), + ... Node(id="1", properties={"label": "Person", "score": 20}), + ... ] + >>> VG = VisualizationGraph(nodes=nodes) + + Color nodes based on a discrete field such as "label": + >>> VG.color_nodes(field="label", color_space=ColorSpace.DISCRETE) + + Color nodes based on a continuous field such as "score": + + >>> VG.color_nodes(field="score", color_space=ColorSpace.CONTINUOUS) + + Color nodes based on a custom colors such as from palettable: + + >>> from palettable.wesanderson import Moonrise1_5 # type: ignore[import-untyped] + >>> VG.color_nodes(field="label", colors=Moonrise1_5.colors) """ if not ((field is None) ^ (property is None)): raise ValueError( diff --git a/python-wrapper/tests/test_gds.py b/python-wrapper/tests/test_gds.py index ce9e1c5..7542f28 100644 --- a/python-wrapper/tests/test_gds.py +++ b/python-wrapper/tests/test_gds.py @@ -3,7 +3,6 @@ import pandas as pd import pytest -from pytest_mock import MockerFixture from neo4j_viz import Node @@ -21,79 +20,13 @@ def db_setup(gds: Any) -> Generator[None, None, None]: gds.run_cypher("MATCH (n:_CI_A|_CI_B) DETACH DELETE n") -@pytest.mark.requires_neo4j_and_gds -def test_from_gds_integration_size(gds: Any) -> None: - from neo4j_viz.gds import from_gds - - nodes = pd.DataFrame( - { - "nodeId": [0, 1, 2], - "labels": [["A"], ["C"], ["A", "B"]], - "score": [1337, 42, 3.14], - "component": [1, 4, 2], - "size": [0.1, 0.2, 0.3], - } - ) - rels = pd.DataFrame( - { - "sourceNodeId": [0, 1, 2], - "targetNodeId": [1, 2, 0], - "cost": [1.0, 2.0, 3.0], - "weight": [0.5, 1.5, 2.5], - "relationshipType": ["REL", "REL2", "REL"], - } - ) - - with gds.graph.construct("flo", nodes, rels) as G: - VG = from_gds( - gds, - G, - size_property="score", - additional_node_properties=["component", "size"], - node_radius_min_max=(3.14, 1337), - ) - - assert len(VG.nodes) == 3 - assert sorted(VG.nodes, key=lambda x: x.id) == [ - Node(id=0, size=float(1337), caption="['A']", properties=dict(labels=["A"], component=float(1), size=0.1)), - Node(id=1, size=float(42), caption="['C']", properties=dict(labels=["C"], component=float(4), size=0.2)), - Node( - id=2, - size=float(3.14), - caption="['A', 'B']", - properties=dict(labels=["A", "B"], component=float(2), size=0.3), - ), - ] - - assert len(VG.relationships) == 3 - vg_rels = sorted( - [ - ( - e.source, - e.target, - e.caption, - e.properties["relationshipType"], - e.properties["cost"], - e.properties["weight"], - ) - for e in VG.relationships - ], - key=lambda x: x[0], - ) - assert vg_rels == [ - (0, 1, "REL", "REL", 1.0, 0.5), - (1, 2, "REL2", "REL2", 2.0, 1.5), - (2, 0, "REL", "REL", 3.0, 2.5), - ] - - @pytest.mark.filterwarnings("ignore::DeprecationWarning") @pytest.mark.requires_neo4j_and_gds def test_from_gds_integration_all_db_properties(gds: Any, db_setup: None) -> None: from neo4j_viz.gds import from_gds with gds.graph.project("g2", ["_CI_A", "_CI_B"], "*") as G: - VG = from_gds(gds, G, node_radius_min_max=None, additional_db_node_properties=["name"]) + VG = from_gds(gds, G, db_node_properties=["name"]) assert len(VG.nodes) == 2 assert {n.properties["name"] for n in VG.nodes} == {"Alice", "Bob"} @@ -123,18 +56,27 @@ def test_from_gds_integration_all_properties(gds: Any) -> None: ) with gds.graph.construct("flo", nodes, rels) as G: - VG = from_gds( - gds, - G, - node_radius_min_max=None, - ) + VG = from_gds(gds, G) assert len(VG.nodes) == 3 assert sorted(VG.nodes, key=lambda x: x.id) == [ - Node(id=0, size=0.1, caption="['A']", properties=dict(labels=["A"], component=float(1), score=1337.0)), - Node(id=1, size=0.2, caption="['C']", properties=dict(labels=["C"], component=float(4), score=42.0)), Node( - id=2, size=0.3, caption="['A', 'B']", properties=dict(labels=["A", "B"], component=float(2), score=3.14) + id=0, + caption="['A']", + color="#ffdf81", + properties=dict(size=0.1, labels=["A"], component=float(1), score=1337.0), + ), + Node( + id=1, + caption="['C']", + color="#f79767", + properties=dict(size=0.2, labels=["C"], component=float(4), score=42.0), + ), + Node( + id=2, + caption="['A', 'B']", + color="#c990c0", + properties=dict(size=0.3, labels=["A", "B"], component=float(2), score=3.14), ), ] @@ -160,150 +102,6 @@ def test_from_gds_integration_all_properties(gds: Any) -> None: ] -def test_from_gds_mocked(mocker: MockerFixture) -> None: - from graphdatascience import Graph, GraphDataScience - - from neo4j_viz.gds import from_gds - - nodes = { - "A": pd.DataFrame( - { - "nodeId": [0, 2], - "score": [1337, 3.14], - "component": [1, 2], - } - ), - "B": pd.DataFrame( - { - "nodeId": [2], - "score": [3.14], - "component": [2], - } - ), - "C": pd.DataFrame( - { - "nodeId": [1], - "score": [42], - "component": [4], - } - ), - } - rels = [ - pd.DataFrame( - { - "sourceNodeId": [0, 1, 2], - "targetNodeId": [1, 2, 0], - "relationshipType": ["REL", "REL2", "REL"], - } - ) - ] - - mocker.patch( - "graphdatascience.Graph.__init__", - lambda x: None, - ) - mocker.patch( - "graphdatascience.Graph.name", - lambda x: "DUMMY", - ) - node_properties = ["score", "component"] - mocker.patch( - "graphdatascience.Graph.node_properties", - lambda x: pd.Series({lbl: node_properties for lbl in nodes.keys()}), - ) - mocker.patch("graphdatascience.Graph.node_labels", lambda x: list(nodes.keys())) - mocker.patch("graphdatascience.Graph.node_count", lambda x: sum(len(df) for df in nodes.values())) - mocker.patch("graphdatascience.GraphDataScience.__init__", lambda x: None) - mocker.patch("neo4j_viz.gds._fetch_node_dfs", return_value=nodes) - mocker.patch("neo4j_viz.gds._fetch_rel_dfs", return_value=rels) - - gds = GraphDataScience() # type: ignore[call-arg] - G = Graph() # type: ignore[call-arg] - - VG = from_gds( - gds, - G, - size_property="score", - additional_node_properties=["component", "score"], - node_radius_min_max=(3.14, 1337), - ) - - assert len(VG.nodes) == 3 - assert sorted(VG.nodes, key=lambda x: x.id) == [ - Node( - id=0, - caption="['A']", - size=float(1337), - properties=dict(labels=["A"], component=float(1), score=float(1337)), - ), - Node(id=1, caption="['C']", size=float(42), properties=dict(labels=["C"], component=float(4), score=float(42))), - Node( - id=2, - caption="['A', 'B']", - size=float(3.14), - properties=dict(labels=["A", "B"], component=float(2), score=float(3.14)), - ), - ] - - assert len(VG.relationships) == 3 - vg_rels = sorted( - [(e.source, e.target, e.caption, e.properties["relationshipType"]) for e in VG.relationships], - key=lambda x: x[0], - ) - assert vg_rels == [ - (0, 1, "REL", "REL"), - (1, 2, "REL2", "REL2"), - (2, 0, "REL", "REL"), - ] - - -@pytest.mark.requires_neo4j_and_gds -def test_from_gds_node_errors(gds: Any) -> None: - from neo4j_viz.gds import from_gds - - nodes = pd.DataFrame( - { - "nodeId": [0, 1, 2], - "labels": [["A"], ["C"], ["A", "B"]], - "component": [1, 4, 2], - "score": [1337, -42, 3.14], - "size": [-0.1, 0.2, 0.3], - } - ) - rels = pd.DataFrame( - { - "sourceNodeId": [0, 1, 2], - "targetNodeId": [1, 2, 0], - "relationshipType": ["REL", "REL2", "REL"], - } - ) - - with gds.graph.construct("flo", nodes, rels) as G: - with pytest.raises( - ValueError, - match=r"Error for node property 'size' with provided input '-0.1'. Reason: Input should be greater than or equal to 0", - ): - from_gds( - gds, - G, - additional_node_properties=["component", "size"], - node_radius_min_max=None, - ) - - with gds.graph.construct("flo", nodes, rels) as G: - with pytest.raises( - ValueError, - match=r"Error for node property 'score' with provided input '-42.0'. Reason: Input should be greater than or equal to 0", - ): - from_gds( - gds, - G, - size_property="score", - additional_node_properties=["component", "size"], - node_radius_min_max=None, - ) - - @pytest.mark.requires_neo4j_and_gds def test_from_gds_sample(gds: Any) -> None: from neo4j_viz.gds import from_gds @@ -369,10 +167,10 @@ def test_from_gds_hetero(gds: Any) -> None: assert len(VG.nodes) == 4 assert sorted(VG.nodes, key=lambda x: x.id) == [ - Node(id=0, caption="['A']", properties=dict(labels=["A"], component=float(1))), - Node(id=1, caption="['A']", properties=dict(labels=["A"], component=float(2))), - Node(id=2, caption="['B']", properties=dict(labels=["B"])), - Node(id=3, caption="['B']", properties=dict(labels=["B"])), + Node(id=0, caption="['A']", color="#ffdf81", properties=dict(labels=["A"], component=float(1))), + Node(id=1, caption="['A']", color="#ffdf81", properties=dict(labels=["A"], component=float(2))), + Node(id=2, caption="['B']", color="#c990c0", properties=dict(labels=["B"])), + Node(id=3, caption="['B']", color="#c990c0", properties=dict(labels=["B"])), ] assert len(VG.relationships) == 2 diff --git a/python-wrapper/tests/test_pandas.py b/python-wrapper/tests/test_pandas.py index b29f2db..9f40924 100644 --- a/python-wrapper/tests/test_pandas.py +++ b/python-wrapper/tests/test_pandas.py @@ -1,6 +1,5 @@ import pytest from pandas import DataFrame -from pydantic_extra_types.color import Color from neo4j_viz.node import Node from neo4j_viz.pandas import from_dfs @@ -18,33 +17,33 @@ def test_from_df() -> None: "weight": [1.0, 2.0], } ) - VG = from_dfs(nodes, relationships, node_radius_min_max=(42, 1337)) + VG = from_dfs(nodes, relationships) assert len(VG.nodes) == 2 - assert VG.nodes[0].id == 0 - assert VG.nodes[0].caption == "A" - assert VG.nodes[0].size == 1337 - assert VG.nodes[0].color == Color("#ff0000") - assert VG.nodes[0].properties == {"instrument": "piano"} + assert VG.nodes[0] == Node( + id=0, + caption=None, + properties={"size": 1337, "color": "#FF0000", "instrument": "piano", "caption": "A"}, + ) - assert VG.nodes[1].id == 1 - assert VG.nodes[1].caption == "B" - assert VG.nodes[1].size == 42 - assert VG.nodes[1].color == Color("#ff0000") - assert VG.nodes[1].properties == {"instrument": "guitar"} + assert VG.nodes[1] == Node( + id=1, + caption=None, + properties={"size": 42, "color": "#FF0000", "instrument": "guitar", "caption": "B"}, + ) assert len(VG.relationships) == 2 assert VG.relationships[0].source == 0 assert VG.relationships[0].target == 1 - assert VG.relationships[0].caption == "REL" - assert VG.relationships[0].properties == {"weight": 1.0} + assert VG.relationships[0].caption is None + assert VG.relationships[0].properties == {"weight": 1.0, "caption": "REL"} assert VG.relationships[1].source == 1 assert VG.relationships[1].target == 0 - assert VG.relationships[1].caption == "REL2" - assert VG.relationships[1].properties == {"weight": 2.0} + assert VG.relationships[1].caption is None + assert VG.relationships[1].properties == {"weight": 2.0, "caption": "REL2"} def test_from_rel_dfs() -> None: @@ -108,29 +107,24 @@ def test_from_dfs() -> None: } ), ] - VG = from_dfs(nodes, relationships, node_radius_min_max=(42, 1337)) + VG = from_dfs(nodes, relationships) assert len(VG.nodes) == 2 - assert VG.nodes[0].id == 0 - assert VG.nodes[0].caption == "A" - assert VG.nodes[0].size == 1337 - assert VG.nodes[0].color == Color("#ff0000") - - assert VG.nodes[1].id == 1 - assert VG.nodes[1].caption == "B" - assert VG.nodes[1].size == 42 - assert VG.nodes[0].color == Color("#ff0000") + assert VG.nodes[0] == Node(id=0, caption=None, properties={"size": 1337, "color": "#FF0000", "caption": "A"}) + assert VG.nodes[1] == Node(id=1, caption=None, properties={"size": 42, "color": "#FF0000", "caption": "B"}) assert len(VG.relationships) == 2 assert VG.relationships[0].source == 0 assert VG.relationships[0].target == 1 - assert VG.relationships[0].caption == "REL" + assert VG.relationships[0].caption is None + assert VG.relationships[0].properties == {"caption": "REL"} assert VG.relationships[1].source == 1 assert VG.relationships[1].target == 0 - assert VG.relationships[1].caption == "REL2" + assert VG.relationships[1].caption is None + assert VG.relationships[1].properties == {"caption": "REL2"} def test_node_errors() -> None: @@ -152,11 +146,6 @@ def test_node_errors() -> None: "instrument": ["piano", "guitar"], } ) - with pytest.raises( - ValueError, - match=r"Error for node column 'size' with provided input 'aaa'. Reason: Input should be a valid integer, unable to parse string as an integer", - ): - from_dfs(nodes, []) def test_rel_errors() -> None: @@ -176,21 +165,6 @@ def test_rel_errors() -> None: ): from_dfs(nodes, relationships) - relationships = DataFrame( - { - "source": [0, 1], - "target": [1, 0], - "caption": ["REL", "REL2"], - "caption_size": [1.0, -300], - "weight": [1.0, 2.0], - } - ) - with pytest.raises( - ValueError, - match=r"Error for relationship column 'caption_size' with provided input '-300.0'. Reason: Input should be greater than 0", - ): - from_dfs(nodes, relationships) - def test_from_dfs_no_rels() -> None: nodes = [ @@ -211,18 +185,10 @@ def test_from_dfs_no_rels() -> None: } ), ] - VG = from_dfs(nodes, [], node_radius_min_max=(42, 1337)) + VG = from_dfs(nodes, []) assert len(VG.nodes) == 2 - - assert VG.nodes[0].id == 0 - assert VG.nodes[0].caption == "A" - assert VG.nodes[0].size == 1337 - assert VG.nodes[0].color == Color("#ff0000") - - assert VG.nodes[1].id == 1 - assert VG.nodes[1].caption == "B" - assert VG.nodes[1].size == 42 - assert VG.nodes[0].color == Color("#ff0000") + assert VG.nodes[0] == Node(id=0, caption=None, properties={"size": 1337, "color": "#FF0000", "caption": "A"}) + assert VG.nodes[1] == Node(id=1, caption=None, properties={"size": 42, "color": "#FF0000", "caption": "B"}) assert len(VG.relationships) == 0 diff --git a/python-wrapper/tests/test_snowflake.py b/python-wrapper/tests/test_snowflake.py index b7bbece..0cb86ba 100644 --- a/python-wrapper/tests/test_snowflake.py +++ b/python-wrapper/tests/test_snowflake.py @@ -2,6 +2,7 @@ from snowflake.snowpark import Session from snowflake.snowpark.types import LongType, StructField, StructType +from neo4j_viz.node import Node from neo4j_viz.snowflake import from_snowflake @@ -58,20 +59,14 @@ def test_from_snowflake(session_with_minimal_graph: Session) -> None: }, ) - assert len(VG.nodes) == 2 - - assert VG.nodes[0].id == 0 - assert VG.nodes[0].caption == "NODES" - assert VG.nodes[0].color is not None - assert VG.nodes[0].properties == {"SNOWFLAKEID": 6} - - assert VG.nodes[1].id == 1 - assert VG.nodes[1].caption == "NODES" - assert VG.nodes[0].color is not None - assert VG.nodes[1].properties == {"SNOWFLAKEID": 7} + assert VG.nodes == [ + Node(id=0, caption="NODES", color="#ffdf81", properties={"SNOWFLAKEID": 6, "table": "NODES"}), + Node(id=1, caption="NODES", color="#ffdf81", properties={"SNOWFLAKEID": 7, "table": "NODES"}), + ] assert len(VG.relationships) == 1 assert VG.relationships[0].source == 0 assert VG.relationships[0].target == 1 assert VG.relationships[0].caption == "RELS" + assert VG.relationships[0].properties == {"table": "RELS"}