|
31 | 31 | class AnnotateQuantAttrs(ExportPass): |
32 | 32 | """ |
33 | 33 | Add "quant_attrs" to graph nodes' meta from the QDQ information |
34 | | - generated after quatization process. |
| 34 | + generated after quantization process. |
35 | 35 | """ |
36 | 36 |
|
37 | | - def __init__( |
38 | | - self, edge_program: torch.export.ExportedProgram, skip_advanced_requat: bool |
39 | | - ): |
| 37 | + def __init__(self, edge_program: torch.export.ExportedProgram): |
40 | 38 | super(AnnotateQuantAttrs, self).__init__() |
41 | 39 | self.edge_program = edge_program |
42 | | - self.skip_advanced_requant = skip_advanced_requat |
43 | 40 |
|
44 | 41 | def _annotate_source_nodes( |
45 | 42 | self, quant_node: torch.fx.Node, quant_attrs: Dict[str, Any] |
@@ -88,30 +85,21 @@ def _annotate_requant(self, n): |
88 | 85 | dq_attrs = get_quant_attrs(self.edge_program, dq_node) |
89 | 86 | # TODO: Store multiple pairs of requantize attributes when we have an op builder |
90 | 87 | # that has multiple outputs that requires quant attributes. |
91 | | - if self.skip_advanced_requant: |
92 | | - if q_attrs[QCOM_DTYPE] != dq_attrs[QCOM_DTYPE]: |
93 | | - dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING] |
94 | | - user_node = list(dq_node.users)[0] |
95 | | - n.args[0].meta.setdefault(QCOM_REQUANTIZE, {}) |
96 | | - n.args[0].meta[QCOM_REQUANTIZE][user_node.name] = dq_attrs |
97 | | - else: |
98 | | - # When dtype is the same but other specs such as scale and offset are different, |
99 | | - # insert requant to improve accuracy. |
100 | | - # Users can turn this feature off if any inference speed drop is observed. |
101 | | - if any( |
102 | | - q_attrs[attr] != dq_attrs[attr] |
103 | | - for attr in [ |
104 | | - QCOM_SCALE, |
105 | | - QCOM_ZERO_POINT, |
106 | | - QCOM_QUANT_MIN, |
107 | | - QCOM_QUANT_MAX, |
108 | | - QCOM_DTYPE, |
109 | | - ] |
110 | | - ): |
111 | | - dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING] |
112 | | - user_node = list(dq_node.users)[0] |
113 | | - n.args[0].meta.setdefault(QCOM_REQUANTIZE, {}) |
114 | | - n.args[0].meta[QCOM_REQUANTIZE][user_node.name] = dq_attrs |
| 88 | + |
| 89 | + if any( |
| 90 | + q_attrs[attr] != dq_attrs[attr] |
| 91 | + for attr in [ |
| 92 | + QCOM_SCALE, |
| 93 | + QCOM_ZERO_POINT, |
| 94 | + QCOM_QUANT_MIN, |
| 95 | + QCOM_QUANT_MAX, |
| 96 | + QCOM_DTYPE, |
| 97 | + ] |
| 98 | + ): |
| 99 | + dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING] |
| 100 | + user_node = list(dq_node.users)[0] |
| 101 | + n.args[0].meta.setdefault(QCOM_REQUANTIZE, {}) |
| 102 | + n.args[0].meta[QCOM_REQUANTIZE][user_node.name] = dq_attrs |
115 | 103 |
|
116 | 104 | # Dequant all the fold_quant parameters back to fp32. |
117 | 105 | # If an operation is not supported by QNN and got fallback, it will expect a fp32 param. |
|
0 commit comments