diff --git a/test/fx/test_subgraph_rewriter.py b/test/fx/test_subgraph_rewriter.py index 17cf05bb67d9a..9e38969700726 100644 --- a/test/fx/test_subgraph_rewriter.py +++ b/test/fx/test_subgraph_rewriter.py @@ -458,3 +458,88 @@ def forward(self, x): if n.op == 'placeholder': assert n.type == int assert m.type == int + + def test_subgraph_writer_replace_consecutive_submodules(self): + + def f(x): + x = torch.sigmoid(x) + x = torch.sigmoid(x) + return torch.sigmoid(x) + + def pattern(x): + return torch.sigmoid(x) + + def replacement(x): + return torch.exp(x) + + def comparison(x): + x = torch.exp(x) + x = torch.exp(x) + return torch.exp(x) + + traced = symbolic_trace(f) + comparison_fn = symbolic_trace(comparison) + + x = torch.randn(3, 4) + + subgraph_rewriter.replace_pattern(traced, pattern, replacement) + + traced.graph.lint() + + ref_outs = comparison_fn(x) + test_outs = traced.forward(x) + self.assertEqual(ref_outs, test_outs) + + def test_subgraph_rewriter_replaces_parallel_functions(self): + def f(x): + y = torch.sigmoid(x) + z = torch.sigmoid(x) + return y, z + + def pattern(x): + return torch.sigmoid(x) + + def replacement(x): + return torch.relu(x) + + def comparison(x): + y = torch.relu(x) + z = torch.relu(x) + return y, z + + traced = symbolic_trace(f) + + subgraph_rewriter.replace_pattern(traced, pattern, replacement) + traced.graph.lint() + + x = torch.randn(3, 4) + ref_outs = comparison(x) + test_outs = traced.forward(x) + self.assertEqual(ref_outs, test_outs) + + def test_subgraph_rewriter_replaces_parallel_functions_when_aggregated(self): + def f(x): + y = torch.sigmoid(x) + z = torch.sigmoid(x) + return y + z + + def pattern(x): + return torch.sigmoid(x) + + def replacement(x): + return torch.relu(x) + + def comparison(x): + y = torch.relu(x) + z = torch.relu(x) + return y + z + + traced = symbolic_trace(f) + + subgraph_rewriter.replace_pattern(traced, pattern, replacement) + traced.graph.lint() + + x = torch.randn(3, 4) + ref_outs = comparison(x) + test_outs = traced.forward(x) + self.assertEqual(ref_outs, test_outs) \ No newline at end of file diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index 72ea56aa31196..c466592ded194 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -28,9 +28,9 @@ def __init__(self, pattern: Graph) -> None: assert len(self.pattern_anchor.all_input_nodes) == 1, \ "Pattern matching on multiple outputs is not supported" # Maps nodes in the pattern subgraph to nodes in the larger graph - self.nodes_map: Dict[Node, Node] = {} + self.nodes_map: List[Dict[Node, Node]] = [{}] - def matches_subgraph_from_anchor(self, anchor: Node) -> bool: + def matches_subgraph_from_anchor(self, anchor: Node) -> List[Dict[Node, Node]]: """ Checks if the whole pattern can be matched starting from ``anchor`` in the larger graph. @@ -38,16 +38,20 @@ def matches_subgraph_from_anchor(self, anchor: Node) -> bool: Pattern matching is done by recursively comparing the pattern node's use-def relationships against the graph node's. """ - self.nodes_map = {} - return self._match_nodes(self.pattern_anchor, anchor) + self.nodes_map: List[Dict[Node, Node]] = [{}] + self._match_nodes(self.pattern_anchor, anchor) + + # We need to filter out the one that are empty + self.nodes_map = [elt for elt in self.nodes_map if len(elt) > 0] + return self.nodes_map # Compare the pattern node `pn` against the graph node `gn` - def _match_nodes(self, pn: Node, gn: Node) -> bool: + def _match_nodes(self, pn: Node, gn: Node, graph_id: int = 0) -> bool: # Check if we've already matched these nodes in the current # traversal - if pn in self.nodes_map: - return self.nodes_map[pn] == gn + if pn in self.nodes_map[graph_id]: + return self.nodes_map[graph_id][pn] == gn def attributes_are_equal(pn: Node, gn: Node) -> bool: # Use placeholder and output nodes as wildcards. The @@ -63,7 +67,7 @@ def attributes_are_equal(pn: Node, gn: Node) -> bool: return False # Optimistically mark `pn` as a match for `gn` - self.nodes_map[pn] = gn + self.nodes_map[graph_id][pn] = gn # Traverse the use-def relationships to ensure that `pn` is a true # match for `gn` @@ -73,14 +77,22 @@ def attributes_are_equal(pn: Node, gn: Node) -> bool: and len(pn.all_input_nodes) != len(gn.all_input_nodes)): return False if pn.op == "output": - match_found = any(self._match_nodes(pn.all_input_nodes[0], gn_) - for gn_ in gn.all_input_nodes) + # Only the first graph compares the output. + assert graph_id == 0 + # We broadcast the result to all the other potential graph matching. + self.nodes_map += [copy.copy(self.nodes_map[graph_id]) for _ in range(len(gn.all_input_nodes) - 1)] + all_matches = tuple(self._match_nodes(pn.all_input_nodes[0], gn_, graph_id_) + for graph_id_, gn_ in enumerate(gn.all_input_nodes) + ) + self.nodes_map = [node_map for node_map, match in zip(self.nodes_map, all_matches) if match] + # This is not really needed to return that value + return any(all_matches) else: match_found = (len(pn.all_input_nodes) == len(gn.all_input_nodes) - and all(self._match_nodes(pn_, gn_) for pn_, gn_ + and all(self._match_nodes(pn_, gn_, graph_id) for pn_, gn_ in zip(pn.all_input_nodes, gn.all_input_nodes))) if not match_found: - self.nodes_map.pop(pn) + self.nodes_map[graph_id].pop(pn) return False return True @@ -256,64 +268,67 @@ def forward(self, x, w1, w2): matcher = _SubgraphMatcher(pattern_graph) matches: List[Match] = [] - # Consider each node as an "anchor" (deepest matching graph node) - for anchor in original_graph.nodes: + def pattern_is_contained(nodes_map: Dict[Node, Node]) -> bool: + # `lookup` represents all the nodes in `original_graph` + # that are part of `pattern` + lookup: Dict[Node, Node] = {v: k for k, v in nodes_map.items()} + for n in lookup.keys(): - if matcher.matches_subgraph_from_anchor(anchor): - - def pattern_is_contained(nodes_map : Dict[Node, Node]) -> bool: - # `lookup` represents all the nodes in `original_graph` - # that are part of `pattern` - lookup: Dict[Node, Node] = {v : k for k, v - in nodes_map.items()} - for n in lookup.keys(): - - # Nodes that can "leak"... - - # Placeholders (by definition) - if n.op == "placeholder": - continue - # Pattern output (acts as a container) - if lookup[n].op == "output": - continue - # Result contained by pattern output (what we'll - # hook in to the new Graph, thus what we'll - # potentially use in other areas of the Graph as - # an input Node) - if (len(lookup[n].users) == 1 - and list(lookup[n].users.keys())[0].op == "output"): - continue - - for user in n.users: - # If this node has users that were not in - # `lookup`, then it must leak out of the - # pattern subgraph - if user not in lookup: - return False - return True + # Nodes that can "leak"... + + # Placeholders (by definition) + if n.op == "placeholder": + continue + # Pattern output (acts as a container) + if lookup[n].op == "output": + continue + # Placeholders (by definition) + if lookup[n].op == "placeholder": + continue + # Result contained by pattern output (what we'll + # hook in to the new Graph, thus what we'll + # potentially use in other areas of the Graph as + # an input Node) + if (len(lookup[n].users) == 1 + and list(lookup[n].users.keys())[0].op == "output"): + continue - # It's not a match if the pattern leaks out into the rest - # of the graph - if pattern_is_contained(matcher.nodes_map): - for k, v in matcher.nodes_map.items(): - # Shallow copy nodes_map - matches.append(Match(anchor=anchor, - nodes_map=copy.copy(matcher.nodes_map))) + for user in n.users: + # If this node has users that were not in + # `lookup`, then it must leak out of the + # pattern subgraph + if user not in lookup: + return False + return True + + # Consider each node as an "anchor" (deepest matching graph node) + for anchor in original_graph.nodes: + potential_matches = matcher.matches_subgraph_from_anchor(anchor) + # It's not a match if the pattern leaks out into the rest + # of the graph + for node_map in potential_matches: + if pattern_is_contained(node_map): + # Shallow copy nodes_map + matches.append(Match(anchor=anchor, + nodes_map=copy.copy(node_map))) # The set of all nodes in `original_graph` that we've seen thus far # as part of a pattern match replaced_nodes: Set[Node] = set() + # As we progressively replace node, we need to keep track on how the match results need to change also + match_changed_node: Dict[Node, Node] = dict() # Return True if one of the nodes in the current match has already # been used as part of another match def overlaps_with_prev_match(match: Match) -> bool: - for n in match.nodes_map.values(): - if n in replaced_nodes and n.op != "placeholder": + for pn, gn in match.nodes_map.items(): + if pn.op in ["placeholder", "output"]: + continue + if gn in replaced_nodes and gn.op != "placeholder": return True return False - for match in matches: - + for i, match in enumerate(matches): # Skip overlapping matches if overlaps_with_prev_match(match): continue @@ -327,7 +342,7 @@ def overlaps_with_prev_match(match: Match) -> bool: replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"] assert len(pattern_placeholders) == len(replacement_placeholders) - placeholder_map = {r : p for r, p + placeholder_map = {r: p for r, p in zip(replacement_placeholders, pattern_placeholders)} # node from `original_graph` that matched with the output node @@ -341,15 +356,17 @@ def mark_node_as_replaced(n: Node) -> None: mark_node_as_replaced(n_) replaced_nodes.add(n) - mark_node_as_replaced(subgraph_output) + for input_node in subgraph_output.all_input_nodes: + mark_node_as_replaced(input_node) - # Intialize `val_map` with mappings from placeholder nodes in + # Initialize `val_map` with mappings from placeholder nodes in # `replacement` to their corresponding node in `original_graph` for replacement_node in replacement_placeholders: # Get the `original_graph` placeholder node # corresponding to the current `replacement_node` pattern_node = placeholder_map[replacement_node] - original_graph_node = match.nodes_map[pattern_node] + original_graph_node = match_changed_node.get(match.nodes_map[pattern_node], match.nodes_map[pattern_node]) + # Populate `val_map` val_map[replacement_node] = original_graph_node @@ -361,39 +378,36 @@ def mark_node_as_replaced(n: Node) -> None: # Hook the output Node of the replacement subgraph in to the # original Graph at the correct location - # CASE 1: We need to hook the replacement subgraph in somewhere - # in the middle of the graph. We replace the Node in the - # original graph that corresponds to the end of the pattern - # subgraph - if subgraph_output.op != "output": - # `subgraph_output` may have multiple args. These args could - # be from the orignal graph, or they could have come from - # the insertion of `replacement_subgraph`. We need to find - # the Node that was originally matched as part of - # `pattern` (i.e. a Node from the original graph). We can - # figure this out by looking in `match.nodes_map`. The map - # was created before `replacement_subgraph` was spliced in, - # so we know that, if a Node is in `match.nodes_map.values`, - # it must have come from the original graph - for n in subgraph_output.all_input_nodes: - if (n.op != "placeholder" - and n in match.nodes_map.values()): - subgraph_output = n - break - assert subgraph_output.op != "output" - # CASE 2: The pattern subgraph match extends to the end of the - # original graph, so we need to change the current graph's - # output Node to reflect the insertion of the replacement graph. - # We'll keep the current output Node, but update its args and - # `_input_nodes` as necessary - else: - subgraph_output.args = ((copied_output,)) - if isinstance(copied_output, Node): - subgraph_output._input_nodes = {copied_output: None} + pattern_outputs = [n for n in pattern_graph.nodes + if n.op == "output"] + assert len(pattern_outputs) + replacement_outputs = [n for n in replacement_graph.nodes + if n.op == "output"] + assert len(replacement_outputs) == len(pattern_outputs) + outputs_map = {p: r for r, p + in zip(replacement_outputs, pattern_outputs)} + + for pn, gn in match.nodes_map.items(): + if gn.op == "placeholder": + continue - assert isinstance(copied_output, Node) - subgraph_output.replace_all_uses_with(copied_output) + # We search for the node corresponding to the output of the pattern. + if pn.op != "output": + continue + + # the anchor should correspond to `subgraph_output` + assert subgraph_output == gn + # We update all anchor inputs to the new nodes + rn = outputs_map[pn] + for pn_input, rn_input in zip(pn.all_input_nodes, rn.all_input_nodes): + gn_input = match.nodes_map[pn_input] + rn_input_in_original_graph = val_map[rn_input] + gn_input.replace_all_uses_with(rn_input_in_original_graph) + # We store the updated node point in case other nodes want to use it + match_changed_node[gn_input] = rn_input_in_original_graph + + assert isinstance(copied_output, Node) # Erase the `pattern` nodes for node in reversed(original_graph.nodes): if len(node.users) == 0 and node.op != "output":