Skip to content

Commit

Permalink
Fix source and sink functions and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
geographika committed May 21, 2024
1 parent 0f3cfd7 commit 4497d13
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 6 deletions.
24 changes: 23 additions & 1 deletion tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,26 @@ def test_remove_edge_by_key():
assert len(net.graph["keys"].keys()) == 0


def test_get_source_edges():
net = networkx.MultiDiGraph()

net.add_edge(0, 1, key="A", **{"EDGE_ID": 1})
net.add_edge(1, 2, key="B", **{"EDGE_ID": 2})
net.add_edge(2, 3, key="C", **{"EDGE_ID": 3})
edges = functions.get_source_edges(net)
assert edges["A"] == {"EDGE_ID": 1}


def test_get_sink_edges():
net = networkx.MultiDiGraph()

net.add_edge(0, 1, key="A", **{"EDGE_ID": 1})
net.add_edge(1, 2, key="B", **{"EDGE_ID": 2})
net.add_edge(2, 3, key="C", **{"EDGE_ID": 3})
edges = functions.get_sink_edges(net)
assert edges["C"] == {"EDGE_ID": 3}


def test_doctest():
import doctest

Expand Down Expand Up @@ -568,5 +588,7 @@ def test_doctest():
# test_remove_edge_by_key()
# test_add_edge_shorthand()
# test_add_single_edge()
test_get_shortest_edge_no_length()
# test_get_shortest_edge_no_length()
test_get_source_edges()
test_get_sink_edges()
print("Done!")
14 changes: 9 additions & 5 deletions wayfarer/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def get_sink_edges(net):
"""
sinks = get_sink_nodes(net)

return get_edges(net, sinks)
return get_edges(net, sinks, in_edges=True)


def get_source_nodes(net):
Expand Down Expand Up @@ -598,17 +598,18 @@ def get_source_edges(net):
"""
sources = get_source_nodes(net)

return get_edges(net, sources)
return get_edges(net, sources, in_edges=False)


def get_edges(net, nodes):
def get_edges(net, nodes, in_edges=True):
"""
Get all the edges in the network that touch the nodes in the
nodes list. Note only works with DiGraph or MultiDiGraph.
Args:
net (object): a networkx network
nodes: (list): a list of network nodes e.g. ``[(224966, 437657), (225195, 437940)]``
in_edges: (bool): set to True to get edges coming into the node, and False to edges leaving the node
Returns:
end_edges (dict): a dict containing segment codes and their associated network edge
Expand All @@ -617,9 +618,12 @@ def get_edges(net, nodes):
end_edges = {}

for node in nodes:
in_edges = net.in_edges(node)
if in_edges is True:
node_edges = net.in_edges(node)
else:
node_edges = net.out_edges(node)

for edge in in_edges:
for edge in node_edges:
edges = get_edges_from_node_pair(net, edge[0], edge[1])
assert len(edges) == 1
end_edges.update(edges)
Expand Down

0 comments on commit 4497d13

Please sign in to comment.