Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion backends/vulkan/quantizer/vulkan_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,22 @@ class VulkanQuantizer(Quantizer):
def __init__(self) -> None:
super().__init__()
self.global_config: Optional[QuantizationConfig] = None
# If specified, only quantize nodes that return true for the filter
# function.
self.filter_fn: Optional[Callable[[Node], bool]] = None

def set_global(self, quantization_config: QuantizationConfig) -> VulkanQuantizer:
self.global_config = quantization_config
return self

def set_filter_function(self, filter_fn: Callable[[Node], bool]):
"""
Set the filter function. We only quantize nodes that return True for
the filter function.
"""
self.filter_fn = filter_fn
return self

def transform_for_annotation(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
Expand All @@ -149,8 +160,14 @@ def _annotate_all_patterns(
if quantization_config is None:
return model

# Create a combined filter function, which returns True only when
# both filter_fn and self.filter_fn return True.
def combined_filter_fn(n: Node) -> bool:
combined_filter = [self.filter_fn, filter_fn]
return all(f(n) for f in combined_filter if f is not None)

for op in _SUPPORTED_OPS:
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
OP_TO_ANNOTATOR[op](model, quantization_config, combined_filter_fn)
return model

def _annotate_for_quantization_config(
Expand Down
Loading