diff --git a/packages/ltx-core/src/ltx_core/loader/fuse_loras.py b/packages/ltx-core/src/ltx_core/loader/fuse_loras.py index 66269dd..4a80db3 100644 --- a/packages/ltx-core/src/ltx_core/loader/fuse_loras.py +++ b/packages/ltx-core/src/ltx_core/loader/fuse_loras.py @@ -1,7 +1,11 @@ import torch -import triton -from ltx_core.loader.kernels import fused_add_round_kernel +try: + import triton + from ltx_core.loader.kernels import fused_add_round_kernel + TRITON_AVAILABLE = True +except ImportError: + TRITON_AVAILABLE = False from ltx_core.loader.primitives import LoraStateDictWithStrength, StateDict BLOCK_SIZE = 1024 @@ -84,7 +88,7 @@ def apply_loras( continue deltas = weight.clone().to(dtype=target_dtype, device=device) elif weight.dtype == torch.float8_e4m3fn: - if str(device).startswith("cuda"): + if TRITON_AVAILABLE and str(device).startswith("cuda"): deltas = calculate_weight_float8_(deltas, weight) else: deltas.add_(weight.to(dtype=deltas.dtype, device=device)) diff --git a/packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py b/packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py index 9e8853a..a6394f2 100644 --- a/packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py +++ b/packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py @@ -72,9 +72,9 @@ def _return_model(self, meta_model: ModelType, device: torch.device) -> ModelTyp def build(self, device: torch.device | None = None, dtype: torch.dtype | None = None) -> ModelType: device = torch.device("cuda") if device is None else device config = self.model_config() - meta_model = self.meta_model(config, self.module_ops) model_paths = self.model_path if isinstance(self.model_path, tuple) else [self.model_path] model_state_dict = self.load_sd(model_paths, sd_ops=self.model_sd_ops, registry=self.registry, device=device) + meta_model = self.meta_model(config, self.module_ops) lora_strengths = [lora.strength for lora in self.loras] if not lora_strengths or (min(lora_strengths) == 0 and max(lora_strengths) == 0):