|
8 | 8 | import json |
9 | 9 | import os |
10 | 10 | import subprocess |
| 11 | +import sys |
| 12 | +import warnings |
11 | 13 | from dataclasses import dataclass |
12 | 14 | from typing import Any, Dict, List, Optional, Tuple, Union |
13 | 15 |
|
14 | 16 | import requests |
15 | 17 | import torch |
16 | 18 | from huggingface_hub import login, snapshot_download |
17 | 19 | from requests.exceptions import HTTPError |
18 | | -from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast |
19 | | - |
| 20 | +from transformers import ( |
| 21 | + AutoConfig, |
| 22 | + AutoProcessor, |
| 23 | + AutoTokenizer, |
| 24 | + PreTrainedTokenizer, |
| 25 | + PreTrainedTokenizerFast, |
| 26 | +) |
| 27 | + |
| 28 | +from QEfficient.transformers.modeling_utils import ( |
| 29 | + SwiftKVModelCardNameToSwiftKVModelTypeDict, |
| 30 | + SwiftKVModelTypeToConfigClassAndModelArchClassDict, |
| 31 | +) |
20 | 32 | from QEfficient.utils.constants import QEFF_MODELS_DIR, Constants |
21 | 33 | from QEfficient.utils.logging_utils import logger |
22 | | -from QEfficient.transformers.modeling_utils import SwiftKVModelCardNameToSwiftKVModelTypeDict, SwiftKVModelTypeToConfigClassAndModelArchClassDict |
| 34 | + |
23 | 35 |
|
24 | 36 | class DownloadRetryLimitExceeded(Exception): |
25 | 37 | """ |
@@ -442,3 +454,63 @@ class IOInfo: |
442 | 454 |
|
443 | 455 | def __repr__(self): |
444 | 456 | return f"input_name:{self.name}\tdatatype:{self.datatype}\tshape:{self.shape}" |
| 457 | + |
| 458 | + |
| 459 | +def convert_str_to_class(className): |
| 460 | + """ |
| 461 | + Convert the string to class name |
| 462 | + --------- |
| 463 | + :className: `str`- Class name string. |
| 464 | + Return: |
| 465 | + Class Name |
| 466 | + """ |
| 467 | + return getattr(sys.modules[__name__], className) |
| 468 | + |
| 469 | + |
| 470 | +def register_swiftKV_model(model_type, SwiftkvConfigCls, SwiftKVModelCls): |
| 471 | + """ |
| 472 | + Register the SwiftKV Models |
| 473 | + --------------------------------------- |
| 474 | + : model_type: str: name of the swiftKVModel for example llama_swiftkv |
| 475 | + : SwiftkVConfigCls: SwiftKV Config class for example LlamaSwiftKVConfig |
| 476 | + : SwiftKVModelCls: SwiftKV model class name for example LlamaSwiftKVForCausalLM |
| 477 | + """ |
| 478 | + |
| 479 | + # Register the SwiftKV Config class using AutoConfig |
| 480 | + AutoConfig.register(model_type, SwiftkvConfigCls) |
| 481 | + |
| 482 | + # Construct the AutoModel class name using SwiftKVModel Class name, this code is written to make things generic |
| 483 | + swiftKvModelName = SwiftKVModelCls.__name__ |
| 484 | + start_index = swiftKvModelName.find("SwiftKVFor") |
| 485 | + |
| 486 | + # Calculate the index after "SwiftKVFor" |
| 487 | + substring_start = start_index + len("SwiftKVFor") |
| 488 | + |
| 489 | + # Get the substring after "SwiftKVFor" |
| 490 | + swiftKVModel = swiftKvModelName[substring_start:] |
| 491 | + |
| 492 | + AutoModelName = "AutoModelFor" + swiftKVModel |
| 493 | + |
| 494 | + # Convert the string to class name |
| 495 | + AutoModelClassName = convert_str_to_class(AutoModelName) |
| 496 | + |
| 497 | + # Register the SwiftKVModel Class and config class using AutoModelClass |
| 498 | + AutoModelClassName.register(SwiftkvConfigCls, SwiftKVModelCls) |
| 499 | + |
| 500 | + |
| 501 | +def QEFFLoadSwiftKVModels(pretrained_model_name_or_path): |
| 502 | + """ |
| 503 | + Load the SwiftKV Models |
| 504 | + --------------------------------------- |
| 505 | + : pretrained_model_name_or_path: str: name of the swiftKVModel for example Snowflake/Llama-3.1-SwiftKV-8B-Instruct |
| 506 | + """ |
| 507 | + try: |
| 508 | + modelType = SwiftKVModelCardNameToSwiftKVModelTypeDict[pretrained_model_name_or_path] |
| 509 | + |
| 510 | + SwiftKVConfigCls = SwiftKVModelTypeToConfigClassAndModelArchClassDict[modelType][0] |
| 511 | + SwiftKVModelArchCls = SwiftKVModelTypeToConfigClassAndModelArchClassDict[modelType][1] |
| 512 | + |
| 513 | + register_swiftKV_model(modelType, SwiftKVConfigCls, SwiftKVModelArchCls) |
| 514 | + |
| 515 | + except KeyError: |
| 516 | + warnings.warn("Requested SwiftKVModel is currently not supported... stay tuned for future releases", Warning) |
0 commit comments