Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions test/fx/test_subgraph_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,3 +458,34 @@ def forward(self, x):
if n.op == 'placeholder':
assert n.type == int
assert m.type == int

def test_subgraph_writer_replace_consecutive_submodules(self):

def f(x):
x = torch.sigmoid(x)
x = torch.sigmoid(x)
return torch.sigmoid(x)

def pattern(x):
return torch.sigmoid(x)

def replacement(x):
return torch.exp(x)

def comparison(x):
x = torch.exp(x)
x = torch.exp(x)
return torch.exp(x)

traced = symbolic_trace(f)
comparison_fn = symbolic_trace(comparison)

x = torch.randn(3, 4)

subgraph_rewriter.replace_pattern(traced, pattern, replacement)

traced.graph.lint()

ref_outs = comparison_fn(x)
test_outs = traced.forward(x)
self.assertEqual(ref_outs, test_outs)
79 changes: 48 additions & 31 deletions torch/fx/subgraph_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,10 @@ def forward(self, x, w1, w2):

if matcher.matches_subgraph_from_anchor(anchor):

def pattern_is_contained(nodes_map : Dict[Node, Node]) -> bool:
def pattern_is_contained(nodes_map: Dict[Node, Node]) -> bool:
# `lookup` represents all the nodes in `original_graph`
# that are part of `pattern`
lookup: Dict[Node, Node] = {v : k for k, v
in nodes_map.items()}
lookup: Dict[Node, Node] = {v: k for k, v in nodes_map.items()}
for n in lookup.keys():

# Nodes that can "leak"...
Expand Down Expand Up @@ -295,25 +294,30 @@ def pattern_is_contained(nodes_map : Dict[Node, Node]) -> bool:
# It's not a match if the pattern leaks out into the rest
# of the graph
if pattern_is_contained(matcher.nodes_map):
for k, v in matcher.nodes_map.items():
# Shallow copy nodes_map
matches.append(Match(anchor=anchor,
nodes_map=copy.copy(matcher.nodes_map)))
# Shallow copy nodes_map
matches.append(Match(anchor=anchor,
nodes_map=copy.copy({
key: value
for key, value in matcher.nodes_map.items()
})))

# The set of all nodes in `original_graph` that we've seen thus far
# as part of a pattern match
replaced_nodes: Set[Node] = set()
# As we progressively replace node, we need to keep track on how the match results need to change also
match_changed_node: Dict[Node, Node] = dict()

# Return True if one of the nodes in the current match has already
# been used as part of another match
def overlaps_with_prev_match(match: Match) -> bool:
for n in match.nodes_map.values():
if n in replaced_nodes and n.op != "placeholder":
for pn, gn in match.nodes_map.items():
if pn.op in ["placeholder", "output"]:
continue
if gn in replaced_nodes and gn.op != "placeholder":
return True
return False

for match in matches:

for i, match in enumerate(matches):
# Skip overlapping matches
if overlaps_with_prev_match(match):
continue
Expand All @@ -327,7 +331,7 @@ def overlaps_with_prev_match(match: Match) -> bool:
replacement_placeholders = [n for n in replacement_graph.nodes
if n.op == "placeholder"]
assert len(pattern_placeholders) == len(replacement_placeholders)
placeholder_map = {r : p for r, p
placeholder_map = {r: p for r, p
in zip(replacement_placeholders, pattern_placeholders)}

# node from `original_graph` that matched with the output node
Expand All @@ -341,15 +345,17 @@ def mark_node_as_replaced(n: Node) -> None:
mark_node_as_replaced(n_)
replaced_nodes.add(n)

mark_node_as_replaced(subgraph_output)
for input_node in subgraph_output.all_input_nodes:
mark_node_as_replaced(input_node)

# Intialize `val_map` with mappings from placeholder nodes in
# Initialize `val_map` with mappings from placeholder nodes in
# `replacement` to their corresponding node in `original_graph`
for replacement_node in replacement_placeholders:
# Get the `original_graph` placeholder node
# corresponding to the current `replacement_node`
pattern_node = placeholder_map[replacement_node]
original_graph_node = match.nodes_map[pattern_node]
original_graph_node = match_changed_node.get(match.nodes_map[pattern_node], match.nodes_map[pattern_node])

# Populate `val_map`
val_map[replacement_node] = original_graph_node

Expand All @@ -366,20 +372,33 @@ def mark_node_as_replaced(n: Node) -> None:
# original graph that corresponds to the end of the pattern
# subgraph
if subgraph_output.op != "output":
# `subgraph_output` may have multiple args. These args could
# be from the orignal graph, or they could have come from
# the insertion of `replacement_subgraph`. We need to find
# the Node that was originally matched as part of
# `pattern` (i.e. a Node from the original graph). We can
# figure this out by looking in `match.nodes_map`. The map
# was created before `replacement_subgraph` was spliced in,
# so we know that, if a Node is in `match.nodes_map.values`,
# it must have come from the original graph
for n in subgraph_output.all_input_nodes:
if (n.op != "placeholder"
and n in match.nodes_map.values()):
subgraph_output = n
break
pattern_outputs = [n for n in pattern_graph.nodes
if n.op == "output"]
assert len(pattern_outputs)
replacement_outputs = [n for n in replacement_graph.nodes
if n.op == "output"]
assert len(replacement_outputs) == len(pattern_outputs)
outputs_map = {p: r for r, p
in zip(replacement_outputs, pattern_outputs)}

for pn, gn in match.nodes_map.items():
if gn.op == "placeholder":
continue

# We search for the node corresponding to the output of the pattern.
if pn.op != "output":
continue
assert subgraph_output == gn

# We update all anchor inputs to the new nodes
rn = outputs_map[pn]
for pn_input, rn_input in zip(pn.all_input_nodes, rn.all_input_nodes):
gn_input = match.nodes_map[pn_input]
rn_input_in_original_graph = val_map[rn_input]
gn_input.replace_all_uses_with(rn_input_in_original_graph)
# We store the updated node point in case other nodes want to use it
match_changed_node[gn_input] = rn_input_in_original_graph

assert subgraph_output.op != "output"
# CASE 2: The pattern subgraph match extends to the end of the
# original graph, so we need to change the current graph's
Expand All @@ -392,8 +411,6 @@ def mark_node_as_replaced(n: Node) -> None:
subgraph_output._input_nodes = {copied_output: None}

assert isinstance(copied_output, Node)
subgraph_output.replace_all_uses_with(copied_output)

# Erase the `pattern` nodes
for node in reversed(original_graph.nodes):
if len(node.users) == 0 and node.op != "output":
Expand Down