diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/cast_to_dtype.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/cast_to_dtype.py new file mode 100644 index 00000000000..f0af929955f --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/cast_to_dtype.py @@ -0,0 +1,15 @@ +from typing import TypeVar + +import torch + +T = TypeVar("T", torch.Tensor, None, torch.Tensor | None) + + +def cast_to_dtype(t: T, to_dtype: torch.dtype) -> T: + """Helper function to cast an optional tensor to a target dtype.""" + if t is None: + return t + + if t.dtype != to_dtype: + return t.to(dtype=to_dtype) + return t diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py index c440526b9b9..2f3f4266d75 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py @@ -3,6 +3,7 @@ import torch from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_dtype import cast_to_dtype from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( CustomModuleMixin, ) @@ -73,6 +74,10 @@ def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor: def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: weight = cast_to_device(self.weight, input.device) bias = cast_to_device(self.bias, input.device) + + weight = cast_to_dtype(weight, input.dtype) + bias = cast_to_dtype(bias, input.dtype) + return torch.nn.functional.linear(input, weight, bias) def forward(self, input: torch.Tensor) -> torch.Tensor: