Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions docs/advanced/pquant.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
======================================
PQuantML
======================================

.. image:: https://img.shields.io/badge/License-Apache_2.0-blue.svg
:target: https://www.apache.org/licenses/LICENSE-2.0
.. image:: https://github.com/nroope/PQuant/actions/workflows/python-publish.yml/badge.svg
:target: https://pquantml.readthedocs.io
.. image:: https://badge.fury.io/py/pquant-ml.svg
:target: https://badge.fury.io/py/pquant-ml

PQuantML is a hardware-aware model compression framework supporting:
- Joint pruning + quantization
- Layer-wise precision configuration
- Flexible training pipelines
- PyTorch and Keras V3 implementations
- Integration with hardware-friendly toolchains (e.g., hls4ml)

PQuantML enables efficient deployment of compact neural networks on resource-constrained hardware such as FPGAs and embedded accelerators.


Key Features
------------

Check warning on line 23 in docs/advanced/pquant.rst

View workflow job for this annotation

GitHub Actions / build

duplicate label key features, other instance in /github/workspace/docs/advanced/hgq.rst

- **Joint Quantization + Pruning**: Combine bit-width reduction with structured pruning.
- **Flexible Precision Control**: Per-layer and mixed-precision configuration.
- **Hardware-Aware Objective**: Include resource constraints (DSP, LUT, BRAM) in training.
- **Simple API**: Configure compression through a single YAML or Python object.
- **PyTorch Integration**: Works with custom training/validation loops.
- **Export Support**: Model conversion towards hardware toolchains.


.. code-block:: python
:caption: Simple example

import torch
from pquant import dst_config
from pquant.layers import PQDense
from pquant.activations import PQActivation

# Define the compression config and model
config = dst_config()
config.training_parameters.epochs = 1000
config.quantization_parameters.default_data_integer_bit = 3.
config.quantization_parameters.default_data_fractional_bits = 2.
config.quantization_parameters.default_weight_fractional_bits = 3.
config.quantization_parameters.use_relu_multiplier = False

def build_model(config):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.dense1 = PQDense(config, 16, 64,
in_quant_bits = (1, 3, 3))
self.relu1 = PQActivation(config, "relu")
self.dense2 = PQDense(config, 64, 32)
self.relu2 = PQActivation(config, "relu")
self.dense3 = PQDense(config, 32, 32)
self.relu3 = PQActivation(config, "relu")
self.dense4 = PQDense(config, 32, 5,
quantize_output=True,
out_quant_bits=(1, 3, 3))

def forward(self, x):
x = self.relu1(self.dense1(x))
x = self.relu2(self.dense2(x))
x = self.relu3(self.dense3(x))
x = self.dense4(x)
return x

return Model(config)

PQmodel = build_model(config)
PQmodel(torch.rand((1, 16)))

... # Training, evaluation, and anything else you want to do with the model

hls_config = config_from_pytorch_model(
PQmodel,
input_shape=input_shape,
)
hls_model = convert_from_pytorch_model(PQmodel, ...)
# Model-wise precision propagation is done automatically for PQuantML models for bit-exactness
# Do NOT pass precision config if you don't know what you are doing

hls_model.compile()

.. note::
Do not pass any precision configuration from ``hls4ml.converters.convert_from_<frontend>_model`` in general. PQuantML-defined models will invoke model-wise precision propagation automatically to ensure bit-exactness between the PQuantML model and the generated HLS code (See `here <./precision.html>`__ for more details).
1 change: 1 addition & 0 deletions hls4ml/converters/keras_v3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
hgq2, # noqa: F401
merge, # noqa: F401
pooling, # noqa: F401
pquant, # noqa: F401
recurrent, # noqa: F401
)
from ._base import registry as layer_handlers
Expand Down
3 changes: 3 additions & 0 deletions hls4ml/converters/keras_v3/pquant/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from . import _base, pooling

__all__ = ['_base', 'pooling']
235 changes: 235 additions & 0 deletions hls4ml/converters/keras_v3/pquant/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
from collections.abc import Sequence
from math import prod
from typing import TYPE_CHECKING, Any

import numpy as np

from hls4ml.converters.keras_v3._base import KerasV3LayerHandler, register
from hls4ml.converters.keras_v3.conv import ConvHandler
from hls4ml.converters.keras_v3.core import ActivationHandler, DenseHandler
from hls4ml.converters.keras_v3.hgq2._base import override_io_tensor_confs

if TYPE_CHECKING:
import pquant
from keras import KerasTensor
from keras.src.layers.layer import Layer as Layer


def extract_quantizer_config(
q, extract_kif, tensor: 'KerasTensor', is_input: bool, overflow_attr: str = 'overflow_mode'
) -> dict[str, Any]:
from keras import ops

shape: tuple[int, ...] = tensor.shape[1:] # type: ignore
if any([s is None for s in shape]):
raise ValueError(f'Tensor {tensor.name} has at least one dimension with no fixed size')

k, i, f = extract_kif(q)
k, B, I = k, k + i + f, k + i # type: ignore # noqa: E741
k, B, I = ops.convert_to_numpy(k), ops.convert_to_numpy(B), ops.convert_to_numpy(I) # noqa: E741
I = np.where(B > 0, I, 0) # noqa: E741 # type: ignore

k = np.broadcast_to(k.astype(np.int16), (1,) + shape) # type: ignore
B = np.broadcast_to(B.astype(np.int16), (1,) + shape) # type: ignore
I = np.broadcast_to(I.astype(np.int16), (1,) + shape) # noqa: E741

overflow_mode: str = getattr(q, overflow_attr, 'SAT')
round_mode: str = q.round_mode
if round_mode.startswith('S_'):
round_mode = round_mode[2:]
fusible = np.unique(k).size == 1 and np.unique(B).size == 1 and np.unique(I).size == 1

input_keras_tensor_names = tensor.name if is_input else f'{tensor.name}_q'
output_keras_tensor_names = f'{tensor.name}_q' if is_input else tensor.name
return {
'name': q.name,
'class_name': 'FixedPointQuantizer',
'mask_kbi': (k, B, I),
'SAT': overflow_mode,
'RND': round_mode,
'fusible': fusible,
'input_keras_tensor_names': [input_keras_tensor_names],
'output_keras_tensor_names': [output_keras_tensor_names],
'overrides': {},
}


def extract_pquant_quantizer_config(q, tensor: 'KerasTensor', is_input: bool) -> dict[str, Any]:
from pquant.quantizer import Quantizer

if not isinstance(q, Quantizer):
raise TypeError(f'Quantizer {type(q).__name__} ({q.__module__}) is not an instance of any allowed Quantizer class.')

if q.use_hgq:
return extract_quantizer_config(q.quantizer.quantizer, lambda q: q.kif, tensor, is_input)
else:
return extract_quantizer_config(q, lambda q: (q.k, q.i, q.f), tensor, is_input, 'overflow')


@register
class PQLayerHandler(KerasV3LayerHandler):
def __call__(
self,
layer: (
'pquant.core.keras.layers.PQWeightBiasBase | '
'pquant.core.keras.layers.PQBatchNormalization | '
'pquant.core.keras.layers.QuantizedPooling | '
'pquant.core.keras.layers.QuantizedActivation'
),
in_tensors: Sequence['KerasTensor'],
out_tensors: Sequence['KerasTensor'],
):
ret = super().__call__(layer, in_tensors, out_tensors)

if getattr(layer, 'quantize_input', False) and hasattr(layer, 'input_quantizer'):
if len(in_tensors) > 1:
iq_confs = [
extract_pquant_quantizer_config(q, tensor, True) for q, tensor in zip(layer.input_quantizer, in_tensors)
]
else:
iq_confs = [extract_pquant_quantizer_config(layer.input_quantizer, in_tensors[0], True)]
else:
iq_confs = ()

if getattr(layer, 'quantize_output', False) and hasattr(layer, 'output_quantizer'):
if len(out_tensors) > 1:
oq_confs = [
extract_pquant_quantizer_config(q, tensor, False)
for q, tensor in zip(layer.output_quantizer, out_tensors)
]
else:
oq_confs = [extract_pquant_quantizer_config(layer.output_quantizer, out_tensors[0], False)]
else:
oq_confs = ()

if iq_confs:
_froms = [t.name for t in in_tensors]
_tos = [f'{t.name}_q' for t in in_tensors]
overrides = dict(zip(_froms, _tos))
override_io_tensor_confs(ret, overrides)

if oq_confs:
_froms = [t.name for t in out_tensors]
_tos = [f'{t.name}_q' for t in out_tensors]
overrides = dict(zip(_froms, _tos))
override_io_tensor_confs(ret, overrides)

return *iq_confs, *ret, *oq_confs

def load_weight(self, layer: 'Layer', key: str):
from keras import ops

if hasattr(layer, f'q{key}'):
return ops.convert_to_numpy(getattr(layer, f'q{key}'))
return super().load_weight(layer, key)

def default_class_name(self, layer: 'Layer') -> str:
class_name = layer.__class__.__name__
if class_name.startswith('PQ'):
class_name = class_name[2:]
return class_name


@register
class PQActivationHandler(PQLayerHandler, ActivationHandler):
handles = ('pquant.core.keras.activations.PQActivation',)

def handle(
self,
layer: 'pquant.core.keras.activations.PQActivation',
in_tensors: Sequence['KerasTensor'],
out_tensors: Sequence['KerasTensor'],
):
config = {}
config.update(self.default_config)

activation = getattr(layer, 'activation_name', 'linear')
match activation:
case 'hard_tanh':
class_name = 'HardActivation'
case _:
class_name = 'Activation'

config['activation'] = activation
config['class_name'] = class_name
config['n_in'] = prod(in_tensors[0].shape[1:]) # type: ignore
return (config,)


@register
class PQBatchNormalizationHandler(PQLayerHandler):
handles = ('pquant.core.keras.layers.PQBatchNormalization',)

def handle(
self,
layer: 'pquant.core.keras.layers.PQBatchNormalization',
in_tensors: Sequence['KerasTensor'],
out_tensors: Sequence['KerasTensor'],
):
from keras import ops

assert layer.axis in (len(in_tensors[0].shape) - 1, -1), 'Only batch_norm with axis=-1 is supported in hls4ml'

conf = {}
conf['class_name'] = layer.__class__.__name__[1:]
conf['n_in'] = prod(in_tensors[0].shape[1:])

conf['use_gamma'] = layer.scale
if conf['use_gamma']:
conf['gamma_data'] = ops.convert_to_numpy(layer.weight_quantizer(layer.gamma))
else:
conf['gamma_data'] = 1

conf['use_beta'] = layer.center
if conf['use_beta']:
conf['beta_data'] = ops.convert_to_numpy(layer.bias_quantizer(layer.beta))
else:
conf['beta_data'] = 0

conf['mean_data'] = ops.convert_to_numpy(layer.moving_mean)
conf['variance_data'] = ops.convert_to_numpy(layer.moving_variance)
conf['n_filt'] = conf['variance_data'].size

return conf


@register
class PQConvHandler(PQLayerHandler, ConvHandler):
handles = ('pquant.core.keras.layers.PQConv1d', 'pquant.core.keras.layers.PQConv2d')

def handle(
self,
layer: 'pquant.core.keras.layers.PQConv1D | pquant.core.keras.layers.PQConv2D',
in_tensors: Sequence['KerasTensor'],
out_tensors: Sequence['KerasTensor'],
):
conf = super().handle(layer, in_tensors, out_tensors)
conf['class_name'] = layer.__class__.__name__[1:-1] + 'D'
pf = layer.parallelization_factor
out_shape: tuple[int, ...] = out_tensors[0].shape[1:] # type: ignore
if pf < 0:
if layer.data_format == 'channels_last':
pf = prod(out_shape[:-1])
else:
pf = prod(out_shape[1:])
conf['parallelization_factor'] = pf
return conf


@register
class PQDenseHandler(PQLayerHandler, DenseHandler):
handles = ('pquant.core.keras.layers.PQDense',)

def handle(
self,
layer: 'pquant.core.keras.layers.PQDense',
in_tensors: Sequence['KerasTensor'],
out_tensors: Sequence['KerasTensor'],
):
conf = super().handle(layer, in_tensors, out_tensors)
conf['class_name'] = 'Dense'
in_shape: tuple[int, ...] = in_tensors[0].shape[1:] # type: ignore
if len(in_shape) > 1:
pf = layer.parallelization_factor
conf['parallelization_factor'] = pf
return conf
30 changes: 30 additions & 0 deletions hls4ml/converters/keras_v3/pquant/pooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING

from hls4ml.converters.keras_v3._base import register
from hls4ml.converters.keras_v3.pooling import PoolingHandler

from ._base import PQLayerHandler

if TYPE_CHECKING:
import pquant
from keras import KerasTensor


@register
class PQAvgPoolHandler(PQLayerHandler, PoolingHandler):
handles = (
'pquant.core.keras.layers.PQAvgPool1d',
'pquant.core.keras.layers.PQAvgPool2d',
)

def handle(
self,
layer: 'pquant.core.keras.layers.PQAvgPool1d | pquant.core.keras.layers.PQAvgPool2d',
in_tensors: Sequence['KerasTensor'],
out_tensors: Sequence['KerasTensor'],
):
conf = super().handle(layer, in_tensors, out_tensors)
conf['class_name'] = 'AveragePooling' + layer.__class__.__name__[-2] + 'D'

return conf
Loading
Loading