diff --git a/QEfficient/finetune/experimental/core/component_registry.py b/QEfficient/finetune/experimental/core/component_registry.py index 7744d71e6..d1f948031 100644 --- a/QEfficient/finetune/experimental/core/component_registry.py +++ b/QEfficient/finetune/experimental/core/component_registry.py @@ -5,7 +5,6 @@ # # ----------------------------------------------------------------------------- - import logging from typing import Callable, Dict, Optional, Type @@ -198,3 +197,14 @@ def list_callbacks(self) -> list[str]: # Global registry instance registry = ComponentRegistry() + + +class ComponentFactory: + @staticmethod + def create_model(model_type: str, model_name: str, **kwargs) -> any: + """Create a model instance.""" + model_class = registry.get_model(model_type) + if model_class is None: + raise ValueError(f"Unknown model: {model_type}. Available: {registry.list_models()}") + model_instance = model_class.create(model_name, **kwargs) + return model_instance diff --git a/QEfficient/finetune/experimental/core/model.py b/QEfficient/finetune/experimental/core/model.py index d647b73a6..0f087e665 100644 --- a/QEfficient/finetune/experimental/core/model.py +++ b/QEfficient/finetune/experimental/core/model.py @@ -4,3 +4,135 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- + +import warnings +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Type + +import torch.nn as nn +import transformers +from transformers import AutoTokenizer + +from QEfficient.finetune.experimental.core.component_registry import registry +from QEfficient.finetune.experimental.core.logger import Logger +from QEfficient.finetune.experimental.core.utils.dataset_utils import insert_pad_token + +logger = Logger(__name__) + + +class BaseModel(nn.Module, ABC): + """Shared skeleton for every finetunable model in the system.""" + + def __init__(self, model_name: str, **model_kwargs: Any) -> None: + super().__init__() + self.model_name = model_name + self.model_kwargs: Dict[str, Any] = model_kwargs + self._model: Optional[nn.Module] = None + self._tokenizer: Any = None # HF tokenizers are not nn.Modules. + + # Factory constructor: load model after __init__ finishes + @classmethod + def create(cls, model_name: str, **model_kwargs: Any) -> "BaseModel": + obj = cls(model_name, **model_kwargs) + # load model after __init__ finishes + module = obj.load_model() + if not isinstance(module, nn.Module): + raise TypeError(f"load_model() must return nn.Module, got {type(module)}") + obj._model = module + return obj + + @abstractmethod + def load_model(self) -> nn.Module: + """Load and return the underlying torch.nn.Module.""" + pass + + def load_tokenizer(self) -> Any: + """Override if the model exposes a tokenizer.""" + warnings.warn(f"{type(self).__name__} does not provide a tokenizer.", category=UserWarning) + return None + + # Lazy accessors + @property + def model(self) -> nn.Module: + if self._model is None: + raise RuntimeError("Model not loaded; use .create(...) to load.") + return self._model + + @property + def tokenizer(self) -> Any: + if self._tokenizer is None: + self._tokenizer = self.load_tokenizer() + return self._tokenizer + + # nn.Module API surface + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def to(self, *args, **kwargs): + self.model.to(*args, **kwargs) + return self + + def train(self, mode: bool = True): + self.model.train(mode) + return super().train(mode) + + def eval(self): + return self.train(False) + + +@registry.model("hf") +class HFModel(BaseModel): + """HuggingFace-backed model with optional quantization.""" + + def __init__( + self, + model_name: str, + auto_class_name: str = "AutoModelForCausalLM", + *, + tokenizer_name: Optional[str] = None, + **model_kwargs: Any, + ) -> None: + super().__init__(model_name, **model_kwargs) + self.tokenizer_name = tokenizer_name or model_name + self.auto_class: Type = self._resolve_auto_class(auto_class_name) + + @staticmethod + def _resolve_auto_class(auto_class_name: str) -> Type: + if not hasattr(transformers, auto_class_name): + candidates = sorted(name for name in dir(transformers) if name.startswith("AutoModel")) + raise ValueError( + f"Unsupported Auto class '{auto_class_name}'. Available candidates: {', '.join(candidates)}" + ) + return getattr(transformers, auto_class_name) + + # def _build_quant_config(self) -> Optional[BitsAndBytesConfig]: + # if not self.model_kwargs.get("load_in_4bit"): + # return None + # return BitsAndBytesConfig( + # load_in_4bit=True, + # bnb_4bit_quant_type=self.model_kwargs.get("bnb_4bit_quant_type", "nf4"), + # bnb_4bit_compute_dtype=self.model_kwargs.get("bnb_4bit_compute_dtype", torch.float16), + # bnb_4bit_use_double_quant=self.model_kwargs.get("bnb_4bit_use_double_quant", True), + # ) + + def configure_model_kwargs(self) -> Dict[str, Any]: + """Hook for subclasses to tweak HF `.from_pretrained` kwargs.""" + + extra = dict(self.model_kwargs) + # extra["quantization_config"] = self._build_quant_config() + return extra + + def load_model(self) -> nn.Module: + logger.log_rank_zero(f"Loading HuggingFace model '{self.model_name}' via {self.auto_class.__name__}") + + return self.auto_class.from_pretrained( + self.model_name, + **self.configure_model_kwargs(), + ) + + def load_tokenizer(self) -> AutoTokenizer: + """Load Hugging Face tokenizer.""" + logger.log_rank_zero(f"Loading tokenizer '{self.tokenizer_name}'") + tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) + insert_pad_token(tokenizer) + return tokenizer diff --git a/QEfficient/finetune/experimental/tests/test_model.py b/QEfficient/finetune/experimental/tests/test_model.py new file mode 100644 index 000000000..e83abf389 --- /dev/null +++ b/QEfficient/finetune/experimental/tests/test_model.py @@ -0,0 +1,136 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from unittest import mock + +import pytest +import torch +import torch.nn as nn + +from QEfficient.finetune.experimental.core import model +from QEfficient.finetune.experimental.core.component_registry import ComponentFactory, registry +from QEfficient.finetune.experimental.core.model import BaseModel + + +class TestMockModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 2) + + def forward(self, x): + return self.linear(x) + + +@registry.model("testcustom") +class TestCustomModel(BaseModel): + def __init__(self, model_name): + super().__init__(model_name) + print("init of custom class") + + def load_model(self) -> nn.Module: + return TestMockModel() + + def load_tokenizer(self): + return "dummy-tokenizer" + + +# BaseModel tests +def test_model_property_errors_if_not_created(): + m = TestCustomModel("dummy") + with pytest.raises(RuntimeError): + _ = m.model # must call .create() + + +def test_create_builds_and_registers(): + m = ComponentFactory.create_model("testcustom", "dummy") + # inner model exists and registered + assert "_model" in m._modules + assert isinstance(m.model, TestMockModel) + # forward works + out = m(torch.zeros(1, 2)) + assert out.shape == (1, 2) + + +def test_tokenizer_lazy_loading(): + m = ComponentFactory.create_model("testcustom", "dummy") + assert m._tokenizer is None + tok = m.tokenizer + assert tok == "dummy-tokenizer" + assert m._tokenizer == tok + + +def test_to_moves_inner_and_returns_self(): + m = ComponentFactory.create_model("testcustom", "dummy") + with mock.patch.object(TestMockModel, "to", autospec=True) as mocked_to: + ret = m.to("cpu:0") + assert mocked_to.call_args[0][0] is m.model + assert mocked_to.call_args[0][1] == "cpu:0" + assert ret is m + + +def test_train_eval_sync_flags(): + m = ComponentFactory.create_model("testcustom", "dummy") + m.eval() + assert m.training is False + assert m.model.training is False + m.train() + assert m.training is True + assert m.model.training is True + + +def test_state_dict_contains_inner_params(): + m = ComponentFactory.create_model("testcustom", "dummy") + sd = m.state_dict() + # should contain params from TestMockModel.linear + assert any("linear.weight" in k for k in sd) + assert any("linear.bias" in k for k in sd) + + +# HFModel tests +def test_hfmodel_invalid_auto_class_raises(): + with pytest.raises(ValueError): + ComponentFactory.create_model("hf", "hf-name", auto_class_name="AutoDoesNotExist") + + +def test_hfmodel_loads_auto_and_tokenizer(monkeypatch): + # fake HF Auto class + class FakeAuto(nn.Module): + @classmethod + def from_pretrained(cls, name, **kwargs): + inst = cls() + inst.loaded = (name, kwargs) + return inst + + def forward(self, x): + return x + + fake_tok = mock.Mock() + + # Monkeypatch transformer classes used in HFModel + monkeypatch.setattr( + "QEfficient.finetune.experimental.core.model.transformers.AutoModelForCausalLM", + FakeAuto, + raising=False, + ) + monkeypatch.setattr( + model, + "AutoTokenizer", + mock.Mock(from_pretrained=mock.Mock(return_value=fake_tok)), + ) + monkeypatch.setattr( + "QEfficient.finetune.experimental.core.model.insert_pad_token", + mock.Mock(), + raising=False, + ) + m = ComponentFactory.create_model("hf", "hf-name") + assert isinstance(m.model, FakeAuto) + + # load tokenizer + tok = m.load_tokenizer() + + assert hasattr(tok, "pad_token_id") + assert m.model.loaded[0] == "hf-name"