diff --git a/hls4ml/converters/keras_v3/__init__.py b/hls4ml/converters/keras_v3/__init__.py index 21950aea6..4f9922531 100644 --- a/hls4ml/converters/keras_v3/__init__.py +++ b/hls4ml/converters/keras_v3/__init__.py @@ -5,6 +5,7 @@ hgq2, # noqa: F401 merge, # noqa: F401 pooling, # noqa: F401 + pquant, # noqa: F401 recurrent, # noqa: F401 ) from ._base import registry as layer_handlers diff --git a/hls4ml/converters/keras_v3/pquant/__init__.py b/hls4ml/converters/keras_v3/pquant/__init__.py new file mode 100644 index 000000000..0a7ac5bb6 --- /dev/null +++ b/hls4ml/converters/keras_v3/pquant/__init__.py @@ -0,0 +1,3 @@ +from . import _base, pooling + +__all__ = ['_base', 'pooling'] diff --git a/hls4ml/converters/keras_v3/pquant/_base.py b/hls4ml/converters/keras_v3/pquant/_base.py new file mode 100644 index 000000000..a972d195a --- /dev/null +++ b/hls4ml/converters/keras_v3/pquant/_base.py @@ -0,0 +1,235 @@ +from collections.abc import Sequence +from math import prod +from typing import TYPE_CHECKING, Any + +import numpy as np + +from hls4ml.converters.keras_v3._base import KerasV3LayerHandler, register +from hls4ml.converters.keras_v3.conv import ConvHandler +from hls4ml.converters.keras_v3.core import ActivationHandler, DenseHandler +from hls4ml.converters.keras_v3.hgq2._base import override_io_tensor_confs + +if TYPE_CHECKING: + import pquant + from keras import KerasTensor + from keras.src.layers.layer import Layer as Layer + + +def extract_quantizer_config( + q, extract_kif, tensor: 'KerasTensor', is_input: bool, overflow_attr: str = 'overflow_mode' +) -> dict[str, Any]: + from keras import ops + + shape: tuple[int, ...] = tensor.shape[1:] # type: ignore + if any([s is None for s in shape]): + raise ValueError(f'Tensor {tensor.name} has at least one dimension with no fixed size') + + k, i, f = extract_kif(q) + k, B, I = k, k + i + f, k + i # type: ignore # noqa: E741 + k, B, I = ops.convert_to_numpy(k), ops.convert_to_numpy(B), ops.convert_to_numpy(I) # noqa: E741 + I = np.where(B > 0, I, 0) # noqa: E741 # type: ignore + + k = np.broadcast_to(k.astype(np.int16), (1,) + shape) # type: ignore + B = np.broadcast_to(B.astype(np.int16), (1,) + shape) # type: ignore + I = np.broadcast_to(I.astype(np.int16), (1,) + shape) # noqa: E741 + + overflow_mode: str = getattr(q, overflow_attr, 'SAT') + round_mode: str = q.round_mode + if round_mode.startswith('S_'): + round_mode = round_mode[2:] + fusible = np.unique(k).size == 1 and np.unique(B).size == 1 and np.unique(I).size == 1 + + input_keras_tensor_names = tensor.name if is_input else f'{tensor.name}_q' + output_keras_tensor_names = f'{tensor.name}_q' if is_input else tensor.name + return { + 'name': q.name, + 'class_name': 'FixedPointQuantizer', + 'mask_kbi': (k, B, I), + 'SAT': overflow_mode, + 'RND': round_mode, + 'fusible': fusible, + 'input_keras_tensor_names': [input_keras_tensor_names], + 'output_keras_tensor_names': [output_keras_tensor_names], + 'overrides': {}, + } + + +def extract_pquant_quantizer_config(q, tensor: 'KerasTensor', is_input: bool) -> dict[str, Any]: + from pquant.quantizer import Quantizer + + if not isinstance(q, Quantizer): + raise TypeError(f'Quantizer {type(q).__name__} ({q.__module__}) is not an instance of any allowed Quantizer class.') + + if q.use_hgq: + return extract_quantizer_config(q.quantizer.quantizer, lambda q: q.kif, tensor, is_input) + else: + return extract_quantizer_config(q, lambda q: (q.k, q.i, q.f), tensor, is_input, 'overflow') + + +@register +class PQLayerHandler(KerasV3LayerHandler): + def __call__( + self, + layer: ( + 'pquant.core.keras.layers.PQWeightBiasBase | ' + 'pquant.core.keras.layers.PQBatchNormalization | ' + 'pquant.core.keras.layers.QuantizedPooling | ' + 'pquant.core.keras.layers.QuantizedActivation' + ), + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + ret = super().__call__(layer, in_tensors, out_tensors) + + if getattr(layer, 'quantize_input', False) and hasattr(layer, 'input_quantizer'): + if len(in_tensors) > 1: + iq_confs = [ + extract_pquant_quantizer_config(q, tensor, True) for q, tensor in zip(layer.input_quantizer, in_tensors) + ] + else: + iq_confs = [extract_pquant_quantizer_config(layer.input_quantizer, in_tensors[0], True)] + else: + iq_confs = () + + if getattr(layer, 'quantize_output', False) and hasattr(layer, 'output_quantizer'): + if len(out_tensors) > 1: + oq_confs = [ + extract_pquant_quantizer_config(q, tensor, False) + for q, tensor in zip(layer.output_quantizer, out_tensors) + ] + else: + oq_confs = [extract_pquant_quantizer_config(layer.output_quantizer, out_tensors[0], False)] + else: + oq_confs = () + + if iq_confs: + _froms = [t.name for t in in_tensors] + _tos = [f'{t.name}_q' for t in in_tensors] + overrides = dict(zip(_froms, _tos)) + override_io_tensor_confs(ret, overrides) + + if oq_confs: + _froms = [t.name for t in out_tensors] + _tos = [f'{t.name}_q' for t in out_tensors] + overrides = dict(zip(_froms, _tos)) + override_io_tensor_confs(ret, overrides) + + return *iq_confs, *ret, *oq_confs + + def load_weight(self, layer: 'Layer', key: str): + from keras import ops + + if hasattr(layer, f'q{key}'): + return ops.convert_to_numpy(getattr(layer, f'q{key}')) + return super().load_weight(layer, key) + + def default_class_name(self, layer: 'Layer') -> str: + class_name = layer.__class__.__name__ + if class_name.startswith('PQ'): + class_name = class_name[2:] + return class_name + + +@register +class PQActivationHandler(PQLayerHandler, ActivationHandler): + handles = ('pquant.core.keras.activations.PQActivation',) + + def handle( + self, + layer: 'pquant.core.keras.activations.PQActivation', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + config = {} + config.update(self.default_config) + + activation = getattr(layer, 'activation_name', 'linear') + match activation: + case 'hard_tanh': + class_name = 'HardActivation' + case _: + class_name = 'Activation' + + config['activation'] = activation + config['class_name'] = class_name + config['n_in'] = prod(in_tensors[0].shape[1:]) # type: ignore + return (config,) + + +@register +class PQBatchNormalizationHandler(PQLayerHandler): + handles = ('pquant.core.keras.layers.PQBatchNormalization',) + + def handle( + self, + layer: 'pquant.core.keras.layers.PQBatchNormalization', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + from keras import ops + + assert layer.axis in (len(in_tensors[0].shape) - 1, -1), 'Only batch_norm with axis=-1 is supported in hls4ml' + + conf = {} + conf['class_name'] = layer.__class__.__name__[1:] + conf['n_in'] = prod(in_tensors[0].shape[1:]) + + conf['use_gamma'] = layer.scale + if conf['use_gamma']: + conf['gamma_data'] = ops.convert_to_numpy(layer.weight_quantizer(layer.gamma)) + else: + conf['gamma_data'] = 1 + + conf['use_beta'] = layer.center + if conf['use_beta']: + conf['beta_data'] = ops.convert_to_numpy(layer.bias_quantizer(layer.beta)) + else: + conf['beta_data'] = 0 + + conf['mean_data'] = ops.convert_to_numpy(layer.moving_mean) + conf['variance_data'] = ops.convert_to_numpy(layer.moving_variance) + conf['n_filt'] = conf['variance_data'].size + + return conf + + +@register +class PQConvHandler(PQLayerHandler, ConvHandler): + handles = ('pquant.core.keras.layers.PQConv1d', 'pquant.core.keras.layers.PQConv2d') + + def handle( + self, + layer: 'pquant.core.keras.layers.PQConv1D | pquant.core.keras.layers.PQConv2D', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + conf = super().handle(layer, in_tensors, out_tensors) + conf['class_name'] = layer.__class__.__name__[1:-1] + 'D' + pf = layer.parallelization_factor + out_shape: tuple[int, ...] = out_tensors[0].shape[1:] # type: ignore + if pf < 0: + if layer.data_format == 'channels_last': + pf = prod(out_shape[:-1]) + else: + pf = prod(out_shape[1:]) + conf['parallelization_factor'] = pf + return conf + + +@register +class PQDenseHandler(PQLayerHandler, DenseHandler): + handles = ('pquant.core.keras.layers.PQDense',) + + def handle( + self, + layer: 'pquant.core.keras.layers.PQDense', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + conf = super().handle(layer, in_tensors, out_tensors) + conf['class_name'] = 'Dense' + in_shape: tuple[int, ...] = in_tensors[0].shape[1:] # type: ignore + if len(in_shape) > 1: + pf = layer.parallelization_factor + conf['parallelization_factor'] = pf + return conf diff --git a/hls4ml/converters/keras_v3/pquant/pooling.py b/hls4ml/converters/keras_v3/pquant/pooling.py new file mode 100644 index 000000000..5625502cf --- /dev/null +++ b/hls4ml/converters/keras_v3/pquant/pooling.py @@ -0,0 +1,30 @@ +from collections.abc import Sequence +from typing import TYPE_CHECKING + +from hls4ml.converters.keras_v3._base import register +from hls4ml.converters.keras_v3.pooling import PoolingHandler + +from ._base import PQLayerHandler + +if TYPE_CHECKING: + import pquant + from keras import KerasTensor + + +@register +class PQAvgPoolHandler(PQLayerHandler, PoolingHandler): + handles = ( + 'pquant.core.keras.layers.PQAvgPool1d', + 'pquant.core.keras.layers.PQAvgPool2d', + ) + + def handle( + self, + layer: 'pquant.core.keras.layers.PQAvgPool1d | pquant.core.keras.layers.PQAvgPool2d', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + conf = super().handle(layer, in_tensors, out_tensors) + conf['class_name'] = 'AveragePooling' + layer.__class__.__name__[-2] + 'D' + + return conf diff --git a/hls4ml/converters/pytorch/pquant.py b/hls4ml/converters/pytorch/pquant.py new file mode 100644 index 000000000..14502f995 --- /dev/null +++ b/hls4ml/converters/pytorch/pquant.py @@ -0,0 +1,154 @@ +from collections.abc import Iterable +from warnings import warn + +import numpy as np + +from hls4ml.converters.pytorch.convolution import parse_conv1d_layer, parse_conv2d_layer +from hls4ml.converters.pytorch.core import parse_batchnorm_layer, parse_linear_layer +from hls4ml.converters.pytorch.pooling import parse_pooling_layer +from hls4ml.converters.pytorch_to_hls import pytorch_handler +from hls4ml.model.types import FixedPrecisionType + + +def extract_fixed_quantizer_config(q, shape, input, name): + q_params = q._parameters + + shape = tuple(shape[1:]) # type: ignore + print(f'FixedPointQuantizer shape: {shape}') + if any([s is None for s in shape]): + raise ValueError(f'Tensor {input} has at least one dimension with no fixed size') + k, i, f = q_params['k'].data, q_params['i'].data, q_params['f'].data + k, B, I = k, k + i + f, k + i # type: ignore # noqa: E741 + k, B, I = k.detach().cpu().numpy(), B.detach().cpu().numpy(), I.detach().cpu().numpy() # noqa: E741 + I = np.where(B > 0, I, 0) # noqa: E741 # type: ignore + + k = np.broadcast_to(k.astype(np.int16), (1,) + shape) # type: ignore + B = np.broadcast_to(B.astype(np.int16), (1,) + shape) # type: ignore + I = np.broadcast_to(I.astype(np.int16), (1,) + shape) # noqa: E741 + + overflow_mode: str = q.overflow + round_mode: str = q.round_mode + if round_mode.startswith('S_'): + round_mode = round_mode[2:] + fusible = np.unique(k).size == 1 and np.unique(B).size == 1 and np.unique(I).size == 1 + + return { + 'name': name, + 'inputs': [input], + 'class_name': 'FixedPointQuantizer', + 'mask_kbi': (k, B, I), + 'SAT': overflow_mode, + 'RND': round_mode, + 'fusible': fusible, + 'overrides': {}, + } + + +def add_quantizer_info(class_object, input_names, input_shapes, output_shape, layer): + if getattr(class_object, 'quantize_input', False) and hasattr(class_object, 'input_quantizer'): + if isinstance(class_object.input_quantizer, Iterable): + iq_confs = [ + extract_fixed_quantizer_config(q, shape, input, f'{layer["name"]}_iq_{i}') + for q, shape, input, i in zip( + class_object.input_quantizer, input_shapes, input_names, [k for k in range(len(input_names))] + ) + ] + else: + iq_confs = [ + extract_fixed_quantizer_config( + class_object.input_quantizer, input_shapes[0], input_names[0], f'{layer["name"]}_iq' + ) + ] + layer['inputs'] = [q['name'] for q in iq_confs] + iq_shapes = input_shapes + else: + iq_confs = [] + iq_shapes = [] + + if getattr(class_object, 'quantize_output', False) and hasattr(class_object, 'output_quantizer'): + if isinstance(class_object.output_quantizer, Iterable): + oq_confs = [ + extract_fixed_quantizer_config(q, output_shape, layer['name'], f'{layer["name"]}_oq_{i}') + for q, i in zip(class_object.output_quantizer, [k for k in range(len(class_object.output_quantizer))]) + ] + oq_shapes = [output_shape for _ in len(class_object.output_quantizer)] + else: + oq_confs = [ + extract_fixed_quantizer_config( + class_object.output_quantizer, output_shape, layer['name'], f'{layer["name"]}_oq' + ) + ] + oq_shapes = [output_shape] + else: + oq_confs = [] + oq_shapes = [] + + out_shapes = [] + if iq_shapes: + out_shapes.append(iq_shapes) + out_shapes.append(output_shape) + if oq_shapes: + out_shapes.append(oq_shapes) + + return iq_confs + [layer] + oq_confs, iq_shapes + [output_shape] + oq_shapes + + +def make_pquant_handler(base_parse_func, op, op_check=None): + if op_check is None: + op_check = op + + @pytorch_handler(op) + def handler(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): + assert op in operation + layer, output_shape = base_parse_func( + op_check, layer_name, input_names, input_shapes, node, class_object, data_reader, config + ) + layers, output_shapes = add_quantizer_info(class_object, input_names, input_shapes, output_shape, layer) + return layers, output_shapes + + handler.__name__ = f'parse_{op.lower()}_layer' + return handler + + +parse_pqlinear_layer = make_pquant_handler(parse_linear_layer, 'PQDense', 'PQLinear') +parse_pqbatchnorm_layer = make_pquant_handler(parse_batchnorm_layer, 'PQBatchNorm2d') +parse_pqconv1d_layer = make_pquant_handler(parse_conv1d_layer, 'PQConv1d') +parse_pqconv2d_layer = make_pquant_handler(parse_conv2d_layer, 'PQConv2d') +parse_pqpool1d_layer = make_pquant_handler(parse_pooling_layer, 'PQAvgPool1d', 'AvgPool1d') +parse_pqpool2d_layer = make_pquant_handler(parse_pooling_layer, 'PQAvgPool2d', 'AvgPool2d') + + +def parse_quant_activation_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): + layer = {} + + layer['activation'] = class_object.activation_name + + print(f'Parsing activation: {layer["activation"]}') + + layer['name'] = layer_name + layer['inputs'] = input_names + + if layer['activation'] == 'hard_tanh': + layer['class_name'] = 'HardActivation' + layer['slope'] = 0.5 + layer['shift'] = 0.5 + layer['slope_prec'] = FixedPrecisionType(width=2, integer=0, signed=False) + layer['shift_prec'] = FixedPrecisionType(width=2, integer=0, signed=False) + warn(f'Hard Tanh activation {layer_name} is currently not supported for bit-exactness.') + + elif layer['activation'] == 'relu' and class_object.use_multiplier: + raise Exception('hls4ml does not currently support activations with multiplier') + """ + layer['activation'] = 'multiplier_relu' + layer['class_name'] = 'MultiplierReLU' + layer['param_data'] = class_object.multiplier.data.numpy() + """ + + else: + layer['class_name'] = 'Activation' + + output_shape = input_shapes[0] + return layer, output_shape + + +parse_pqactivation_layer = make_pquant_handler(parse_quant_activation_layer, 'PQActivation') diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index 5399cf37c..9431650a1 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -274,20 +274,41 @@ def resolve_getitem_source(node_name, visited=None): pytorch_class, layer_name, input_names, input_shapes, node, class_object, reader, config ) - if verbose: - print( - 'Layer name: {}, layer type: {}, input shape: {}'.format( - layer['name'], - layer['class_name'], - input_shapes, + if isinstance(layer, dict): + if verbose: + print( + 'Layer name: {}, layer type: {}, input shape: {}'.format( + layer['name'], + layer['class_name'], + input_shapes, + ) ) - ) - layer_list.append(layer) + layer_list.append(layer) - assert output_shape is not None - output_shapes[layer['name']] = output_shape + assert output_shape is not None + output_shapes[layer['name']] = output_shape - layer_counter += 1 + layer_counter += 1 + + else: + for idx, (lay, out_shape) in enumerate(zip(layer, output_shape)): + if verbose: + print( + 'Layer name: {}, layer type: {}, input shape: {}'.format( + lay['name'], + lay['class_name'], + input_shapes, + ) + ) + layer_list.append(lay) + + if idx < len(layer) - 1: + inputs_map[lay['name']] = inputs_map.get(layer[idx + 1]['name'], layer[idx + 1]['name']) + + assert out_shape is not None + output_shapes[lay['name']] = out_shape + + layer_counter += 1 if node.op == 'placeholder': # 'placeholder' indicates an input layer. Multiple inputs are supported diff --git a/hls4ml/model/optimizer/passes/convert_to_channels_last.py b/hls4ml/model/optimizer/passes/convert_to_channels_last.py index cc3b6d0e1..ff35c9518 100644 --- a/hls4ml/model/optimizer/passes/convert_to_channels_last.py +++ b/hls4ml/model/optimizer/passes/convert_to_channels_last.py @@ -2,8 +2,11 @@ # Based on https://github.com/fastmachinelearning/qonnx/blob/ # 12c96a3ded06beacab08e0f554e4ed014476c0aa/src/qonnx/transformation/channels_last.py +import numpy as np + from hls4ml.model.layers import GRU, LSTM, Concatenate, Dense, Input, LayerNormalization, Reshape, Transpose from hls4ml.model.optimizer import OptimizerPass +from hls4ml.model.optimizer.passes.hgq_proxy_model import FixedPointQuantizer from hls4ml.model.types import WeightVariable @@ -62,21 +65,28 @@ def transform(self, model, node): elif isinstance(node, LSTM) or isinstance(node, GRU): pass else: - # Transpose weight tensors - tensors = ['weight', 'depthwise', 'pointwise', 'zero_bias', 'scale', 'recurrent_weight'] - for tensor in tensors: - try: - if len(node.get_weights(tensor).shape) == 2: - weights_channels_last = node.get_weights(tensor).data.transpose() - node.get_weights(tensor).data = weights_channels_last - elif len(node.get_weights(tensor).shape) == 3: - weights_channels_last = node.get_weights(tensor).data.transpose([2, 1, 0]) - node.get_weights(tensor).data = weights_channels_last - elif len(node.get_weights(tensor).shape) == 4: - weights_channels_last = node.get_weights(tensor).data.transpose([2, 3, 1, 0]) - node.get_weights(tensor).data = weights_channels_last - except KeyError: - pass + if isinstance(node, FixedPointQuantizer): + transpose_map = {3: (0, 2, 1), 4: (0, 2, 3, 1), 5: (0, 2, 3, 4, 1)} + node.mask_kbi = tuple( + np.transpose(t, transpose_map[t.ndim]) if t.ndim in transpose_map else t for t in node.mask_kbi + ) + else: + # Transpose weight tensors + tensors = ['weight', 'depthwise', 'pointwise', 'zero_bias', 'scale', 'recurrent_weight'] + for tensor in tensors: + try: + t_shape = node.get_weights(tensor).shape + if len(t_shape) == 2: + weights_channels_last = node.get_weights(tensor).data.transpose() + node.get_weights(tensor).data = weights_channels_last + elif len(t_shape) == 3: + weights_channels_last = node.get_weights(tensor).data.transpose([2, 1, 0]) + node.get_weights(tensor).data = weights_channels_last + elif len(t_shape) == 4: + weights_channels_last = node.get_weights(tensor).data.transpose([2, 3, 1, 0]) + node.get_weights(tensor).data = weights_channels_last + except KeyError: + pass try: node.set_attr('data_format', 'channels_last') except AttributeError: diff --git a/hls4ml/utils/torch.py b/hls4ml/utils/torch.py index 25d2754b1..71d97dfaf 100644 --- a/hls4ml/utils/torch.py +++ b/hls4ml/utils/torch.py @@ -22,4 +22,5 @@ def is_leaf_module(self, m, module_qualified_name: str) -> bool: or m.__module__.startswith('torch.nn') or m.__module__.startswith('torch.ao.nn') or m.__module__.startswith('brevitas.nn') + or m.__module__.startswith('pquant.core') ) and not isinstance(m, torch.nn.Sequential) diff --git a/test/pytest/test_pquant_keras.py b/test/pytest/test_pquant_keras.py new file mode 100644 index 000000000..7d985caa8 --- /dev/null +++ b/test/pytest/test_pquant_keras.py @@ -0,0 +1,174 @@ +import os +from pathlib import Path + +import numpy as np +import pytest +from pquant.activations import PQActivation +from pquant.core.finetuning import TuningConfig +from pquant.core.utils import get_default_config +from pquant.layers import PQAvgPool1d, PQAvgPool2d, PQBatchNormalization, PQConv1d, PQConv2d, PQDense + +from hls4ml.converters import convert_from_keras_model +from hls4ml.utils import config_from_keras_model + +os.environ['KERAS_BACKEND'] = 'tensorflow' +import keras # noqa: E402 + +test_path = Path(__file__).parent + + +def _run_synth_match_test(PQmodel: keras.Model, data, io_type: str, backend: str, dir: str, cond=None, strategy='latency'): + output_dir = dir + '/hls4ml_prj' + hls_config = config_from_keras_model( + PQmodel, + granularity='name', + default_precision='ap_fixed<32, 16>', + backend=backend, + ) + hls_model = convert_from_keras_model( + PQmodel, + io_type=io_type, + output_dir=output_dir, + backend=backend, + hls_config=hls_config, + ) + hls_model.compile() + + data_len = data.shape[0] if isinstance(data, np.ndarray) else data[0].shape[0] + r_pq: list[np.ndarray] = [PQmodel(data).numpy()] # type: ignore + r_hls: list[np.ndarray] = [hls_model.predict(np.ascontiguousarray(data)).reshape(r_pq[0].shape)] # type: ignore + + errors = [] + for i, (p, h) in enumerate(zip(r_pq, r_hls)): + try: + if cond is None: + mismatch_ph = p != h + assert np.sum(mismatch_ph) == 0, ( + f'Proxy-HLS4ML mismatch for out {i}: {np.sum(np.any(mismatch_ph, axis=1))} out of {data_len} samples are different. Sample: {p[mismatch_ph].ravel()[:5]} vs {h[mismatch_ph].ravel()[:5]}' # noqa: E501 + ) + else: + cond(p, h) + except AssertionError as e: + errors.append(e) + if len(errors) > 0: + msgs = [str(e) for e in errors] + raise AssertionError('\n'.join(msgs)) + + +def run_model_test( + PQmodel: keras.Model, + data, + io_type: str, + backend: str, + dir: str, + cond=None, + strategy='latency', +): + _run_synth_match_test(PQmodel, data, io_type, backend, dir, cond=cond, strategy=strategy) + + +def create_pqlayer_model(layer: str, use_hgq: bool): + config = get_default_config('pdp') + config['pruning_parameters']['disable_pruning_for_layers'] = [''] + config['quantization_parameters']['use_high_granularity_quantization'] = use_hgq + config = TuningConfig.load_from_config(config) + + idx = layer.find('(') + 1 + layer = ( + layer[:idx] + + 'config, ' + + layer[idx:-1] + + (', quantize_output=True, out_quant_bits=(1., 2., 7.)' if 'BatchNorm' not in layer else '') + + ')' + ) + _layer = eval(layer) + + shape = get_shape(_layer) + inp = keras.Input(shape[1:]) + out = _layer(inp) + if 'BatchNorm' in layer: + flat = keras.layers.Flatten() + _layer2 = PQDense(config, 16, in_quant_bits=(1.0, 1.0, 7.0), quantize_output=True, out_quant_bits=(1.0, 2.0, 7.0)) + out = _layer2(flat(out)) + model = keras.Model(inp, out) + + return model, shape + + +def get_data(shape: tuple[int, ...], v: float, max_scale: float): + rng = np.random.default_rng() + a1 = rng.uniform(-v, v, shape).astype(np.float32) + a2 = rng.uniform(0, max_scale, (1, *shape[1:])).astype(np.float32) + return (a1 * a2).astype(np.float32) + + +def get_shape( + layer: keras.layers.Layer, + batch_size: int = 1, + default_length: int = 32, + default_hw: tuple[int, int] = (32, 32), + default_channels: int = 2, +): + match layer: + case PQActivation(): + # (N, L) + return (batch_size, default_length) + case PQAvgPool1d(): + # (N, L, C) + return (batch_size, default_length, default_channels) + case PQAvgPool2d(): + # (N, H, W, C) + return (batch_size, *default_hw, default_channels) + case PQBatchNormalization(): + # (N, num_features, H, W) + return (batch_size, *default_hw, default_channels) + case PQConv1d(): + # (N, C_in, L) + return (batch_size, default_length, default_channels) + case PQConv2d(): + # (N, C_in, H, W) + return (batch_size, *default_hw, default_channels) + case PQDense(): + # (N, in_features) + return (batch_size, default_length) + case _: + raise TypeError(f'Unsupported layer type: {type(layer).__name__}') + + +@pytest.mark.parametrize( + 'layer', + [ + 'PQDense(16)', + 'PQDense(16, use_bias=False)', + "PQConv1d(3, kernel_size=3, padding='same')", + "PQConv1d(3, kernel_size=3, padding='valid')", + "PQConv1d(3, kernel_size=3, padding='valid', use_bias=False)", + "PQConv1d(3, kernel_size=3, padding='valid', strides=2)", + "PQConv1d(3, kernel_size=3, padding='same', strides=2)", + "PQConv2d(3, kernel_size=(3,3), padding='same')", + "PQConv2d(3, kernel_size=(3,3), padding='valid')", + "PQConv2d(3, kernel_size=(3,3), padding='valid', use_bias=False)", + "PQConv2d(3, kernel_size=(3,3), padding='valid', strides=2)", + "PQConv2d(3, kernel_size=(3,3), padding='same', strides=2)", + 'PQBatchNormalization()', + "PQAvgPool1d(2, padding='same')", + "PQAvgPool2d((1,2), padding='same')", + "PQAvgPool2d((2,2), padding='same')", + "PQAvgPool1d(2, padding='valid')", + "PQAvgPool2d((1,2), padding='valid')", + "PQAvgPool2d((2,2), padding='valid')PQActivation('relu')", + "PQActivation('tanh')", + ], +) +@pytest.mark.parametrize('N', [1000]) +@pytest.mark.parametrize('io_type', ['io_parallel']) +@pytest.mark.parametrize('backend', ['vivado', 'vitis']) +@pytest.mark.parametrize('use_hgq', [True, False]) +@pytest.mark.parametrize('strategy', ['latency', 'resource']) +def test_syn_hlayers(layer, N: int, io_type: str, backend: str, use_hgq: bool, strategy: str): + model, data_shape = create_pqlayer_model(layer=layer, use_hgq=use_hgq) + data = get_data(data_shape, 7, 1) + + path = test_path / f'hls4mlprj_pquant_keras__{layer}_{io_type}_{backend}_{use_hgq}_{strategy}' + + run_model_test(model, data, io_type, backend, str(path), None, strategy) diff --git a/test/pytest/test_pquant_pytorch.py b/test/pytest/test_pquant_pytorch.py new file mode 100644 index 000000000..472d9cb89 --- /dev/null +++ b/test/pytest/test_pquant_pytorch.py @@ -0,0 +1,184 @@ +import os +from pathlib import Path + +import numpy as np +import pytest +from pquant.activations import PQActivation +from pquant.core.finetuning import TuningConfig +from pquant.core.utils import get_default_config +from pquant.layers import PQAvgPool1d, PQAvgPool2d, PQBatchNorm2d, PQConv1d, PQConv2d, PQDense + +from hls4ml.converters import convert_from_pytorch_model +from hls4ml.utils import config_from_pytorch_model + +os.environ['KERAS_BACKEND'] = 'torch' +import torch # noqa: E402 +import torch.nn as nn # noqa: E402 + +test_path = Path(__file__).parent + + +def _run_synth_match_test(PQmodel: nn.Module, data, io_type: str, backend: str, dir: str, cond=None, strategy='latency'): + output_dir = dir + '/hls4ml_prj' + hls_config = config_from_pytorch_model( + PQmodel, + input_shape=tuple(data.shape[1:]), + granularity='name', + default_precision='ap_fixed<32, 16>', + backend=backend, + transpose_outputs=True, + ) + hls_model = convert_from_pytorch_model( + PQmodel, + io_type=io_type, + output_dir=output_dir, + backend=backend, + hls_config=hls_config, + ) + hls_model.compile() + + data_len = data.shape[0] if isinstance(data, np.ndarray) else data[0].shape[0] + r_pq: list[np.ndarray] = [PQmodel(data).detach().cpu().numpy()] # type: ignore + r_hls: list[np.ndarray] = [hls_model.predict(np.ascontiguousarray(data)).reshape(r_pq[0].shape)] # type: ignore + + errors = [] + for i, (p, h) in enumerate(zip(r_pq, r_hls)): + try: + if cond is None: + mismatch_ph = p != h + assert np.sum(mismatch_ph) == 0, ( + f'Proxy-HLS4ML mismatch for out {i}: {np.sum(np.any(mismatch_ph, axis=1))} out of {data_len} samples are different. Sample: {p[mismatch_ph].ravel()[:5]} vs {h[mismatch_ph].ravel()[:5]}' # noqa: E501 + ) + else: + cond(p, h) + except AssertionError as e: + errors.append(e) + if len(errors) > 0: + msgs = [str(e) for e in errors] + raise AssertionError('\n'.join(msgs)) + + +def run_model_test( + PQmodel: nn.Module, + data, + io_type: str, + backend: str, + dir: str, + cond=None, + strategy='latency', +): + PQmodel.eval() + PQmodel(data[:1]) + _run_synth_match_test(PQmodel, data, io_type, backend, dir, cond=cond, strategy=strategy) + + +def create_pqlayer_model(layer: str, use_hgq: bool): + config = get_default_config('pdp') + config['pruning_parameters']['disable_pruning_for_layers'] = [''] + config['quantization_parameters']['use_high_granularity_quantization'] = use_hgq + config = TuningConfig.load_from_config(config) + + idx = layer.find('(') + 1 + layer = ( + layer[:idx] + + 'config, ' + + layer[idx:-1] + + (', quantize_output=True, out_quant_bits=(1, 2, 7)' if 'BatchNorm' not in layer else '') + + ')' + ) + _layer = eval(layer) + + class SingleLayerModel(nn.Module): + def __init__(self, layer): + super().__init__() + self.layer = layer + + def forward(self, x): + return self.layer(x) + + model = SingleLayerModel(_layer) + return model + + +def get_data(shape: tuple[int, ...], v: float, max_scale: float): + rng = np.random.default_rng() + a1 = rng.uniform(-v, v, shape).astype(np.float32) + a2 = rng.uniform(0, max_scale, (1, *shape[1:])).astype(np.float32) + return torch.tensor((a1 * a2), dtype=torch.float32) + + +def get_shape(model: nn.Module, batch_size: int = 1, default_length: int = 32, default_hw: tuple[int, int] = (32, 32)): + for lay in list(model.modules())[1:]: + if not isinstance(lay, (nn.Sequential, nn.ModuleList, nn.Identity)): + layer = lay + break + else: + raise ValueError('Model has no valid layers to infer shape from.') + + match layer: + case PQActivation(): + # (N, L) + return (batch_size, default_length) + case PQAvgPool1d(): + # (N, C, L) + return (batch_size, 1, default_length) + case PQAvgPool2d(): + # (N, C, H, W) + return (batch_size, 1, *default_hw) + # case PQBatchNorm1d(): + # # (N, num_features, L) + # return (batch_size, layer.num_features, *default_length) + case PQBatchNorm2d(): + # (N, num_features, H, W) + return (batch_size, layer.num_features, *default_hw) + case PQConv1d(): + # (N, C_in, L) + return (batch_size, layer.in_channels, default_length) + case PQConv2d(): + # (N, C_in, H, W) + return (batch_size, layer.in_channels, *default_hw) + case PQDense(): + # (N, in_features) + return (batch_size, layer.in_features) + case _: + raise TypeError(f'Unsupported layer type: {type(layer).__name__}') + + +@pytest.mark.parametrize( + 'layer', + [ + 'PQDense(16, 4)', + 'PQDense(16, 4, bias=False)', + 'PQConv1d(2, 3, kernel_size=3, padding=1)', + 'PQConv1d(2, 3, kernel_size=3, padding=0)', + 'PQConv1d(2, 3, kernel_size=3, padding=0, bias=False)', + 'PQConv1d(2, 3, kernel_size=3, padding=0, stride=2)', + 'PQConv1d(2, 3, kernel_size=3, padding=1, stride=2)', + 'PQConv2d(2, 3, kernel_size=(3,3), padding=1)', + 'PQConv2d(2, 3, kernel_size=(3,3), padding=0)', + 'PQConv2d(2, 3, kernel_size=(3,3), padding=0, bias=False)', + 'PQConv2d(2, 3, kernel_size=(3,3), padding=0, stride=2)', + 'PQConv2d(2, 3, kernel_size=(3,3), padding=1, stride=2)', + 'PQBatchNorm2d(3)', + 'PQAvgPool1d(2, padding=1)', + 'PQAvgPool1d(2, padding=0)', + 'PQAvgPool2d((2,2), padding=1)', + 'PQAvgPool2d((2,2), padding=0)', + 'PQAvgPool2d((1, 2), stride=(1, 2), padding=(0, 1))', + "PQActivation('relu')", + "PQActivation('tanh')", + ], +) +@pytest.mark.parametrize('N', [1000]) +@pytest.mark.parametrize('io_type', ['io_parallel']) +@pytest.mark.parametrize('backend', ['vivado', 'vitis']) +@pytest.mark.parametrize('use_hgq', [True, False]) +@pytest.mark.parametrize('strategy', ['latency', 'resource']) +def test_syn_hlayers(layer, N: int, io_type: str, backend: str, use_hgq: bool, strategy: str): + model = create_pqlayer_model(layer=layer, use_hgq=use_hgq) + data_shape = get_shape(model, batch_size=N) + data = get_data(data_shape, 7, 1) + + path = test_path / f'hls4mlprj_pquant_pytorch_{layer}_{io_type}_{backend}_{use_hgq}_{strategy}' + + run_model_test(model, data, io_type, backend, str(path), None, strategy)