diff --git a/exllamav2/attn.py b/exllamav2/attn.py index 4370dceb..7e2d2909 100644 --- a/exllamav2/attn.py +++ b/exllamav2/attn.py @@ -16,6 +16,7 @@ import torch.nn.functional as F import inspect import os +from exllamav2.util import substitute_inf_with_max from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -964,17 +965,17 @@ def forward(self, use_flash_attn = has_flash_attn and not cfg.no_flash_attn if isinstance(attn_params, ExLlamaV2Attention.PagedParams): - return self.forward_paged( + return substitute_inf_with_max(self.forward_paged( hidden_states, cache, attn_params, loras = loras, **kwargs - ) + )) if self.is_tp: if cache is not None and use_flash_attn: - return self.forward_tp( + return substitute_inf_with_max(self.forward_tp( hidden_states, cache, attn_params, @@ -982,11 +983,11 @@ def forward(self, intermediates, loras, **kwargs, - ) + )) else: # TODO: Can't use the optimized forward function because it writes directly to a fixed output # tensor, and flash-attn currently has a bug that prevents that from working when q_len == 1 - return self.forward_tp_old( + return substitute_inf_with_max(self.forward_tp_old( hidden_states, cache, attn_params, @@ -994,7 +995,7 @@ def forward(self, intermediates, loras, **kwargs, - ) + )) if self.q_handle is None or intermediates: return self.forward_torch( @@ -1113,7 +1114,7 @@ def forward(self, if cfg.arch.clamp_hidden_states: hidden_states.clamp_(-65504, 65504) - return hidden_states + return substitute_inf_with_max(hidden_states) def forward_tp( self, @@ -1428,9 +1429,9 @@ def forward_torch( if intermediates: return {"post_norm": post_norm, "attn_output": attn_output, - "hidden_states": hidden_states} + "hidden_states": substitute_inf_with_max(hidden_states)} else: - return hidden_states + return substitute_inf_with_max(hidden_states) def update_loras(self): diff --git a/exllamav2/embedding.py b/exllamav2/embedding.py index 48168b2d..b67ee649 100644 --- a/exllamav2/embedding.py +++ b/exllamav2/embedding.py @@ -9,6 +9,8 @@ if TYPE_CHECKING: from exllamav2.model import ExLlamaV2 +from exllamav2.util import substitute_inf_with_max + EMBEDDING_INDEX: int = 1000000 class ExLlamaV2Embedding(ExLlamaV2Module): @@ -185,6 +187,6 @@ def forward( hidden_states = ctx.copy_pinned(0, hidden_states) if intermediates: - return {"hidden_states": hidden_states} + return {"hidden_states": substitute_inf_with_max(hidden_states)} else: - return hidden_states + return substitute_inf_with_max(hidden_states) diff --git a/exllamav2/headnorm.py b/exllamav2/headnorm.py index b890ba11..dbf6c9ec 100644 --- a/exllamav2/headnorm.py +++ b/exllamav2/headnorm.py @@ -8,6 +8,8 @@ if TYPE_CHECKING: from exllamav2.model import ExLlamaV2 +from exllamav2.util import substitute_inf_with_max + class ExLlamaV2HeadNorm(ExLlamaV2Module): name: str = "LayerNorm" @@ -122,9 +124,9 @@ def forward( self.variance_epsilon) if intermediates: - return {"hidden_states": hidden_states} + return {"hidden_states": substitute_inf_with_max(hidden_states)} else: - return hidden_states + return substitute_inf_with_max(hidden_states) def forward_torch( self, @@ -146,8 +148,8 @@ def forward_torch( hidden_states = hidden_states.to(input_dtype) if intermediates: - return {"hidden_states": hidden_states} + return {"hidden_states": substitute_inf_with_max(hidden_states)} else: - return hidden_states + return substitute_inf_with_max(hidden_states) diff --git a/exllamav2/layernorm.py b/exllamav2/layernorm.py index 7b8f6c5b..0a2092cd 100644 --- a/exllamav2/layernorm.py +++ b/exllamav2/layernorm.py @@ -8,6 +8,8 @@ if TYPE_CHECKING: from exllamav2.model import ExLlamaV2 +from exllamav2.util import substitute_inf_with_max + class ExLlamaV2LayerNorm(ExLlamaV2Module): name: str = "LayerNorm" @@ -119,9 +121,9 @@ def forward( hidden_states = norm.view(output_shape) if intermediates: - return {"hidden_states": hidden_states} + return {"hidden_states": substitute_inf_with_max(hidden_states)} else: - return hidden_states + return substitute_inf_with_max(hidden_states) def forward_torch( @@ -139,8 +141,8 @@ def forward_torch( hidden_states = self.layernorm(hidden_states) if intermediates: - return {"hidden_states": hidden_states} + return {"hidden_states": substitute_inf_with_max(hidden_states)} else: - return hidden_states + return substitute_inf_with_max(hidden_states) diff --git a/exllamav2/linear.py b/exllamav2/linear.py index 5d6855dd..61504814 100644 --- a/exllamav2/linear.py +++ b/exllamav2/linear.py @@ -7,7 +7,7 @@ from exllamav2.module import ExLlamaV2Module from exllamav2.compat import safe_move_tensor from exllamav2.tensor_p import BROADCAST_VC -from exllamav2.util import unpack_4bit, pack_4bit +from exllamav2.util import unpack_4bit, pack_4bit, substitute_inf_with_max import gc from typing import TYPE_CHECKING @@ -295,8 +295,7 @@ def temp_fwd_size(self) -> int: max_len = self.model.config.max_input_len if self.max_out_len is None else \ min(self.max_out_len, self.model.config.max_input_len) return self.out_features * max_len * self.model.config.max_batch_size * 4 + 128 - - + def forward( self, hidden_states: torch.Tensor, @@ -312,7 +311,7 @@ def forward( if self.is_tp: if self.out_features_tp: - return self.forward_tp( + return substitute_inf_with_max(self.forward_tp( hidden_states, cache, attn_params, @@ -322,9 +321,9 @@ def forward( force_recons, force_cuda, **kwargs - ) + )) elif self.in_features_tp: - return self.forward_tp_row( + return substitute_inf_with_max(self.forward_tp_row( hidden_states, cache, attn_params, @@ -334,7 +333,7 @@ def forward( force_recons, force_cuda, **kwargs - ) + )) else: assert False, "Unitialized TP linear layer" @@ -344,9 +343,9 @@ def forward( hidden_states_out = loras[0].lm_head(hidden_states) if intermediates: - return {"hidden_states": hidden_states_out} + return {"hidden_states": substitute_inf_with_max(hidden_states_out)} else: - return hidden_states_out + return substitute_inf_with_max(hidden_states_out) if self.q_handle is not None and not force_recons: @@ -380,9 +379,9 @@ def forward( hidden_states_out += torch.matmul(temp, lora_b) if intermediates: - return {"hidden_states": hidden_states_out} + return {"hidden_states": substitute_inf_with_max(hidden_states_out)} else: - return hidden_states_out + return substitute_inf_with_max(hidden_states_out) def forward_tp( diff --git a/exllamav2/mlp.py b/exllamav2/mlp.py index 2d8282d5..27525932 100644 --- a/exllamav2/mlp.py +++ b/exllamav2/mlp.py @@ -9,6 +9,7 @@ from exllamav2.ext import exllamav2_ext as ext_c, none_tensor from exllamav2.lora import ExLlamaV2Lora from exllamav2.tensor_p import BROADCAST_ID, BROADCAST_RS +from exllamav2.util import substitute_inf_with_max # from line_profiler import profile from typing import TYPE_CHECKING @@ -288,7 +289,7 @@ def forward( ) -> torch.Tensor | dict[str: torch.Tensor]: if self.is_tp: - return self.forward_tp( + return substitute_inf_with_max(self.forward_tp( hidden_states, cache, attn_params, @@ -296,7 +297,7 @@ def forward( intermediates, loras, **kwargs - ) + )) cfg = self.model.config @@ -319,7 +320,7 @@ def forward( if cfg.arch.clamp_hidden_states: hidden_states.clamp_(-65504, 65504) - return hidden_states + return substitute_inf_with_max(hidden_states) # @profile @@ -457,9 +458,9 @@ def forward_torch( if intermediates: return {"post_norm": post_norm, "pre_down": y, - "hidden_states": hidden_states} + "hidden_states": substitute_inf_with_max(hidden_states)} else: - return hidden_states + return substitute_inf_with_max(hidden_states) def update_loras(self): diff --git a/exllamav2/model.py b/exllamav2/model.py index 3fb4f5be..eaf57d3f 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -2,6 +2,7 @@ import os, sys from exllamav2.architecture import RopeStyle +from exllamav2.util import substitute_inf_with_max min_version = (3, 8) if sys.version_info < min_version: @@ -820,9 +821,9 @@ def forward( if abort_event and abort_event.is_set(): return if "last_state" in result: - return result.get("logits"), result["last_state"] + return substitute_inf_with_max(result.get("logits")), substitute_inf_with_max(result["last_state"]) else: - return result.get("logits") + return substitute_inf_with_max(result.get("logits")) # Confirm that the input fits within the allocated cache space @@ -893,9 +894,9 @@ def forward( last_state = r.get("last_state") if last_state is None: - return result + return substitute_inf_with_max(result) else: - return result, last_state + return substitute_inf_with_max(result), substitute_inf_with_max(last_state) @torch.inference_mode() diff --git a/exllamav2/module.py b/exllamav2/module.py index 5bd672a6..c671bbf0 100644 --- a/exllamav2/module.py +++ b/exllamav2/module.py @@ -4,6 +4,7 @@ from exllamav2.config import ExLlamaV2Config from exllamav2.fasttensors import STFile from exllamav2.compat import safe_move_tensor +from exllamav2.util import substitute_inf_with_max from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -282,4 +283,4 @@ def forward(self, hidden_states, *args, **kwargs): hidden_states = self.post_forward(hidden_states, *args, **kwargs) hidden_states = safe_move_tensor(hidden_states, dev) - return hidden_states + return substitute_inf_with_max(hidden_states) diff --git a/exllamav2/moe_mlp.py b/exllamav2/moe_mlp.py index 403c5ca1..d67ef371 100644 --- a/exllamav2/moe_mlp.py +++ b/exllamav2/moe_mlp.py @@ -7,6 +7,7 @@ from exllamav2.linear import ExLlamaV2Linear from exllamav2.lora import ExLlamaV2Lora from exllamav2.ext import exllamav2_ext as ext_c, none_tensor +from exllamav2.util import substitute_inf_with_max from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -244,7 +245,7 @@ def forward( # ext_c.q_moe_mlp_forward_(self.q_handle, hidden_states.view(-1, hidden_states.shape[-1]), pass_loras, pass_lora_temp) ext_c.q_moe_mlp_forward_(self.q_handle, hidden_states.view(-1, hidden_states.shape[-1])) - return hidden_states + return substitute_inf_with_max(hidden_states) def forward_torch( @@ -313,9 +314,9 @@ def forward_torch( if intermediates: result["hidden_states"] = final_hidden_states - return result + return substitute_inf_with_max(result) else: - return final_hidden_states + return substitute_inf_with_max(final_hidden_states) def update_loras(self): diff --git a/exllamav2/parallel_decoder.py b/exllamav2/parallel_decoder.py index be772eca..902dd3f7 100644 --- a/exllamav2/parallel_decoder.py +++ b/exllamav2/parallel_decoder.py @@ -9,6 +9,7 @@ from exllamav2.lora import ExLlamaV2Lora from exllamav2.layernorm import ExLlamaV2LayerNorm from exllamav2.ext import exllamav2_ext as ext_c, none_tensor +from exllamav2.util import substitute_inf_with_max from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -119,7 +120,7 @@ def forward( b = self.mlp.forward(b, cache, attn_params, past_len, intermediates, loras, **kwargs) hidden_states += a hidden_states += b - return hidden_states + return substitute_inf_with_max(hidden_states) def forward_interm( diff --git a/exllamav2/pos_embedding.py b/exllamav2/pos_embedding.py index d7b85625..ac2e9eeb 100644 --- a/exllamav2/pos_embedding.py +++ b/exllamav2/pos_embedding.py @@ -4,6 +4,7 @@ from exllamav2.module import ExLlamaV2Module from exllamav2.attn import ExLlamaV2Attention from exllamav2.compat import safe_move_tensor +from exllamav2.util import substitute_inf_with_max from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -118,6 +119,6 @@ def forward( hidden_states[b, target_a:target_b] += emb_slice if intermediates: - return {"hidden_states": hidden_states} + return {"hidden_states": substitute_inf_with_max(hidden_states)} else: - return hidden_states + return substitute_inf_with_max(hidden_states) diff --git a/exllamav2/rmsnorm.py b/exllamav2/rmsnorm.py index be6f00aa..8f640857 100644 --- a/exllamav2/rmsnorm.py +++ b/exllamav2/rmsnorm.py @@ -4,6 +4,7 @@ from exllamav2.module import ExLlamaV2Module from exllamav2.ext import exllamav2_ext as ext_c from exllamav2.compat import safe_move_tensor +from exllamav2.util import substitute_inf_with_max from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -114,7 +115,7 @@ def forward( ) -> torch.Tensor | dict[str: torch.Tensor]: if self.is_tp: - return self.forward_tp( + return substitute_inf_with_max(self.forward_tp( hidden_states, cache, attn_params, @@ -123,7 +124,7 @@ def forward( loras, output_fp32, **kwargs - ) + )) output_shape = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) @@ -137,9 +138,9 @@ def forward( hidden_states = norm.view(output_shape) if intermediates: - return {"hidden_states": hidden_states} + return {"hidden_states": substitute_inf_with_max(hidden_states)} else: - return hidden_states + return substitute_inf_with_max(hidden_states) def forward_tp( @@ -198,9 +199,9 @@ def forward_torch( hidden_states *= self.weight if intermediates: - return {"hidden_states": hidden_states} + return {"hidden_states": substitute_inf_with_max(hidden_states)} else: - return hidden_states + return substitute_inf_with_max(hidden_states) def tp_split(self, broadcast_type: int): diff --git a/exllamav2/util.py b/exllamav2/util.py index fd44462e..2c8ad13b 100644 --- a/exllamav2/util.py +++ b/exllamav2/util.py @@ -367,4 +367,13 @@ def pack_4bit(unpacked: torch.Tensor): for i in range(8): packed |= (unpacked[:, i::8].to(torch.int64) << (i * 4)) packed = packed.to(torch.int32) - return packed \ No newline at end of file + return packed + + +# Function to substitute inf and NaN with the maximum value of the type +def substitute_inf_with_max(tensor): + dtype = tensor.dtype + max_value = torch.finfo(dtype).max if dtype.is_floating_point else torch.iinfo(dtype).max + tensor = torch.where(torch.isinf(tensor), max_value, tensor) + tensor = torch.where(torch.isnan(tensor), max_value, tensor) + return tensor