Skip to content

Commit

Permalink
feat: leaf_paths (Get a dictionary of leaf paths of a nested dictiona…
Browse files Browse the repository at this point in the history
…ry) + refactors of flatten_dict
  • Loading branch information
thorwhalen committed Dec 6, 2024
1 parent 0a7b66e commit 8ea5e78
Showing 1 changed file with 113 additions and 19 deletions.
132 changes: 113 additions & 19 deletions dol/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from functools import wraps, partial
from dataclasses import dataclass
from typing import (
Optional,
Union,
Callable,
Any,
Expand All @@ -31,6 +32,8 @@
KT,
VT,
TypeVar,
Generator,
TypeAlias,
List,
Dict,
)
Expand All @@ -51,37 +54,72 @@
from dol.dig import recursive_get_attr


KeyValueGenerator = Generator[tuple[KT, VT], None, None]
Path = TypeVar("Path")
PathExtenderFunc = Callable[[Path, KT], Path]
PathExtenderSpec = Union[str, PathExtenderFunc]
NestedMapping: TypeAlias = Mapping[KT, Union[VT, "NestedMapping[KT, VT]"]]


def separator_based_path_extender(path: Path, key: KT, sep: str) -> Path:
"""
Extends a given path with a new key using the specified separator.
If the path is empty, the key is returned as is.
"""
return f"{path}{sep}{key}" if path else key


def ensure_path_extender_func(path_extender: PathExtenderSpec) -> PathExtenderFunc:
"""
Ensure that the path_extender is a function that takes a path and a key and returns
a new path."""
if isinstance(path_extender, str):
return partial(separator_based_path_extender, sep=path_extender)
return path_extender


def flattened_dict_items(
d,
sep=".",
sep: PathExtenderSpec = ".",
*,
parent_key="",
parent_path: Optional[Path] = None,
visit_nested: Callable = lambda obj: isinstance(obj, Mapping),
):
) -> KeyValueGenerator:
"""
Yield flattened key-value pairs from a nested dictionary.
"""
path_extender = ensure_path_extender_func(sep)

stable_kwargs = dict(sep=sep, visit_nested=visit_nested)
paths = []

for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
new_path = path_extender(parent_path, k)
if visit_nested(v):
yield from flattened_dict_items(v, parent_key=new_key, **stable_kwargs)
yield from flattened_dict_items(v, parent_path=new_path, **stable_kwargs)
else:
yield new_key, v

return paths
yield new_path, v


def flatten_dict(
d,
sep=".",
sep: PathExtenderSpec = ".",
*,
parent_key="",
parent_path: Optional[Path] = None,
visit_nested: Callable = lambda obj: isinstance(obj, Mapping),
egress: Callable[[KeyValueGenerator], Mapping] = dict,
):
r"""
Flatten a nested dictionary into a flat one.
Flatten a nested dictionary into a flat one, using key-paths as keys.
See also `leaf_paths` for a related function that returns paths to leaf values.
Args:
d: The dictionary to flatten
sep: The separator to use for joining keys, or a function that takes a path and
a key and returns a new path.
parent_path: The path to the parent of the current dict
visit_nested: A function that returns True if a value should be visited
egress: A function that takes a generator of key-value pairs and returns a mapping
>>> d = {'a': {'b': 2}, 'c': 3}
>>> flatten_dict(d)
Expand All @@ -90,13 +128,74 @@ def flatten_dict(
{'a/b': 2, 'c': 3}
"""
return dict(
return egress(
flattened_dict_items(
d, sep=sep, parent_key=parent_key, visit_nested=visit_nested
d, sep=sep, parent_path=parent_path, visit_nested=visit_nested
)
)


def leaf_paths(
d: NestedMapping,
sep: PathExtenderSpec = ".",
*,
parent_path: Optional[Path] = None,
egress: Callable[[KeyValueGenerator], Mapping] = dict,
) -> Dict[KT, Union[KT, Path]]:
"""
Get a dictionary of leaf paths of a nested dictionary.
Given a nested dictionary, returns a similarly structured dictionary where each
leaf value is replaced by its flattened path. The 'sep' parameter can be either
a string or a callable.
Original use case: You used flatten_dict to flatten a nested dictionary, referencing
your values with paths, but maybe you'd like to know what the paths that your
nested dictionary is going to flatten to are. This function does that.
The output is a dict with the same keys and structure as the input, but the leaf
values are replaced by the paths that would be used to access them in a flat dict.
Args:
d: The nested dictionary to get the leaf paths from
sep: The separator to use for joining keys, or a function that takes a path and
a key and returns a new path.
parent_path: The path to the parent of the current dict
egress: A function that takes a generator of key-value pairs and returns a mapping
Example:
>>> leaf_paths({'a': {'b': 2}, 'c': 3})
{'a': {'b': 'a.b'}, 'c': 'c'}
>>> leaf_paths({'a': {'b': 2}, 'c': 3}, sep="/")
{'a': {'b': 'a/b'}, 'c': 'c'}
>>> leaf_paths({'a': {'b': 2}, 'c': 3}, sep=lambda p, k: f"{p}-{k}" if p else k)
{'a': {'b': 'a-b'}, 'c': 'c'}
"""
path_extender = ensure_path_extender_func(sep)

return egress(_leaf_paths_recursive(d, path_extender, parent_path=parent_path))


def _leaf_paths_recursive(
d: NestedMapping,
path_extender: PathExtenderFunc,
parent_path: Optional[Path] = None,
*,
visit_nested: Callable[[Any], bool] = lambda x: isinstance(x, dict),
) -> KeyValueGenerator:
"""
A recursive generator that yields (key, value) pairs.
A helper for leaf_paths.
"""
for k, v in d.items():
current_path = path_extender(parent_path, k)
if visit_nested(v):
yield k, dict(_leaf_paths_recursive(v, path_extender, current_path))
else:
yield k, current_path


path_sep = os.path.sep


Expand Down Expand Up @@ -214,11 +313,6 @@ def get_attr_or_item(obj, k):
# key-path operations


from typing import Iterable, KT, VT, Callable, Mapping, Union

Path = Union[Iterable[KT], str]


# TODO: Needs a lot more documentation and tests to show how versatile it is
def path_get(
obj: Any,
Expand Down

0 comments on commit 8ea5e78

Please sign in to comment.