Skip to content

Commit 82670d9

Browse files
feat: add quant to mixtral (#1337)
1 parent ec6d459 commit 82670d9

File tree

4 files changed

+184
-35
lines changed

4 files changed

+184
-35
lines changed

server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -434,8 +434,6 @@ def __init__(self, config, weights):
434434
weights=weights,
435435
)
436436
self.max_past = config.sliding_window
437-
if self.max_past is None:
438-
raise ValueError("max_past cannot be None")
439437

440438
def forward(
441439
self,
@@ -454,7 +452,7 @@ def forward(
454452
if prefill_cache_indices is not None:
455453
# Slots also need to be sliced as it has the same size as the whole kv tensor
456454
slots = slots[prefill_cache_indices]
457-
else:
455+
elif self.max_past is not None:
458456
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
459457
# kernel requires the true values
460458
max_s = min(self.max_past, max_s)

server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py

Lines changed: 161 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,9 @@ def __init__(self, prefix, config: MixtralConfig, weights):
365365
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
366366

367367
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
368-
self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).t()
368+
self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights)
369369
self.w2 = _load_experts(config, f"{prefix}.experts", "w2", weights)
370-
self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).t()
370+
self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights)
371371

372372
self.offsets = None
373373
self.offsets_block_rows = 0
@@ -467,8 +467,7 @@ def indices_and_padded_bins(self, selected_experts: torch.Tensor):
467467

468468
return indices, bin_ids, bins, padded_bins, tokens_per_expert
469469

470-
@torch.inference_mode()
471-
def forward(self, x: torch.Tensor) -> torch.Tensor:
470+
def sparse_forward(self, x: torch.Tensor) -> torch.Tensor:
472471
"""
473472
x: (sequence_length, model_dim)
474473
gate_logits: (sequence_length, n_experts)
@@ -502,8 +501,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
502501
# (top_k * sequence_length + padding, ffn_dim * n_experts)
503502
x = stk.Matrix(
504503
topo.size(),
505-
self.act(stk.ops.sdd(x, self.w1, topo).data)
506-
* stk.ops.sdd(x, self.w3, topo).data,
504+
self.act(stk.ops.sdd(x, self.w1.t(), topo).data)
505+
* stk.ops.sdd(x, self.w3.t(), topo).data,
507506
topo.row_indices,
508507
topo.column_indices,
509508
topo.offsets,
@@ -534,6 +533,156 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
534533

535534
return x.view(*input_shape)
536535

536+
def dense_forward(self, x: torch.Tensor) -> torch.Tensor:
537+
"""
538+
x: (sequence_length, model_dim)
539+
gate_logits: (sequence_length, n_experts)
540+
"""
541+
# optional reshape
542+
input_shape = x.shape
543+
x = x.view(-1, input_shape[-1])
544+
545+
# gate_logits: (sequence_length, n_experts)
546+
gate_logits = self.gate(x)
547+
# all_probs: (sequence_length, n_experts) and upcast for softmax
548+
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
549+
550+
if self.top_k < self.num_experts:
551+
_, not_selected_experts = torch.topk(
552+
all_probs,
553+
self.num_experts - self.top_k,
554+
largest=False,
555+
sorted=False,
556+
dim=1,
557+
)
558+
# Mask not selected experts
559+
all_probs.scatter_(1, not_selected_experts, 0)
560+
561+
# Re-normalize
562+
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
563+
564+
# Expand to [num_experts, sequence_length, model_dim]
565+
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
566+
567+
# Permute to [num_experts, model_dim, ffn_dim]
568+
w1 = self.w1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
569+
0, 2, 1
570+
)
571+
w3 = self.w3.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
572+
0, 2, 1
573+
)
574+
575+
inter = self.act(torch.bmm(x, w1)) * torch.bmm(x, w3)
576+
577+
out = torch.bmm(
578+
inter, self.w2.view(self.num_experts, self.ffn_dim, self.hidden_dim)
579+
)
580+
# Mask not selected experts
581+
out *= weights.t().view(self.num_experts, -1, 1)
582+
583+
# Sum experts
584+
out = out.sum(0)
585+
586+
# Reduce sum
587+
if self.process_group.size() > 1:
588+
torch.distributed.all_reduce(out, group=self.process_group)
589+
590+
return out
591+
592+
def forward(self, x: torch.Tensor) -> torch.Tensor:
593+
if len(x) > 256:
594+
return self.sparse_forward(x)
595+
# This is faster when there is not a lot of tokens
596+
return self.dense_forward(x)
597+
598+
599+
class DenseMoE(nn.Module):
600+
def __init__(self, prefix, config: MixtralConfig, weights):
601+
super().__init__()
602+
self.hidden_dim = config.hidden_size
603+
self.ffn_dim = config.intermediate_size // weights.process_group.size()
604+
self.num_experts = config.num_local_experts
605+
self.top_k = config.num_experts_per_tok
606+
607+
act = config.hidden_act
608+
if "gelu" in act:
609+
self.act = lambda x: torch.nn.functional.gelu(
610+
x,
611+
approximate="tanh"
612+
if act in ["gelu_fast", "gelu_pytorch_tanh"]
613+
else "none",
614+
)
615+
elif "silu" in act:
616+
self.act = torch.nn.functional.silu
617+
else:
618+
self.act = ACT2FN[act]
619+
620+
# gating
621+
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
622+
623+
self.w1 = [
624+
TensorParallelColumnLinear.load(
625+
config, prefix=f"{prefix}.experts.{i}.w1", weights=weights, bias=False
626+
)
627+
for i in range(self.num_experts)
628+
]
629+
self.w3 = [
630+
TensorParallelColumnLinear.load(
631+
config, prefix=f"{prefix}.experts.{i}.w3", weights=weights, bias=False
632+
)
633+
for i in range(self.num_experts)
634+
]
635+
self.w2 = [
636+
TensorParallelRowLinear.load(
637+
config, prefix=f"{prefix}.experts.{i}.w2", weights=weights, bias=False
638+
)
639+
for i in range(self.num_experts)
640+
]
641+
642+
self.process_group = weights.process_group
643+
644+
def forward(self, x: torch.Tensor) -> torch.Tensor:
645+
"""
646+
x: (sequence_length, model_dim)
647+
gate_logits: (sequence_length, n_experts)
648+
"""
649+
# optional reshape
650+
input_shape = x.shape
651+
x = x.view(-1, input_shape[-1])
652+
653+
# gate_logits: (sequence_length, n_experts)
654+
gate_logits = self.gate(x)
655+
# all_probs: (sequence_length, n_experts) and upcast for softmax
656+
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
657+
658+
if self.top_k < self.num_experts:
659+
_, not_selected_experts = torch.topk(
660+
all_probs,
661+
self.num_experts - self.top_k,
662+
largest=False,
663+
sorted=False,
664+
dim=1,
665+
)
666+
# Mask not selected experts
667+
all_probs.scatter_(1, not_selected_experts, 0)
668+
669+
# Re-normalize
670+
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
671+
672+
# Final output tensor
673+
out = x.new_zeros(x.shape[0], self.hidden_dim)
674+
for i in range(self.num_experts):
675+
h = self.act(self.w1[i](x)) * self.w3[i](x)
676+
h = self.w2[i](h, reduce=False)
677+
# Add expert output to out with masking
678+
out += h * weights[:, i].view(-1, 1)
679+
680+
# Reduce sum
681+
if self.process_group.size() > 1:
682+
torch.distributed.all_reduce(out, group=self.process_group)
683+
684+
return out
685+
537686

538687
class MixtralLayer(nn.Module):
539688
def __init__(self, layer_id, config, weights):
@@ -543,9 +692,9 @@ def __init__(self, layer_id, config, weights):
543692
self.self_attn = MixtralAttention(
544693
prefix=f"{prefix}.self_attn", config=config, weights=weights
545694
)
546-
self.block_sparse_moe = BlockSparseMoE(
547-
f"{prefix}.block_sparse_moe", config, weights
548-
)
695+
696+
moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE
697+
self.moe = moe_cls(f"{prefix}.block_sparse_moe", config, weights)
549698

550699
self.input_layernorm = FastRMSNorm.load(
551700
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
@@ -591,9 +740,9 @@ def forward(
591740
attn_output, res
592741
)
593742

594-
block_sparse_moe_output = self.block_sparse_moe(normed_attn_res_output)
743+
moe_output = self.moe(normed_attn_res_output)
595744

596-
return block_sparse_moe_output, attn_res
745+
return moe_output, attn_res
597746

598747

599748
class MixtralModel(torch.nn.Module):
@@ -675,8 +824,6 @@ def __init__(self, config, weights):
675824
weights=weights,
676825
)
677826
self.max_past = config.sliding_window
678-
if self.max_past is None:
679-
raise ValueError("max_past cannot be None")
680827

681828
def forward(
682829
self,
@@ -695,7 +842,7 @@ def forward(
695842
if prefill_cache_indices is not None:
696843
# Slots also need to be sliced as it has the same size as the whole kv tensor
697844
slots = slots[prefill_cache_indices]
698-
else:
845+
elif self.max_past is not None:
699846
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
700847
# kernel requires the true values
701848
max_s = min(self.max_past, max_s)

server/text_generation_server/models/flash_mistral.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,9 @@ def from_pb(
136136
total_tokens = input_length + max_new_tokens - 1 + speculative_length
137137

138138
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
139-
needed_blocks = min(
140-
math.ceil(total_tokens / BLOCK_SIZE), SLIDING_WINDOW_BLOCKS
141-
)
139+
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
140+
if SLIDING_WINDOW_BLOCKS is not None:
141+
needed_blocks = min(needed_blocks, SLIDING_WINDOW_BLOCKS)
142142
blocks += needed_blocks
143143

144144
needed_blocks_slots.append((needed_blocks, total_tokens))
@@ -152,12 +152,13 @@ def from_pb(
152152
slot_indices.append(request_slot_indices)
153153

154154
# Create tensor to slice into the kv tensor in prefill
155-
request_prefill_cache_indices = torch.arange(
156-
cumulative_length + max(0, input_length - SLIDING_WINDOW),
157-
cumulative_length + input_length,
158-
dtype=torch.int64,
159-
)
160-
prefill_cache_indices.append(request_prefill_cache_indices)
155+
if SLIDING_WINDOW is not None:
156+
request_prefill_cache_indices = torch.arange(
157+
cumulative_length + max(0, input_length - SLIDING_WINDOW),
158+
cumulative_length + input_length,
159+
dtype=torch.int64,
160+
)
161+
prefill_cache_indices.append(request_prefill_cache_indices)
161162

162163
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
163164
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
@@ -209,20 +210,24 @@ def from_pb(
209210
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
210211
position_ids = torch.cat(position_ids)
211212
slot_indices = torch.cat(slot_indices)
212-
prefill_cache_indices = torch.cat(prefill_cache_indices)
213+
if SLIDING_WINDOW is not None:
214+
prefill_cache_indices = torch.cat(prefill_cache_indices)
213215
else:
214216
input_ids = all_input_ids[0]
215217
position_ids = position_ids[0]
216218
slot_indices = slot_indices[0]
217-
prefill_cache_indices = prefill_cache_indices[0]
219+
if SLIDING_WINDOW is not None:
220+
prefill_cache_indices = prefill_cache_indices[0]
218221

219222
cu_seqlen_prefill = torch.tensor(
220223
cu_seqlen_prefill, device=device, dtype=torch.int32
221224
)
222225

223226
position_ids = position_ids.to(device)
224227
slot_indices = slot_indices.to(device)
225-
prefill_cache_indices = prefill_cache_indices.to(device)
228+
prefill_cache_indices = (
229+
prefill_cache_indices.to(device) if SLIDING_WINDOW is not None else None
230+
)
226231
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
227232
input_lengths_tensor = torch.tensor(
228233
input_lengths, dtype=torch.int32, device=device
@@ -314,8 +319,9 @@ def __init__(
314319
config.quantize = quantize
315320

316321
# Set context windows
317-
SLIDING_WINDOW = config.sliding_window
318-
SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE)
322+
if config.sliding_window is not None:
323+
SLIDING_WINDOW = config.sliding_window
324+
SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE)
319325

320326
torch.distributed.barrier(group=self.process_group)
321327

server/text_generation_server/utils/layers.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@
6464
except ImportError:
6565
pass
6666

67-
from typing import Optional
68-
6967
HAS_EETQ = False
7068
try:
7169
from EETQ import quant_weights, w8_a16_gemm
@@ -489,9 +487,9 @@ def load(cls, config, prefix: str, weights, bias: bool):
489487
process_group=weights.process_group,
490488
)
491489

492-
def forward(self, input: torch.Tensor) -> torch.Tensor:
490+
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
493491
out = super().forward(input)
494-
if self.process_group.size() > 1:
492+
if self.process_group.size() > 1 and reduce:
495493
torch.distributed.all_reduce(out, group=self.process_group)
496494
return out
497495

0 commit comments

Comments
 (0)