Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions QEfficient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,17 @@
# -----------------------------------------------------------------------------

import os
import warnings

# ----------------------------------------------------------------------------- #
# For faster downloads via hf_transfer
# This code is put above import statements as this needs to be executed before
# hf_transfer is imported (will happen on line 15 via leading imports)
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# DO NOT ADD ANY CODE ABOVE THIS LINE
# Please contact maintainers if you must edit this file above this line.
# ----------------------------------------------------------------------------- #
# Placeholder for all non-transformer models registered in QEfficient
import warnings # noqa: I001

import QEfficient.utils.model_registery # noqa: F401
from QEfficient.base import (
Expand All @@ -18,12 +28,19 @@
QEFFCommonLoader,
)
from QEfficient.compile.compile_helper import compile
from QEfficient.diffusers.pipelines.flux.pipeline_flux import QEffFluxPipeline
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
from QEfficient.peft import QEffAutoPeftModelForCausalLM
from QEfficient.transformers.transform import transform
from QEfficient.utils import custom_format_warning
from QEfficient.utils.logging_utils import logger
from QEfficient.utils.logging_utils import QEFFLogger

logger = QEFFLogger.get_logger("INFRA", loglevel="INFO")

# custom warning for the better logging experience
warnings.formatwarning = custom_format_warning


# Users can use QEfficient.export for exporting models to ONNX
export = qualcomm_efficient_converter
Expand All @@ -39,15 +56,9 @@
"QEFFAutoModelForImageTextToText",
"QEFFAutoModelForSpeechSeq2Seq",
"QEFFCommonLoader",
"QEffFluxPipeline",
]
# For faster downloads via hf_transfer
# This code is put above import statements as this needs to be executed before
# hf_transfer is imported (will happen on line 15 via leading imports)
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# Placeholder for all non-transformer models registered in QEfficient

# custom warning for the better logging experience
warnings.formatwarning = custom_format_warning

# Conditionally import QAIC-related modules if the SDK is installed
__version__ = "0.0.1.dev0"
Expand Down
141 changes: 102 additions & 39 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

import gc
import inspect
import logging
import re
import shutil
import subprocess
import warnings
Expand All @@ -21,28 +19,24 @@

from QEfficient.base.onnx_transforms import (
BaseOnnxTransform,
CustomOpTransform,
OnnxTransformPipeline,
RenameFunctionOutputsTransform,
)
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.transformers.cache_utils import InvalidIndexProvider
from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export
from QEfficient.utils import (
constants,
create_json,
create_model_params,
dump_qconfig,
export_wrapper,
generate_mdp_partition_config,
hash_dict_params,
load_json,
)
from QEfficient.utils.torch_patches import apply_torch_patches, undo_torch_patches
from QEfficient.utils.export_utils import export_wrapper
from QEfficient.utils.logging_utils import QEFFLogger

logger = logging.getLogger(__name__)
logger = QEFFLogger.get_logger("INFRA", loglevel="INFO")


class QEFFBaseModel(ABC):
Expand All @@ -66,6 +60,7 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
super().__init__()
self.model = model
self.hash_params = create_model_params(self, **kwargs)
self.prefill_onnx_path: Optional[str] = None
self.onnx_path: Optional[str] = None
self.qpc_path: Optional[str] = None
self.qpc_session: Optional[QAICInferenceSession] = None
Expand Down Expand Up @@ -125,9 +120,35 @@ def _model_offloaded_check(self) -> None:
logger.error(error_msg)
raise RuntimeError(error_msg)

@property
def model_name(self) -> str:
"""
Get the model class name without QEff/QEFF prefix.

This property extracts the underlying model's class name and removes
any QEff or QEFF prefix that may have been added during wrapping.

Returns:
str: Model class name (e.g., "CLIPTextModel" instead of "QEffCLIPTextModel")
"""
mname = self.model.__class__.__name__
if mname.startswith("QEff") or mname.startswith("QEFF"):
mname = mname[4:]
return mname

@property
@abstractmethod
def model_name(self) -> str: ...
def get_model_config(self) -> Dict:
"""
Get the model configuration as a dictionary.

This is an abstract property that must be implemented by all subclasses.
Typically returns: self.model.config.__dict__

Returns:
Dict: The configuration dictionary of the underlying model
"""
pass

@abstractmethod
def export(self, export_dir: Optional[str] = None) -> Path:
Expand Down Expand Up @@ -184,11 +205,11 @@ def _export(
example_inputs: Dict[str, torch.Tensor],
output_names: List[str],
dynamic_axes: Dict[str, Dict[int, str]],
export_kwargs: Optional[Dict[str, any]] = None,
onnx_transform_kwargs: Optional[Dict[str, any]] = None,
export_dir: Optional[str] = None,
offload_pt_weights: bool = True,
use_onnx_subfunctions: bool = False,
prefill_only: Optional[bool] = False,
**export_kwargs,
) -> str:
"""
Export the PyTorch model to ONNX and apply ONNX transforms
Expand All @@ -213,11 +234,16 @@ def _export(
instance using from_pretrained() for re-export.

"""
# TODO: Hack for retain_full_kv, handle this outside
export_kwargs.pop("retain_full_kv", None)
onnx_path = export_dir / f"{self.model_name}.onnx"

# Return early if ONNX already exists
if onnx_path.is_file():
self.onnx_path = onnx_path
if prefill_only:
self.prefill_onnx_path = onnx_path
else:
self.onnx_path = onnx_path
return onnx_path

# check if the model is in meta state or weights are offloaded
Expand Down Expand Up @@ -253,19 +279,6 @@ def _export(
input_names.append(param)

try:
# Initialize the registry with your custom ops
export_kwargs = {} if export_kwargs is None else export_kwargs
if use_onnx_subfunctions:
warnings.warn(
"The subfunction feature is experimental. Please note that using compile consecutively with and without subfunction may produce inconsistent results."
)
apply_torch_patches()
InvalidIndexProvider.SUBFUNC_ENABLED = True
output_names = [re.sub("_RetainedState", "_InternalRetainedState", s) for s in output_names]
export_kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(self.model)
self._onnx_transforms.append(RenameFunctionOutputsTransform)
self._onnx_transforms.append(CustomOpTransform)

torch.onnx.export(
self.model,
(example_inputs,),
Expand Down Expand Up @@ -309,15 +322,43 @@ def _export(
finally:
shutil.rmtree(tmp_onnx_dir, ignore_errors=True)

if use_onnx_subfunctions:
undo_torch_patches()
InvalidIndexProvider.SUBFUNC_ENABLED = False
self._onnx_transforms.remove(CustomOpTransform)
self._onnx_transforms.remove(RenameFunctionOutputsTransform)

self.onnx_path = onnx_path
if prefill_only:
self.prefill_onnx_path = onnx_path
else:
self.onnx_path = onnx_path
logger.info("Model export is finished and saved at: %s", onnx_path)
return onnx_path

def get_onnx_path(
self,
prefill_only: Optional[bool] = False,
enable_chunking: Optional[bool] = False,
specializations: Optional[List[Dict[str, int]]] = None,
offload_pt_weights: Optional[bool] = True,
use_onnx_subfunctions: Optional[bool] = False,
retain_full_kv: Optional[bool] = False,
):
kwargs = {
"offload_pt_weights": offload_pt_weights,
"use_onnx_subfunctions": use_onnx_subfunctions,
"retain_full_kv": retain_full_kv,
}
if prefill_only:
if self.prefill_onnx_path is None:
kwargs.update(
{
"prefill_only": prefill_only,
"prefill_seq_len": specializations[0].get("seq_len"),
"enable_chunking": enable_chunking,
}
)
self.export(**kwargs)
return self.prefill_onnx_path
else:
if self.onnx_path is None:
self.export(**kwargs)
return self.onnx_path

@dump_qconfig
def _compile(
self,
Expand All @@ -332,6 +373,10 @@ def _compile(
enable_qnn: Optional[bool] = False,
qnn_config: Optional[str] = None,
use_onnx_subfunctions: bool = False,
prefill_only: Optional[str] = None,
offload_pt_weights: Optional[bool] = True,
enable_chunking: Optional[bool] = False,
retain_full_kv: Optional[bool] = None,
**compiler_options,
) -> str:
"""
Expand All @@ -357,11 +402,18 @@ def _compile(

For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored.
"""

if onnx_path is None and self.onnx_path is None:
self.export(use_onnx_subfunctions=use_onnx_subfunctions)

onnx_path = Path(onnx_path or self.onnx_path)
onnx_path = Path(
onnx_path
if onnx_path
else self.get_onnx_path(
prefill_only,
enable_chunking,
specializations,
offload_pt_weights,
use_onnx_subfunctions,
retain_full_kv,
)
)
compile_dir = Path(compile_dir or onnx_path.parent)
qpc_path = compile_dir / "qpc"
if not onnx_path.is_file():
Expand Down Expand Up @@ -423,6 +475,7 @@ def _compile(
"mdp_ts_num_devices": mdp_ts_num_devices,
"mdp_ts_json": mdp_ts_json,
"num_speculative_tokens": num_speculative_tokens,
"prefill_only": prefill_only,
}
compile_hash = hash_dict_params(compile_hash_params)

Expand Down Expand Up @@ -462,6 +515,16 @@ def _compile(

command.append(f"-aic-binary-dir={qpc_path}")
logger.info(f"Running compiler: {' '.join(command)}")
if use_onnx_subfunctions:

class FeatureNotAvailableError(Exception):
pass

exec_command = f'QAIC_COMPILER_OPTS_UNSUPPORTED="-loader-inline-all=0" {" ".join(command)}'
raise FeatureNotAvailableError(
"ONNX graph is exported with subfunctions, assert version of apps SDK should be used for compiling this model."
+ f"\nRun following command manually with assert compiler:\n{exec_command}"
)
try:
subprocess.run(command, capture_output=True, check=True)
except subprocess.CalledProcessError as e:
Expand All @@ -482,5 +545,5 @@ def _compile(
logger.info("Hashed parameters exported successfully.")

self.qpc_path = qpc_path

logger.info("Model compilation is finished and saved at: %s", qpc_path)
return qpc_path
10 changes: 8 additions & 2 deletions QEfficient/base/onnx_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,20 @@
from QEfficient.customop.ctx_scatter_gather import (
CtxGather,
CtxGather3D,
CtxGatherBlockedKV,
CtxGatherFunc,
CtxGatherFunc3D,
CtxGatherFuncBlockedKV,
CtxScatter,
CtxScatter3D,
CtxScatterFunc,
CtxScatterFunc3D,
)
from QEfficient.customop.ctx_scatter_gather_cb import (
CtxGatherBlockedKVCB,
CtxGatherCB,
CtxGatherCB3D,
CtxGatherFuncBlockedKVCB,
CtxGatherFuncCB,
CtxGatherFuncCB3D,
CtxScatterCB,
Expand Down Expand Up @@ -91,10 +95,12 @@ class CustomOpTransform(BaseOnnxTransform):
"CtxScatterFunc3D": (CtxScatterFunc3D, CtxScatter3D),
"CtxGatherFunc": (CtxGatherFunc, CtxGather),
"CtxGatherFunc3D": (CtxGatherFunc3D, CtxGather3D),
"CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB),
"CtxScatterFuncCB3D": (CtxScatterFuncCB3D, CtxScatterCB3D),
"CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB),
"CtxGatherFuncCB3D": (CtxGatherFuncCB3D, CtxGatherCB3D),
"CtxGatherFuncBlockedKV": (CtxGatherFuncBlockedKV, CtxGatherBlockedKV),
"CtxGatherFuncBlockedKVCB": (CtxGatherFuncBlockedKVCB, CtxGatherBlockedKVCB),
"CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB),
"CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB),
}

@classmethod
Expand Down
4 changes: 3 additions & 1 deletion QEfficient/base/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

from torch import nn

from QEfficient.utils.logging_utils import logger
from QEfficient.utils.logging_utils import QEFFLogger

logger = QEFFLogger.get_logger("INFRA", loglevel="INFO")


class PytorchTransform:
Expand Down
4 changes: 3 additions & 1 deletion QEfficient/cloud/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from QEfficient.base.common import QEFFCommonLoader
from QEfficient.utils import check_and_assign_cache_dir
from QEfficient.utils.custom_yaml import generate_custom_io
from QEfficient.utils.logging_utils import logger
from QEfficient.utils.logging_utils import QEFFLogger

logger = QEFFLogger.get_logger("INFRA", loglevel="INFO")

# Specifically for Docker images.
ROOT_DIR = os.path.dirname(os.path.abspath(""))
Expand Down
4 changes: 3 additions & 1 deletion QEfficient/cloud/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
from QEfficient.finetune.utils.dataset_utils import get_dataloader, get_longest_seq_length
from QEfficient.finetune.utils.device_map import get_device_map
from QEfficient.finetune.utils.helper import Task_Mode, get_world_size
from QEfficient.finetune.utils.logging_utils import logger
from QEfficient.finetune.utils.parser import get_finetune_parser
from QEfficient.finetune.utils.train_utils import print_model_size, print_trainable_parameters, train
from QEfficient.utils._utils import hf_download
from QEfficient.utils.logging_utils import QEFFLogger

logger = QEFFLogger.get_logger("FT", loglevel="INFO")

# Try importing QAIC-specific module, proceed without it if unavailable
try:
Expand Down
4 changes: 3 additions & 1 deletion QEfficient/cloud/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

from QEfficient.base.common import QEFFCommonLoader
from QEfficient.utils import check_and_assign_cache_dir, load_hf_processor, load_hf_tokenizer
from QEfficient.utils.logging_utils import logger
from QEfficient.utils.logging_utils import QEFFLogger

logger = QEFFLogger.get_logger("INFRA", loglevel="INFO")


# TODO: Remove after adding support for VLM's compile and execute
Expand Down
Loading
Loading