diff --git a/tests/test_functions.py b/tests/test_functions.py index a163144..5a9eea5 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -533,6 +533,26 @@ def test_remove_edge_by_key(): assert len(net.graph["keys"].keys()) == 0 +def test_get_source_edges(): + net = networkx.MultiDiGraph() + + net.add_edge(0, 1, key="A", **{"EDGE_ID": 1}) + net.add_edge(1, 2, key="B", **{"EDGE_ID": 2}) + net.add_edge(2, 3, key="C", **{"EDGE_ID": 3}) + edges = functions.get_source_edges(net) + assert edges["A"] == {"EDGE_ID": 1} + + +def test_get_sink_edges(): + net = networkx.MultiDiGraph() + + net.add_edge(0, 1, key="A", **{"EDGE_ID": 1}) + net.add_edge(1, 2, key="B", **{"EDGE_ID": 2}) + net.add_edge(2, 3, key="C", **{"EDGE_ID": 3}) + edges = functions.get_sink_edges(net) + assert edges["C"] == {"EDGE_ID": 3} + + def test_doctest(): import doctest @@ -568,5 +588,7 @@ def test_doctest(): # test_remove_edge_by_key() # test_add_edge_shorthand() # test_add_single_edge() - test_get_shortest_edge_no_length() + # test_get_shortest_edge_no_length() + test_get_source_edges() + test_get_sink_edges() print("Done!") diff --git a/wayfarer/functions.py b/wayfarer/functions.py index 741d4b7..fe4e9a4 100644 --- a/wayfarer/functions.py +++ b/wayfarer/functions.py @@ -565,7 +565,7 @@ def get_sink_edges(net): """ sinks = get_sink_nodes(net) - return get_edges(net, sinks) + return get_edges(net, sinks, in_edges=True) def get_source_nodes(net): @@ -598,10 +598,10 @@ def get_source_edges(net): """ sources = get_source_nodes(net) - return get_edges(net, sources) + return get_edges(net, sources, in_edges=False) -def get_edges(net, nodes): +def get_edges(net, nodes, in_edges=True): """ Get all the edges in the network that touch the nodes in the nodes list. Note only works with DiGraph or MultiDiGraph. @@ -609,6 +609,7 @@ def get_edges(net, nodes): Args: net (object): a networkx network nodes: (list): a list of network nodes e.g. ``[(224966, 437657), (225195, 437940)]`` + in_edges: (bool): set to True to get edges coming into the node, and False to edges leaving the node Returns: end_edges (dict): a dict containing segment codes and their associated network edge @@ -617,9 +618,12 @@ def get_edges(net, nodes): end_edges = {} for node in nodes: - in_edges = net.in_edges(node) + if in_edges is True: + node_edges = net.in_edges(node) + else: + node_edges = net.out_edges(node) - for edge in in_edges: + for edge in node_edges: edges = get_edges_from_node_pair(net, edge[0], edge[1]) assert len(edges) == 1 end_edges.update(edges)