Skip to content

Commit ed4e1a9

Browse files
committed
rebased
Signed-off-by: Onkar Chougule <[email protected]>
1 parent a73da94 commit ed4e1a9

File tree

5 files changed

+78
-11
lines changed

5 files changed

+78
-11
lines changed

QEfficient/transformers/modeling_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,5 @@ def _create_causal_mask(
378378
# While onboarding new models make sure to add the new SwiftKV model card names to this dictionary.
379379
SwiftKVModelTypeToConfigClassAndModelArchClassDict = {
380380
# LlamaSwiftKV Model
381-
"llama_swiftkv" : [LlamaSwiftKVConfig, LlamaSwiftKVForCausalLM]
381+
"llama_swiftkv": [LlamaSwiftKVConfig, LlamaSwiftKVForCausalLM]
382382
}
383-

QEfficient/transformers/models/llama_swiftkv/config_llama_swiftkv.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99

1010
"""Inference-only LLaMA model compatible with HuggingFace weights."""
1111

12-
13-
1412
from typing import Optional
1513
from transformers import LlamaConfig
1614

@@ -40,6 +38,4 @@ def __init__(
4038
self.swiftkv = swiftkv
4139
self.num_key_value_layers = num_key_value_layers or self.num_hidden_layers
4240
self.key_value_group_size = key_value_group_size or 1
43-
assert (
44-
self.num_hidden_layers - self.num_key_value_layers
45-
) % self.key_value_group_size == 0
41+
assert (self.num_hidden_layers - self.num_key_value_layers) % self.key_value_group_size == 0

QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
)
3030
from QEfficient.transformers.models.llama_swiftkv.config_llama_swiftkv import LlamaSwiftKVConfig
3131

32+
3233
class LlamaSwiftKVAttention(nn.Module):
3334
def __init__(self, config: LlamaSwiftKVConfig, layer_idx) -> None:
3435
super().__init__()

QEfficient/transformers/models/modeling_auto.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ def __repr__(self) -> str:
8080
@classmethod
8181
@with_replaced_quantizers
8282
def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = False, *args, **kwargs):
83-
8483
# Load the SwiftKV model if supported
8584
QEFFLoadSwiftKVModels(pretrained_model_name_or_path)
8685

QEfficient/utils/_utils.py

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,30 @@
88
import json
99
import os
1010
import subprocess
11+
import sys
12+
import warnings
1113
from dataclasses import dataclass
1214
from typing import Any, Dict, List, Optional, Tuple, Union
1315

1416
import requests
1517
import torch
1618
from huggingface_hub import login, snapshot_download
1719
from requests.exceptions import HTTPError
18-
from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
19-
20+
from transformers import (
21+
AutoConfig,
22+
AutoProcessor,
23+
AutoTokenizer,
24+
PreTrainedTokenizer,
25+
PreTrainedTokenizerFast,
26+
)
27+
28+
from QEfficient.transformers.modeling_utils import (
29+
SwiftKVModelCardNameToSwiftKVModelTypeDict,
30+
SwiftKVModelTypeToConfigClassAndModelArchClassDict,
31+
)
2032
from QEfficient.utils.constants import QEFF_MODELS_DIR, Constants
2133
from QEfficient.utils.logging_utils import logger
22-
from QEfficient.transformers.modeling_utils import SwiftKVModelCardNameToSwiftKVModelTypeDict, SwiftKVModelTypeToConfigClassAndModelArchClassDict
34+
2335

2436
class DownloadRetryLimitExceeded(Exception):
2537
"""
@@ -442,3 +454,63 @@ class IOInfo:
442454

443455
def __repr__(self):
444456
return f"input_name:{self.name}\tdatatype:{self.datatype}\tshape:{self.shape}"
457+
458+
459+
def convert_str_to_class(className):
460+
"""
461+
Convert the string to class name
462+
---------
463+
:className: `str`- Class name string.
464+
Return:
465+
Class Name
466+
"""
467+
return getattr(sys.modules[__name__], className)
468+
469+
470+
def register_swiftKV_model(model_type, SwiftkvConfigCls, SwiftKVModelCls):
471+
"""
472+
Register the SwiftKV Models
473+
---------------------------------------
474+
: model_type: str: name of the swiftKVModel for example llama_swiftkv
475+
: SwiftkVConfigCls: SwiftKV Config class for example LlamaSwiftKVConfig
476+
: SwiftKVModelCls: SwiftKV model class name for example LlamaSwiftKVForCausalLM
477+
"""
478+
479+
# Register the SwiftKV Config class using AutoConfig
480+
AutoConfig.register(model_type, SwiftkvConfigCls)
481+
482+
# Construct the AutoModel class name using SwiftKVModel Class name, this code is written to make things generic
483+
swiftKvModelName = SwiftKVModelCls.__name__
484+
start_index = swiftKvModelName.find("SwiftKVFor")
485+
486+
# Calculate the index after "SwiftKVFor"
487+
substring_start = start_index + len("SwiftKVFor")
488+
489+
# Get the substring after "SwiftKVFor"
490+
swiftKVModel = swiftKvModelName[substring_start:]
491+
492+
AutoModelName = "AutoModelFor" + swiftKVModel
493+
494+
# Convert the string to class name
495+
AutoModelClassName = convert_str_to_class(AutoModelName)
496+
497+
# Register the SwiftKVModel Class and config class using AutoModelClass
498+
AutoModelClassName.register(SwiftkvConfigCls, SwiftKVModelCls)
499+
500+
501+
def QEFFLoadSwiftKVModels(pretrained_model_name_or_path):
502+
"""
503+
Load the SwiftKV Models
504+
---------------------------------------
505+
: pretrained_model_name_or_path: str: name of the swiftKVModel for example Snowflake/Llama-3.1-SwiftKV-8B-Instruct
506+
"""
507+
try:
508+
modelType = SwiftKVModelCardNameToSwiftKVModelTypeDict[pretrained_model_name_or_path]
509+
510+
SwiftKVConfigCls = SwiftKVModelTypeToConfigClassAndModelArchClassDict[modelType][0]
511+
SwiftKVModelArchCls = SwiftKVModelTypeToConfigClassAndModelArchClassDict[modelType][1]
512+
513+
register_swiftKV_model(modelType, SwiftKVConfigCls, SwiftKVModelArchCls)
514+
515+
except KeyError:
516+
warnings.warn("Requested SwiftKVModel is currently not supported... stay tuned for future releases", Warning)

0 commit comments

Comments
 (0)