Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions modelconverter/packages/rvc4/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, config: SingleStageConfig, output_dir: Path):
super().__init__(config=config, output_dir=output_dir)

rvc4_cfg = config.rvc4
self.custom_quantization_overrides = rvc4_cfg.quantization_overrides
self.snpe_onnx_to_dlc = rvc4_cfg.snpe_onnx_to_dlc_args
self.snpe_dlc_quant = rvc4_cfg.snpe_dlc_quant_args
self.snpe_dlc_graph_prepare = rvc4_cfg.snpe_dlc_graph_prepare_args
Expand Down Expand Up @@ -273,19 +274,27 @@ class Entry(NamedTuple):
return self.input_list_path

def generate_io_encodings(self) -> Path:
encodings_dict = {"activation_encodings": {}, "param_encodings": {}}
if not (list(self.inputs.keys()) and list(self.outputs.keys())):
logger.warning(
"Cannot generate I/O encodings as inputs or outputs are not defined. The resulting DLC may not be compatible with DAI."
if self.custom_quantization_overrides is not None:
encodings_dict = self.custom_quantization_overrides.model_dump(
mode="json", exclude_none=True
)
for name in (
list(self.inputs.keys())
+ list(self.outputs.keys())
+ self.extra_quant_tensors
):
encodings_dict["activation_encodings"][name] = [
{"bitwidth": 8, "dtype": "int"}
]
else:
encodings_dict = {
"activation_encodings": {},
"param_encodings": {},
}
if not (list(self.inputs.keys()) and list(self.outputs.keys())):
logger.warning(
"Cannot generate I/O encodings as inputs or outputs are not defined. The resulting DLC may not be compatible with DAI."
)
for name in (
list(self.inputs.keys())
+ list(self.outputs.keys())
+ self.extra_quant_tensors
):
encodings_dict["activation_encodings"][name] = [
{"bitwidth": 8, "dtype": "int"}
]
encodings_path = self.intermediate_outputs_dir / "io_encodings.json"
with open(encodings_path, "w") as encodings_file:
json.dump(encodings_dict, encodings_file, indent=4)
Expand Down
63 changes: 48 additions & 15 deletions modelconverter/utils/config.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
import json
from itertools import chain
from pathlib import Path
from typing import Annotated, Any, Literal, cast

import onnx
from loguru import logger
from luxonis_ml.typing import PathType
from luxonis_ml.typing import BaseModelExtraForbid, PathType
from luxonis_ml.utils import LuxonisConfig
from onnx import TypeProto
from pydantic import (
BaseModel,
ConfigDict,
Field,
PositiveInt,
field_serializer,
field_validator,
model_validator,
)
from typing_extensions import Self

from modelconverter.utils.calibration_data import download_calibration_data
from modelconverter.utils.constants import MODELS_DIR
from modelconverter.utils.constants import MISC_DIR, MODELS_DIR
from modelconverter.utils.filesystem_utils import resolve_path
from modelconverter.utils.layout import make_default_layout
from modelconverter.utils.metadata import Metadata, get_metadata
Expand All @@ -40,11 +40,7 @@
}


class CustomBaseModel(BaseModel):
model_config = ConfigDict(extra="forbid")


class LinkCalibrationConfig(CustomBaseModel):
class LinkCalibrationConfig(BaseModelExtraForbid):
stage: str
output: str | None = None
script: str | None = None
Expand All @@ -68,7 +64,7 @@ def _download_calibration_script(script: Any) -> Path | None:
return script


class ImageCalibrationConfig(CustomBaseModel):
class ImageCalibrationConfig(BaseModelExtraForbid):
path: Path
max_images: int = -1
resize_method: ResizeMethod = ResizeMethod.RESIZE
Expand All @@ -81,7 +77,7 @@ def _download_calibration_data(value: Any) -> Path | None:
return download_calibration_data(str(value))


class RandomCalibrationConfig(CustomBaseModel):
class RandomCalibrationConfig(BaseModelExtraForbid):
max_images: int = 20
min_value: float = 0.0
max_value: float = 255.0
Expand All @@ -90,7 +86,7 @@ class RandomCalibrationConfig(CustomBaseModel):
data_type: DataType = DataType.FLOAT32


class OutputConfig(CustomBaseModel):
class OutputConfig(BaseModelExtraForbid):
name: str
shape: list[int] | None = None
layout: str | None = None
Expand Down Expand Up @@ -123,7 +119,7 @@ def validate_layout(self) -> Self:
return self


class EncodingConfig(CustomBaseModel):
class EncodingConfig(BaseModelExtraForbid):
from_: Annotated[
Encoding, Field(alias="from", serialization_alias="from")
] = Encoding.RGB
Expand Down Expand Up @@ -225,7 +221,7 @@ def _parse_values(
return value


class TargetConfig(CustomBaseModel):
class TargetConfig(BaseModelExtraForbid):
disable_calibration: bool = False


Expand Down Expand Up @@ -265,6 +261,28 @@ class RVC3Config(BlobBaseConfig):
pot_target_device: PotDevice = PotDevice.VPU


class QuantizationOverridesItem(BaseModelExtraForbid):
bitwidth: Annotated[int, Field(ge=4, le=32)] | None = None
is_symmetric: bool | None = None
dtype: Literal["int", "float"] | None = None
max: float | None = None
min: float | None = None
offset: int | None = None
scale: float | None = None

@field_serializer("is_symmetric", when_used="json")
@staticmethod
def serialize_is_symmetric(value: bool | None) -> str | None:
if value is None:
return None
return str(value)


class QuantizationOverrides(BaseModelExtraForbid):
activation_encodings: dict[str, list[QuantizationOverridesItem]]
parameter_encodings: dict[str, list[QuantizationOverridesItem]]


class RVC4Config(TargetConfig):
snpe_onnx_to_dlc_args: list[str] = []
snpe_dlc_quant_args: list[str] = []
Expand All @@ -277,6 +295,21 @@ class RVC4Config(TargetConfig):
htp_socs: list[
Literal["sm8350", "sm8450", "sm8550", "sm8650", "qcs6490", "qcs8550"]
] = ["sm8550"]
quantization_overrides: QuantizationOverrides | None = None

@field_validator("quantization_overrides", mode="before")
@staticmethod
def validate_quantization_overrides(
value: Any,
) -> QuantizationOverrides | None:
if value is None:
return None

if isinstance(value, str):
value_path = resolve_path(value, MISC_DIR)
return QuantizationOverrides(**json.loads(value_path.read_text()))

return QuantizationOverrides(**value)

@model_validator(mode="after")
def _validate_fp16(self) -> Self:
Expand All @@ -286,7 +319,7 @@ def _validate_fp16(self) -> Self:
return self


class SingleStageConfig(CustomBaseModel):
class SingleStageConfig(BaseModelExtraForbid):
input_model: Path
input_bin: Path | None = None
input_file_type: InputFileType
Expand Down
9 changes: 9 additions & 0 deletions shared_with_container/configs/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,12 @@ stages:

# Pre-defined quantization modes for the RVC4 exporter. Pre-defined modes (except CUSTOM) will override any user-provided SNPE arguments via `snpe_onnx_to_dlc_args`, `snpe_dlc_quant_args`, and `snpe_dlc_graph_prepare_args`. The available quantization modes are: INT8_STANDARD, INT8_ACCURACY_FOCUSED, INT8_INT16_MIXED, FP16_STANDARD, and CUSTOM.
quantization_mode: INT8_STANDARD

# Custom quantization overrides for the RVC4 exporter. The
# format is specified as a dictionary with two mandatory
# fields: `activation_encodings` and `param_encodings`.
# Each field is a list of dictionaries specifying the
# quantization encodings for the activations and parameters,
# respectively.
# The dictionaries in the lists must follow the AIMET specification.
quantization_overrides: ~
1 change: 1 addition & 0 deletions tests/test_utils/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
"use_per_row_quantization": False,
"quantization_mode": QuantizationMode.INT8_STD,
"optimization_level": 2,
"quantization_overrides": None,
},
"hailo": {
"force_onnx_names": True,
Expand Down
Loading