-
Notifications
You must be signed in to change notification settings - Fork 63
[QEff. Finetune]: Adding base class and HF class #658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,3 +4,135 @@ | |
| # SPDX-License-Identifier: BSD-3-Clause | ||
| # | ||
| # ----------------------------------------------------------------------------- | ||
|
|
||
| import warnings | ||
| from abc import ABC, abstractmethod | ||
| from typing import Any, Dict, Optional, Type | ||
|
|
||
| import torch.nn as nn | ||
| import transformers | ||
| from transformers import AutoTokenizer | ||
|
|
||
| from QEfficient.finetune.experimental.core.component_registry import registry | ||
| from QEfficient.finetune.experimental.core.logger import Logger | ||
| from QEfficient.finetune.experimental.core.utils.dataset_utils import insert_pad_token | ||
|
|
||
| logger = Logger(__name__) | ||
|
|
||
|
|
||
| class BaseModel(nn.Module, ABC): | ||
| """Shared skeleton for every finetunable model in the system.""" | ||
|
|
||
| def __init__(self, model_name: str, **model_kwargs: Any) -> None: | ||
| super().__init__() | ||
| self.model_name = model_name | ||
| self.model_kwargs: Dict[str, Any] = model_kwargs | ||
| self._model: Optional[nn.Module] = None | ||
| self._tokenizer: Any = None # HF tokenizers are not nn.Modules. | ||
quic-swatia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Factory constructor: load model after __init__ finishes | ||
| @classmethod | ||
| def create(cls, model_name: str, **model_kwargs: Any) -> "BaseModel": | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not needed. Use registry mechanism to instantiate the BaseModel type of objects, which in turn instantiates nn.Modules as well.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. create factory seems to me a clean way to construct objects. It is the only place that guarantees the wrapped model is actually registered in the nn.Module class before anyone calls state_dict, parameters, to / train. Dropping it and relying on bare init puts us back into lazy-init land, where un-necessary we have to call load_model() in every method like reference code.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand your concern. Basically, you want to make sure there will be one liner initialization of model. The reference code does it in 2 lines : https://github.com/quic-meetkuma/LightningLLMs/blob/da2f6b39e8533cbd05d563ca68113828f783e73e/LightningLLM/main.py#L131C34-L131C46 I suggest better to move this create into the ComponentRegistry class and create a "create_model" method over there. The reason being that class has a sole responsibility of storing references of particular classes and creating instances of classes which the user want. This way all the registry and instance creation is encapsulated in one place and we only need to deal with ComponentRegistry and not other individual classes.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reference shared above is stale as PR #645 didn't have create_model() added in the component_registry.py Why was it not added? I have added it. Even after adding it, we would still need a method in BaseModel to load the model and register it in nn.Module which will be called through create_model() now. Keeping the name of that method same as create as of now. Loading part can also be moved to create_model() in the component_registry.py if that way seems better.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Going ahead with existing approach and not adding load_model in the init as the current approach follows the HF approach of factory method from_pretrained() of class PreTrainedModel. |
||
| obj = cls(model_name, **model_kwargs) | ||
| # load model after __init__ finishes | ||
| module = obj.load_model() | ||
| if not isinstance(module, nn.Module): | ||
| raise TypeError(f"load_model() must return nn.Module, got {type(module)}") | ||
| obj._model = module | ||
| return obj | ||
|
|
||
| @abstractmethod | ||
| def load_model(self) -> nn.Module: | ||
| """Load and return the underlying torch.nn.Module.""" | ||
| pass | ||
|
|
||
| def load_tokenizer(self) -> Any: | ||
quic-swatia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Override if the model exposes a tokenizer.""" | ||
| warnings.warn(f"{type(self).__name__} does not provide a tokenizer.", category=UserWarning) | ||
| return None | ||
|
|
||
| # Lazy accessors | ||
| @property | ||
| def model(self) -> nn.Module: | ||
| if self._model is None: | ||
| raise RuntimeError("Model not loaded; use .create(...) to load.") | ||
| return self._model | ||
|
|
||
| @property | ||
| def tokenizer(self) -> Any: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is applicable for NLP models. We can't put it in BaseClass. Better to create an abstract method called "preprocessor" which defines generic preprocessing function applicable for the model. There wont be any implementation here but the children class should implement that. In case of LLMs, this method should return tokenizer. |
||
| if self._tokenizer is None: | ||
| self._tokenizer = self.load_tokenizer() | ||
| return self._tokenizer | ||
|
|
||
| # nn.Module API surface | ||
| def forward(self, *args, **kwargs): | ||
| return self.model(*args, **kwargs) | ||
|
|
||
| def to(self, *args, **kwargs): | ||
| self.model.to(*args, **kwargs) | ||
| return self | ||
|
|
||
| def train(self, mode: bool = True): | ||
| self.model.train(mode) | ||
| return super().train(mode) | ||
|
|
||
| def eval(self): | ||
| return self.train(False) | ||
|
|
||
|
|
||
| @registry.model("hf") | ||
| class HFModel(BaseModel): | ||
quic-swatia marked this conversation as resolved.
Show resolved
Hide resolved
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to load the model based on configuration as well. That is mainly for testing purpose. In integration tests we will not load an entire model consists of 32 layers. But we will only load the same model with 2 or 4 layers and do further testing. For that purpose config should be used to load the model. Check huggingface documentation on how to do that. |
||
| """HuggingFace-backed model with optional quantization.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| model_name: str, | ||
| auto_class_name: str = "AutoModelForCausalLM", | ||
| *, | ||
| tokenizer_name: Optional[str] = None, | ||
| **model_kwargs: Any, | ||
| ) -> None: | ||
| super().__init__(model_name, **model_kwargs) | ||
| self.tokenizer_name = tokenizer_name or model_name | ||
| self.auto_class: Type = self._resolve_auto_class(auto_class_name) | ||
|
|
||
| @staticmethod | ||
| def _resolve_auto_class(auto_class_name: str) -> Type: | ||
| if not hasattr(transformers, auto_class_name): | ||
| candidates = sorted(name for name in dir(transformers) if name.startswith("AutoModel")) | ||
| raise ValueError( | ||
| f"Unsupported Auto class '{auto_class_name}'. Available candidates: {', '.join(candidates)}" | ||
| ) | ||
| return getattr(transformers, auto_class_name) | ||
|
|
||
| # def _build_quant_config(self) -> Optional[BitsAndBytesConfig]: | ||
| # if not self.model_kwargs.get("load_in_4bit"): | ||
| # return None | ||
| # return BitsAndBytesConfig( | ||
| # load_in_4bit=True, | ||
| # bnb_4bit_quant_type=self.model_kwargs.get("bnb_4bit_quant_type", "nf4"), | ||
| # bnb_4bit_compute_dtype=self.model_kwargs.get("bnb_4bit_compute_dtype", torch.float16), | ||
| # bnb_4bit_use_double_quant=self.model_kwargs.get("bnb_4bit_use_double_quant", True), | ||
| # ) | ||
|
|
||
| def configure_model_kwargs(self) -> Dict[str, Any]: | ||
| """Hook for subclasses to tweak HF `.from_pretrained` kwargs.""" | ||
|
|
||
| extra = dict(self.model_kwargs) | ||
| # extra["quantization_config"] = self._build_quant_config() | ||
| return extra | ||
|
|
||
| def load_model(self) -> nn.Module: | ||
| logger.log_rank_zero(f"Loading HuggingFace model '{self.model_name}' via {self.auto_class.__name__}") | ||
|
|
||
| return self.auto_class.from_pretrained( | ||
| self.model_name, | ||
| **self.configure_model_kwargs(), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dont directly pass results from a function. Keep it explicit in a variable and then unpack dict here. That way it will be easier to pin point the error if any.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, you are unpacking all the kwargs. If any extra kwargs are given, does the self.auto_class.from_pretrained method accepts and discards it? If not then it will surely throw an error. Please check and correct it if needed. |
||
| ) | ||
|
|
||
| def load_tokenizer(self) -> AutoTokenizer: | ||
| """Load Hugging Face tokenizer.""" | ||
| logger.log_rank_zero(f"Loading tokenizer '{self.tokenizer_name}'") | ||
| tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) | ||
| insert_pad_token(tokenizer) | ||
| return tokenizer | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,136 @@ | ||
| # ----------------------------------------------------------------------------- | ||
| # | ||
| # Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
| # | ||
| # ----------------------------------------------------------------------------- | ||
|
|
||
| from unittest import mock | ||
|
|
||
| import pytest | ||
quic-swatia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| from QEfficient.finetune.experimental.core import model | ||
| from QEfficient.finetune.experimental.core.component_registry import ComponentFactory, registry | ||
| from QEfficient.finetune.experimental.core.model import BaseModel | ||
|
|
||
|
|
||
| class TestMockModel(nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.linear = nn.Linear(2, 2) | ||
|
|
||
| def forward(self, x): | ||
| return self.linear(x) | ||
|
|
||
|
|
||
| @registry.model("testcustom") | ||
| class TestCustomModel(BaseModel): | ||
| def __init__(self, model_name): | ||
| super().__init__(model_name) | ||
| print("init of custom class") | ||
|
|
||
| def load_model(self) -> nn.Module: | ||
| return TestMockModel() | ||
|
|
||
| def load_tokenizer(self): | ||
| return "dummy-tokenizer" | ||
|
|
||
|
|
||
quic-swatia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # BaseModel tests | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add tests to register the children class (of BaseModel) in registry and fetch correct instance of the class upon passing required parameters. |
||
| def test_model_property_errors_if_not_created(): | ||
| m = TestCustomModel("dummy") | ||
| with pytest.raises(RuntimeError): | ||
| _ = m.model # must call .create() | ||
|
|
||
|
|
||
| def test_create_builds_and_registers(): | ||
| m = ComponentFactory.create_model("testcustom", "dummy") | ||
| # inner model exists and registered | ||
| assert "_model" in m._modules | ||
| assert isinstance(m.model, TestMockModel) | ||
| # forward works | ||
| out = m(torch.zeros(1, 2)) | ||
| assert out.shape == (1, 2) | ||
|
|
||
|
|
||
| def test_tokenizer_lazy_loading(): | ||
| m = ComponentFactory.create_model("testcustom", "dummy") | ||
| assert m._tokenizer is None | ||
| tok = m.tokenizer | ||
| assert tok == "dummy-tokenizer" | ||
| assert m._tokenizer == tok | ||
quic-swatia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def test_to_moves_inner_and_returns_self(): | ||
| m = ComponentFactory.create_model("testcustom", "dummy") | ||
| with mock.patch.object(TestMockModel, "to", autospec=True) as mocked_to: | ||
| ret = m.to("cpu:0") | ||
| assert mocked_to.call_args[0][0] is m.model | ||
| assert mocked_to.call_args[0][1] == "cpu:0" | ||
| assert ret is m | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does this mean? This is not needed. |
||
|
|
||
|
|
||
| def test_train_eval_sync_flags(): | ||
| m = ComponentFactory.create_model("testcustom", "dummy") | ||
| m.eval() | ||
| assert m.training is False | ||
| assert m.model.training is False | ||
| m.train() | ||
| assert m.training is True | ||
| assert m.model.training is True | ||
|
|
||
|
|
||
| def test_state_dict_contains_inner_params(): | ||
| m = ComponentFactory.create_model("testcustom", "dummy") | ||
| sd = m.state_dict() | ||
| # should contain params from TestMockModel.linear | ||
| assert any("linear.weight" in k for k in sd) | ||
| assert any("linear.bias" in k for k in sd) | ||
|
|
||
|
|
||
| # HFModel tests | ||
| def test_hfmodel_invalid_auto_class_raises(): | ||
| with pytest.raises(ValueError): | ||
| ComponentFactory.create_model("hf", "hf-name", auto_class_name="AutoDoesNotExist") | ||
|
|
||
|
|
||
| def test_hfmodel_loads_auto_and_tokenizer(monkeypatch): | ||
| # fake HF Auto class | ||
| class FakeAuto(nn.Module): | ||
| @classmethod | ||
| def from_pretrained(cls, name, **kwargs): | ||
| inst = cls() | ||
| inst.loaded = (name, kwargs) | ||
| return inst | ||
|
|
||
| def forward(self, x): | ||
| return x | ||
|
|
||
| fake_tok = mock.Mock() | ||
|
|
||
| # Monkeypatch transformer classes used in HFModel | ||
| monkeypatch.setattr( | ||
| "QEfficient.finetune.experimental.core.model.transformers.AutoModelForCausalLM", | ||
| FakeAuto, | ||
| raising=False, | ||
| ) | ||
| monkeypatch.setattr( | ||
| model, | ||
| "AutoTokenizer", | ||
| mock.Mock(from_pretrained=mock.Mock(return_value=fake_tok)), | ||
| ) | ||
| monkeypatch.setattr( | ||
| "QEfficient.finetune.experimental.core.model.insert_pad_token", | ||
| mock.Mock(), | ||
| raising=False, | ||
| ) | ||
| m = ComponentFactory.create_model("hf", "hf-name") | ||
| assert isinstance(m.model, FakeAuto) | ||
|
|
||
| # load tokenizer | ||
| tok = m.load_tokenizer() | ||
|
|
||
| assert hasattr(tok, "pad_token_id") | ||
| assert m.model.loaded[0] == "hf-name" | ||
Uh oh!
There was an error while loading. Please reload this page.