diff --git a/.gitignore b/.gitignore index 8b56a79c..9b14f350 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ build/ __pycache__/ .idea venv -dist \ No newline at end of file +dist +*.so \ No newline at end of file diff --git a/exllamav2/model.py b/exllamav2/model.py index 8154b0ef..a4346efa 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -130,14 +130,14 @@ def __init__(self, config: ExLlamaV2Config, lazy_load = False): self.modules.append(ExLlamaV2Embedding(self, "model.embed_tokens")) self.modules_dict[self.modules[-1].key] = self.modules[-1] - for layer_idx in range(self.config.num_hidden_layers): + for layer_list in range(self.config.num_hidden_layers): - self.modules.append(ExLlamaV2Attention(self, f"model.layers.{layer_idx}", layer_idx)) + self.modules.append(ExLlamaV2Attention(self, f"model.layers.{layer_list}", layer_list)) for m in self.modules[-1].submodules: self.modules_dict[m.key] = m if self.config.architecture == "Mixtral": - self.modules.append(ExLlamaV2MoEMLP(self, f"model.layers.{layer_idx}", layer_idx)) + self.modules.append(ExLlamaV2MoEMLP(self, f"model.layers.{layer_list}", layer_list)) else: - self.modules.append(ExLlamaV2MLP(self, f"model.layers.{layer_idx}", layer_idx)) + self.modules.append(ExLlamaV2MLP(self, f"model.layers.{layer_list}", layer_list)) for m in self.modules[-1].submodules: self.modules_dict[m.key] = m @@ -150,13 +150,32 @@ def __init__(self, config: ExLlamaV2Config, lazy_load = False): # Find last layer that affects k/v cache - layer_idx = len(self.modules) + layer_list = len(self.modules) while True: - layer_idx -= 1 - if isinstance(self.modules[layer_idx], ExLlamaV2Attention): + layer_list -= 1 + if isinstance(self.modules[layer_list], ExLlamaV2Attention): break - self.last_kv_layer_idx = layer_idx + self.last_kv_layer_idx = layer_list + + if hasattr(config, 'repeats'): + embedTokenLayers = 1 + transformerSublayers = 2 + layer_arrangement = [list(range(*interval)) for interval in config.repeats] + layer_arrangement = [item for sublist in layer_arrangement for item in sublist] + + + LayeredModules = self.modules[:embedTokenLayers] + for idx in layer_arrangement: + LayeredModules += self.modules[idx*transformerSublayers + embedTokenLayers : idx*transformerSublayers + transformerSublayers + embedTokenLayers] + LayeredModules += self.modules[-2:] + self.head_layer_idx = len(self.modules) -1 + self.last_kv_layer_idx = len(self.modules) -4 + + for i, m in enumerate(LayeredModules): + print(i, m.key) + + self.layeredModules = LayeredModules def set_device_map(self, allocation, embed_cpu = True): @@ -582,6 +601,23 @@ def _forward(self, return_last_state = False, position_offsets = None): + def process_module(module, x, last_state): + device = _torch_device(module.device_idx) + + if idx == self.head_layer_idx: + if last_id_only and return_last_state: + x = x.narrow(-2, -1, 1) + last_state = x + elif last_id_only: + x = x.narrow(-2, -1, 1) + elif return_last_state: + last_state = x.narrow(-2, -1, 1) + + x = safe_move_tensor(x, device) + x = module.forward(x, cache=cache, attn_params=attn_params, past_len=past_len, loras=loras) + + return x, last_state + batch_size, seq_len = input_ids.shape past_len = 0 if cache is not None: @@ -596,27 +632,18 @@ def _forward(self, attn_params = ExLlamaV2Attention.Params(batch_size, seq_len, past_len, input_mask, position_offsets) last_state = None - for idx, module in enumerate(self.modules): - - device = _torch_device(module.device_idx) - - # Onward - - if idx == self.head_layer_idx: - if last_id_only and return_last_state: - x = x.narrow(-2, -1, 1) - last_state = x - elif last_id_only: - x = x.narrow(-2, -1, 1) - elif return_last_state: - last_state = x.narrow(-2, -1, 1) - - x = safe_move_tensor(x, device) - x = module.forward(x, cache = cache, attn_params = attn_params, past_len = past_len, loras = loras) - - if preprocess_only and idx == self.last_kv_layer_idx: - x = None - break + if hasattr(self, 'layeredModules'): + for idx, module in enumerate(self.layeredModules): + x, last_state = process_module(module, x, last_state) + if preprocess_only and idx == self.last_kv_layer_idx: + x = None + break + else: + for idx, module in enumerate(self.modules): + x, last_state = process_module(module, x, last_state) + if preprocess_only and idx == self.last_kv_layer_idx: + x = None + break # Advance cache diff --git a/exllamav2/model_init.py b/exllamav2/model_init.py index 834aa465..89c9708b 100644 --- a/exllamav2/model_init.py +++ b/exllamav2/model_init.py @@ -1,5 +1,5 @@ -import argparse, sys, os, glob +import argparse, sys, os, glob, ast from exllamav2 import( ExLlamaV2, @@ -17,6 +17,7 @@ def add_args(parser): parser.add_argument("-nfa", "--no_flash_attn", action = "store_true", help = "Disable Flash Attention") parser.add_argument("-lm", "--low_mem", action = "store_true", help = "Enable VRAM optimizations, potentially trading off speed") parser.add_argument("-ept", "--experts_per_token", type = int, help = "Override MoE model's default number of experts per token") + parser.add_argument("--repeats", type=parse_tuple_list, help="List of tuples of the layers to repeat") def print_options(args): @@ -60,6 +61,22 @@ def check_args(args): print(f" ## Error: Cannot find {filename} in {args.model_dir}") sys.exit() +def parse_tuple_list(string): + try: + # Safely evaluate the string as a Python literal (list of tuples) + tuple_list = ast.literal_eval(string) + + # Ensure all elements in the list are tuples + if not all(isinstance(item, tuple) for item in tuple_list): + raise ValueError("All elements must be tuples") + + # Convert tuple elements to integers + int_tuple_list = [tuple(int(x) for x in item) for item in tuple_list] + + return int_tuple_list + except: + raise argparse.ArgumentTypeError("Input must be a valid list of tuples with integer elements") + def init(args, quiet = False, allow_auto_split = False, skip_load = False): @@ -76,7 +93,8 @@ def init(args, quiet = False, allow_auto_split = False, skip_load = False): if args.rope_alpha: config.scale_alpha_value = args.rope_alpha config.no_flash_attn = args.no_flash_attn if args.experts_per_token: config.num_experts_per_token = args.experts_per_token - + if args.repeats: config.repeats = args.repeats + # Set low-mem options if args.low_mem: config.set_low_mem()