Skip to content

Commit 479056e

Browse files
committed
Add from_snowflake constructor
1 parent 4b4a43b commit 479056e

File tree

5 files changed

+421
-1
lines changed

5 files changed

+421
-1
lines changed

changelog.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,15 @@
66

77
## New features
88

9+
* Added new constructor `from_snowflake` that creates visualization graphs from Snowflake tables.
10+
911

1012
## Bug fixes
1113

1214

1315
## Improvements
1416

17+
* The `field` parameter of `color_nodes` now also accepts casing other than `snake_case`.
18+
1519

1620
## Other changes

python-wrapper/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ docs = [
6161
pandas = ["pandas>=2, <3", "pandas-stubs>=2, <3"]
6262
gds = ["graphdatascience>=1, <2"]
6363
neo4j = ["neo4j"]
64+
snowflake = ["snowflake-snowpark-python>=1, <2"]
6465
notebook = [
6566
"ipykernel>=6.29.5",
6667
"pykernel>=0.1.6",
Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
from __future__ import annotations
2+
3+
from enum import Enum
4+
from typing import Annotated, Any, Optional
5+
6+
from pandas import DataFrame
7+
from pydantic import (
8+
AfterValidator,
9+
BaseModel,
10+
BeforeValidator,
11+
)
12+
from pydantic_core.core_schema import ValidationInfo
13+
from snowflake.snowpark import Session
14+
from snowflake.snowpark.exceptions import SnowparkSQLException
15+
from snowflake.snowpark.types import (
16+
ArrayType,
17+
BooleanType,
18+
ByteType,
19+
DataType,
20+
DateType,
21+
DecimalType,
22+
DoubleType,
23+
FloatType,
24+
GeographyType,
25+
GeometryType,
26+
IntegerType,
27+
LongType,
28+
MapType,
29+
ShortType,
30+
StringType,
31+
StructField,
32+
StructType,
33+
TimestampType,
34+
TimeType,
35+
VariantType,
36+
VectorType,
37+
)
38+
39+
from neo4j_viz import VisualizationGraph
40+
from neo4j_viz.pandas import from_dfs
41+
42+
43+
def data_type_name(type: DataType) -> str:
44+
if isinstance(type, StringType):
45+
return "VARCHAR"
46+
elif isinstance(type, LongType):
47+
return "BIGINT"
48+
elif isinstance(type, IntegerType):
49+
return "INT"
50+
elif isinstance(type, DoubleType):
51+
return "DOUBLE"
52+
elif isinstance(type, DecimalType):
53+
return "NUMBER"
54+
elif isinstance(type, BooleanType):
55+
return "BOOLEAN"
56+
elif isinstance(type, ByteType):
57+
return "TINYINT"
58+
elif isinstance(type, DateType):
59+
return "DATE"
60+
elif isinstance(type, ShortType):
61+
return "SMALLINT"
62+
elif isinstance(type, FloatType):
63+
return "FLOAT"
64+
elif isinstance(type, ArrayType):
65+
return "ARRAY"
66+
elif isinstance(type, VectorType):
67+
return "VECTOR"
68+
elif isinstance(type, MapType):
69+
return "OBJECT"
70+
elif isinstance(type, TimeType):
71+
return "TIME"
72+
elif isinstance(type, TimestampType):
73+
return "TIMESTAMP"
74+
elif isinstance(type, VariantType):
75+
return "VARIANT"
76+
elif isinstance(type, GeographyType):
77+
return "GEOGRAPHY"
78+
elif isinstance(type, GeometryType):
79+
return "GEOMETRY"
80+
else:
81+
# This actually does the job much of the time anyway
82+
return type.simple_string().upper()
83+
84+
85+
SUPPORTED_ID_TYPES = [data_type_name(data_type) for data_type in [StringType(), LongType(), IntegerType()]]
86+
87+
88+
def _validate_id_column(schema: StructType, column_name: str, index: int, supported_types: list[str]) -> None:
89+
if column_name.lower() not in [name.lower() for name in schema.names]:
90+
raise ValueError(f"Schema must contain a `{column_name}` column")
91+
92+
field: StructField = schema.fields[index]
93+
94+
if field.name.lower() != column_name.lower():
95+
raise ValueError(f"Column `{column_name}` must have column index {index}")
96+
97+
if data_type_name(field.datatype) not in supported_types:
98+
raise ValueError(
99+
f"Column `{column_name}` has invalid type `{data_type_name(field.datatype)}`. Expected one of [{', '.join(supported_types)}]"
100+
)
101+
102+
103+
def validate_viz_node_table(table: str, info: ValidationInfo) -> str:
104+
context = info.context
105+
if context and context["session"] is not None:
106+
session = context["session"]
107+
try:
108+
schema = session.table(table).schema
109+
_validate_id_column(schema, "nodeId", 0, SUPPORTED_ID_TYPES)
110+
except SnowparkSQLException as e:
111+
raise ValueError(f"Table '{table}' does not exist or is not accessible.") from e
112+
return table
113+
114+
115+
def validate_viz_relationship_table(
116+
table: str,
117+
info: ValidationInfo,
118+
) -> str:
119+
context = info.context
120+
if context and context["session"] is not None:
121+
session = context["session"]
122+
try:
123+
schema = session.table(table).schema
124+
_validate_id_column(schema, "sourceNodeId", 0, SUPPORTED_ID_TYPES)
125+
_validate_id_column(schema, "targetNodeId", 1, SUPPORTED_ID_TYPES)
126+
except SnowparkSQLException as e:
127+
raise ValueError(f"Table '{table}' does not exist or is not accessible.") from e
128+
return table
129+
130+
131+
def parse_identifier_groups(identifier: str) -> list[str]:
132+
"""
133+
Parses a table identifier into a list of individual identifier groups.
134+
135+
This function handles identifiers that may include double-quoted segments
136+
and ensures proper validation of the identifier's structure. It raises
137+
errors for invalid formats, such as unbalanced quotes, invalid characters,
138+
or improper use of dots.
139+
140+
Args:
141+
identifier (str): The input string identifier to parse.
142+
143+
Returns:
144+
list[str]: A list of parsed identifier groups.
145+
146+
Raises:
147+
ValueError: If the identifier contains:
148+
- Empty double quotes.
149+
- Consecutive dots outside of double quotes.
150+
- Unbalanced double quotes.
151+
- Invalid characters in unquoted segments.
152+
- Improper placement of dots around double-quoted segments.
153+
"""
154+
inside = False # Tracks whether the current character is inside double quotes
155+
quoted_starts = [] # Stores the start indices of double-quoted segments
156+
quoted_ends = [] # Stores the end indices of double-quoted segments
157+
remaining = "" # Stores the unquoted part of the identifier
158+
previous_is_dot = False # Tracks if the previous character was a dot
159+
160+
for i, c in enumerate(identifier):
161+
if c == '"':
162+
if not inside:
163+
quoted_starts.append(i + 1) # Mark the start of a quoted segment
164+
previous_is_dot = False
165+
else:
166+
quoted_ends.append(i) # Mark the end of a quoted segment
167+
if quoted_ends[-1] - quoted_starts[-1] == 0:
168+
raise ValueError("Empty double quotes")
169+
inside = not inside # Toggle the inside state
170+
else:
171+
if not inside:
172+
remaining += c # Append unquoted characters to `remaining`
173+
if c == ".":
174+
if previous_is_dot:
175+
raise ValueError("Not ok to have consecutive dots outside of double quote")
176+
previous_is_dot = True
177+
else:
178+
previous_is_dot = False
179+
180+
if len(quoted_starts) != len(quoted_ends):
181+
raise ValueError("Unbalanced double quotes")
182+
183+
for quoted_start in quoted_starts:
184+
if quoted_start > 1:
185+
if identifier[quoted_start - 2] != ".":
186+
raise ValueError("Only dot character may precede before double quoted identifier")
187+
188+
for quoted_end in quoted_ends:
189+
if quoted_end < len(identifier) - 1:
190+
if identifier[quoted_end + 1] != ".":
191+
raise ValueError("Only dot character may follow double quoted identifier")
192+
193+
words = remaining.split(".") # Split the unquoted part by dots
194+
for word in words:
195+
if len(word) == 0:
196+
continue
197+
if word.lower()[0] not in "abcdefghijklmnopqrstuvwxyz_":
198+
raise ValueError(f"Invalid first character in identifier {word}. Only a-z, A-Z, and _ are allowed.")
199+
if not set(word.lower()).issubset(set("abcdefghijklmnopqrstuvwxyz$_0123456789")):
200+
raise ValueError(f"Invalid characters in identifier {word}. Only a-z, A-Z, 0-9, _, and $ are allowed.")
201+
202+
empty_words_idx = [i for i, w in enumerate(words) if w == ""]
203+
for i in range(len(quoted_starts)):
204+
# Replace empty words with their corresponding quoted segments
205+
words[empty_words_idx[i]] = f'"{identifier[quoted_starts[i] : quoted_ends[i]]}"'
206+
207+
return words
208+
209+
210+
def validate_table_name(table: str) -> str:
211+
if not isinstance(table, str):
212+
raise TypeError(f"Table name must be a string, got {type(table).__name__}")
213+
214+
try:
215+
words = parse_identifier_groups(table)
216+
except ValueError as e:
217+
raise ValueError(f"Invalid table name '{table}'. {str(e)}") from e
218+
219+
if len(words) not in {1, 3}:
220+
raise ValueError(
221+
f"Invalid table name '{table}'. Table names must be in the format '<database>.<schema>.<table>' or '<table>'"
222+
)
223+
224+
return table
225+
226+
227+
Table = Annotated[str, BeforeValidator(validate_table_name)]
228+
229+
VizNodeTable = Annotated[Table, AfterValidator(validate_viz_node_table)]
230+
VizRelationshipTable = Annotated[Table, AfterValidator(validate_viz_relationship_table)]
231+
232+
233+
class Orientation(Enum):
234+
NATURAL = "natural"
235+
UNDIRECTED = "undirected"
236+
REVERSE = "reverse"
237+
238+
239+
def to_lower(value: str) -> str:
240+
return value.lower() if value and isinstance(value, str) else value
241+
242+
243+
LowercaseOrientation = Annotated[Orientation, BeforeValidator(to_lower)]
244+
245+
246+
class VizRelationshipTableConfig(BaseModel, extra="forbid"):
247+
sourceTable: VizNodeTable
248+
targetTable: VizNodeTable
249+
orientation: LowercaseOrientation | None = Orientation.NATURAL
250+
251+
252+
class VizProjectConfig(BaseModel, extra="forbid"):
253+
defaultTablePrefix: str | None = None
254+
nodeTables: list[VizNodeTable]
255+
relationshipTables: dict[VizRelationshipTable, VizRelationshipTableConfig]
256+
257+
258+
def _map_tables(
259+
session: Session, project_model: VizProjectConfig
260+
) -> tuple[list[DataFrame], list[DataFrame], list[str]]:
261+
offset = 0
262+
to_internal = {}
263+
node_dfs = []
264+
for table in project_model.nodeTables:
265+
df = session.table(table).to_pandas()
266+
internal_ids = range(offset, offset + df.shape[0])
267+
to_internal[table] = df[["NODEID"]].copy()
268+
to_internal[table]["INTERNALID"] = internal_ids
269+
offset += df.shape[0]
270+
271+
df["SNOWFLAKEID"] = df["NODEID"]
272+
df["NODEID"] = internal_ids
273+
274+
node_dfs.append(df)
275+
276+
rel_dfs = []
277+
rel_table_names = []
278+
for table, rel_table_config in project_model.relationshipTables.items():
279+
df = session.table(table).to_pandas()
280+
281+
source_table = rel_table_config.sourceTable
282+
target_table = rel_table_config.targetTable
283+
284+
df = df.merge(to_internal[source_table], left_on="SOURCENODEID", right_on="NODEID")
285+
df.drop(["SOURCENODEID", "NODEID"], axis=1, inplace=True)
286+
df.rename({"INTERNALID": "SOURCENODEID"}, axis=1, inplace=True)
287+
df = df.merge(to_internal[target_table], left_on="TARGETNODEID", right_on="NODEID")
288+
df.drop(["TARGETNODEID", "NODEID"], axis=1, inplace=True)
289+
df.rename({"INTERNALID": "TARGETNODEID"}, axis=1, inplace=True)
290+
291+
if (
292+
rel_table_config.orientation == Orientation.NATURAL
293+
or rel_table_config.orientation == Orientation.UNDIRECTED
294+
):
295+
rel_dfs.append(df)
296+
rel_table_names.append(table)
297+
298+
if rel_table_config.orientation == Orientation.REVERSE:
299+
df_rev = df.rename(columns={"SOURCENODEID": "TARGETNODEID", "TARGETNODEID": "SOURCENODEID"}, copy=False)
300+
rel_dfs.append(df_rev)
301+
rel_table_names.append(table)
302+
303+
if rel_table_config.orientation == Orientation.UNDIRECTED:
304+
df_rev = df.rename(columns={"SOURCENODEID": "TARGETNODEID", "TARGETNODEID": "SOURCENODEID"}, copy=True)
305+
rel_dfs.append(df_rev)
306+
rel_table_names.append(table)
307+
308+
return node_dfs, rel_dfs, rel_table_names
309+
310+
311+
def from_snowflake(
312+
session: Session,
313+
project_config: dict[str, Any],
314+
node_radius_min_max: Optional[tuple[float, float]] = (3, 60),
315+
) -> VisualizationGraph:
316+
project_model = VizProjectConfig.model_validate(project_config, strict=False, context={"session": session})
317+
node_dfs, rel_dfs, rel_table_names = _map_tables(session, project_model)
318+
319+
node_caption_present = False
320+
for node_df in node_dfs:
321+
if "CAPTION" in node_df.columns:
322+
node_caption_present = True
323+
break
324+
325+
if not node_caption_present:
326+
for i, node_df in enumerate(node_dfs):
327+
node_df["caption"] = project_model.nodeTables[i].split(".")[-1]
328+
329+
rel_caption_present = False
330+
for rel_df in rel_dfs:
331+
if "CAPTION" in rel_df.columns:
332+
rel_caption_present = True
333+
break
334+
335+
if not rel_caption_present:
336+
for i, rel_df in enumerate(rel_dfs):
337+
rel_df["caption"] = rel_table_names[i].split(".")[-1]
338+
339+
return from_dfs(node_dfs, rel_dfs, node_radius_min_max)

python-wrapper/src/neo4j_viz/visualization_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any, Callable, Hashable, Optional, Union
66

77
from IPython.display import HTML
8+
from pydantic.alias_generators import to_snake
89
from pydantic_extra_types.color import Color, ColorType
910

1011
from .colors import NEO4J_COLORS_CONTINUOUS, NEO4J_COLORS_DISCRETE, ColorSpace, ColorsType
@@ -277,7 +278,7 @@ def node_to_attr(node: Node) -> Any:
277278
return node.properties.get(attribute)
278279
else:
279280
assert field is not None
280-
attribute = field
281+
attribute = to_snake(field)
281282

282283
def node_to_attr(node: Node) -> Any:
283284
return getattr(node, attribute)

0 commit comments

Comments
 (0)