From f481dca4df9dbddc456bfec5365f71df80136c07 Mon Sep 17 00:00:00 2001 From: koshe Date: Fri, 9 Jan 2026 08:28:09 +0100 Subject: [PATCH 01/38] Run on 8gb vram laptop --- .../text_encoders/gemma/encoders/base_encoder.py | 10 +++++++++- .../src/ltx_pipelines/ti2vid_two_stages.py | 6 ++++++ .../ltx-pipelines/src/ltx_pipelines/utils/helpers.py | 9 +++++++-- .../src/ltx_pipelines/utils/model_ledger.py | 6 +++++- 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/base_encoder.py b/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/base_encoder.py index e689c1af..deba4f45 100644 --- a/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/base_encoder.py +++ b/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/base_encoder.py @@ -244,8 +244,16 @@ def module_ops_from_gemma_root(gemma_root: str) -> tuple[ModuleOps, ...]: tokenizer_path = _find_matching_dir(gemma_root, "tokenizer.model") def load_gemma(module: GemmaTextEncoderModelBase) -> GemmaTextEncoderModelBase: + # Reserve 2GB VRAM for context window and activations + # Limit Gemma to 6GB, forcing more layers to CPU RAM + max_memory = {0: "6GiB", "cpu": "32GiB"} # GPU 0: 6GB, CPU: 32GB + module.model = Gemma3ForConditionalGeneration.from_pretrained( - gemma_path, local_files_only=True, torch_dtype=torch.bfloat16 + gemma_path, + local_files_only=True, + torch_dtype=torch.bfloat16, + device_map="auto", # Enable sequential offloading + max_memory=max_memory # Reserve 2GB VRAM for inference ) module._gemma_root = module._gemma_root or gemma_root return module diff --git a/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py b/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py index b835bfe5..0ae75ba5 100644 --- a/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py +++ b/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py @@ -111,6 +111,12 @@ def __call__( # noqa: PLR0913 v_context_n, a_context_n = context_n torch.cuda.synchronize() + # For device-mapped models, need to explicitly remove hooks before deletion + if hasattr(text_encoder, 'model') and hasattr(text_encoder.model, 'hf_device_map'): + # Remove all hooks to fully release memory + from accelerate.hooks import remove_hook_from_module + remove_hook_from_module(text_encoder.model, recurse=True) + text_encoder.model = None del text_encoder cleanup_memory() diff --git a/packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py b/packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py index 867db18f..31d27643 100644 --- a/packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py +++ b/packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py @@ -33,9 +33,14 @@ def get_device() -> torch.device: def cleanup_memory() -> None: + """Clean up GPU and system memory, including device-mapped models.""" gc.collect() - torch.cuda.empty_cache() - torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + # Second pass to ensure device-mapped tensors are released + gc.collect() + torch.cuda.empty_cache() def image_conditionings_by_replacing_latent( diff --git a/packages/ltx-pipelines/src/ltx_pipelines/utils/model_ledger.py b/packages/ltx-pipelines/src/ltx_pipelines/utils/model_ledger.py index c507ff4c..edec562f 100644 --- a/packages/ltx-pipelines/src/ltx_pipelines/utils/model_ledger.py +++ b/packages/ltx-pipelines/src/ltx_pipelines/utils/model_ledger.py @@ -218,7 +218,11 @@ def text_encoder(self) -> AVGemmaTextEncoderModel: "ModelLedger constructor." ) - return self.text_encoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval() + model = self.text_encoder_builder.build(device=self._target_device(), dtype=self.dtype) + # If the model has device mapping (from device_map="auto"), don't call .to() as it's already distributed + if hasattr(model, 'model') and hasattr(model.model, 'hf_device_map') and model.model.hf_device_map: + return model.eval() + return model.to(self.device).eval() def audio_decoder(self) -> AudioDecoder: if not hasattr(self, "audio_decoder_builder"): From 5ae02a4f9734bb14f0e2450da37c70ab21e70cbc Mon Sep 17 00:00:00 2001 From: koshe Date: Wed, 14 Jan 2026 03:34:17 +0100 Subject: [PATCH 02/38] Optimize for 8Gb vram --- .../src/ltx_core/loader/fuse_loras.py | 33 ++- .../ltx-core/src/ltx_core/loader/kernels.py | 139 +++++----- .../loader/single_gpu_model_builder.py | 98 +++++-- .../src/ltx_core/model/transformer/model.py | 24 +- .../ltx_core/model/transformer/transformer.py | 15 +- .../gemma/encoders/base_encoder.py | 24 +- .../prompts/gemma_t2v_system_prompt.txt | 59 ++--- .../text_encoders/gemma/feature_extractor.py | 3 +- .../src/ltx_pipelines/distilled.py | 106 ++++++-- .../src/ltx_pipelines/ti2vid_two_stages.py | 240 +++++++++++++++++- .../src/ltx_pipelines/utils/helpers.py | 43 ++-- .../src/ltx_pipelines/utils/model_ledger.py | 22 +- 12 files changed, 589 insertions(+), 217 deletions(-) diff --git a/packages/ltx-core/src/ltx_core/loader/fuse_loras.py b/packages/ltx-core/src/ltx_core/loader/fuse_loras.py index 66269dd7..0da89375 100644 --- a/packages/ltx-core/src/ltx_core/loader/fuse_loras.py +++ b/packages/ltx-core/src/ltx_core/loader/fuse_loras.py @@ -1,13 +1,41 @@ import torch -import triton +# import triton from ltx_core.loader.kernels import fused_add_round_kernel from ltx_core.loader.primitives import LoraStateDictWithStrength, StateDict BLOCK_SIZE = 1024 +from line_profiler import profile +@profile def fused_add_round_launch(target_weight: torch.Tensor, original_weight: torch.Tensor, seed: int) -> torch.Tensor: + """ + Native PyTorch implementation of fused_add_round_launch. + + Note: + 1. Requires PyTorch 2.1 or newer for torch.float8 support. + 2. The 'seed' argument is accepted to maintain API compatibility but is ignored + because native PyTorch addition uses deterministic Round-To-Nearest-Even (RTNE) + rather than stochastic rounding. + """ + # Validation logic from original function + if original_weight.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: + raise ValueError("Unsupported dtype") + + if target_weight.dtype != torch.bfloat16: + raise ValueError("target_weight dtype must be bfloat16") + + # Implementation: + # 1. Cast original_weight (fp8) to target_weight dtype (bf16). + # Since bf16 has higher dynamic range and precision than fp8, this upcast is exact. + # 2. Add in-place. + target_weight.add_(original_weight.to(target_weight.dtype)) + + return target_weight + + +def fused_add_round_launch__(target_weight: torch.Tensor, original_weight: torch.Tensor, seed: int) -> torch.Tensor: if original_weight.dtype == torch.float8_e4m3fn: exponent_bits, mantissa_bits, exponent_bias = 4, 3, 7 elif original_weight.dtype == torch.float8_e5m2: @@ -20,7 +48,8 @@ def fused_add_round_launch(target_weight: torch.Tensor, original_weight: torch.T # Calculate grid and block sizes n_elements = original_weight.numel() - grid = (triton.cdiv(n_elements, BLOCK_SIZE),) + #grid = (triton.cdiv(n_elements, BLOCK_SIZE),) + grid = 0 # Launch kernel fused_add_round_kernel[grid]( diff --git a/packages/ltx-core/src/ltx_core/loader/kernels.py b/packages/ltx-core/src/ltx_core/loader/kernels.py index ee4cefbe..765f367e 100644 --- a/packages/ltx-core/src/ltx_core/loader/kernels.py +++ b/packages/ltx-core/src/ltx_core/loader/kernels.py @@ -1,72 +1,79 @@ -# ruff: noqa: ANN001, ANN201, ERA001, N803, N806 -import triton -import triton.language as tl +import torch +from line_profiler import profile -@triton.jit +@profile def fused_add_round_kernel( - x_ptr, - output_ptr, # contents will be added to the output - seed, - n_elements, - EXPONENT_BIAS, - MANTISSA_BITS, - BLOCK_SIZE: tl.constexpr, + x: torch.Tensor, + output: torch.Tensor, + seed: int, + n_elements: int, # Kept for signature compatibility, but unused + EXPONENT_BIAS: int, + MANTISSA_BITS: int, + BLOCK_SIZE: int = None, # Kept for signature compatibility, but unused ): """ - A kernel to upcast 8bit quantized weights to bfloat16 with stochastic rounding - and add them to bfloat16 output weights. Might be used to upcast original model weights - and to further add them to precalculated deltas coming from LoRAs. + Native PyTorch implementation of the fused_add_round_kernel. + + This performs: + 1. Upcast 8-bit weights (x) to match output precision. + 2. Add output weights (deltas) to x. + 3. Calculate the epsilon (quantization noise step) based on the target + Float8 parameters (EXPONENT_BIAS, MANTISSA_BITS). + 4. Apply stochastic rounding (add noise proportional to epsilon). + 5. Store back to output. """ - # Get program ID and compute offsets - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - - # Load data - x = tl.load(x_ptr + offsets, mask=mask) - rand_vals = tl.rand(seed, offsets) - 0.5 - - x = tl.cast(x, tl.float16) - delta = tl.load(output_ptr + offsets, mask=mask) - delta = tl.cast(delta, tl.float16) - x = x + delta - - x_bits = tl.cast(x, tl.int16, bitcast=True) - - # Calculate the exponent. Unbiased fp16 exponent is ((x_bits & 0x7C00) >> 10) - 15 for - # normal numbers and -14 for subnormals. - fp16_exponent_bits = (x_bits & 0x7C00) >> 10 - fp16_normals = fp16_exponent_bits > 0 - fp16_exponent = tl.where(fp16_normals, fp16_exponent_bits - 15, -14) - - # Add the target dtype's exponent bias and clamp to the target dtype's exponent range. - exponent = fp16_exponent + EXPONENT_BIAS - MAX_EXPONENT = 2 * EXPONENT_BIAS + 1 - exponent = tl.where(exponent > MAX_EXPONENT, MAX_EXPONENT, exponent) - exponent = tl.where(exponent < 0, 0, exponent) - - # Normal ULP exponent, expressed as an fp16 exponent field: - # (exponent - EXPONENT_BIAS - MANTISSA_BITS) + 15 - # Simplifies to: fp16_exponent - MANTISSA_BITS + 15 - # See https://en.wikipedia.org/wiki/Unit_in_the_last_place - eps_exp = tl.maximum(0, tl.minimum(31, exponent - EXPONENT_BIAS - MANTISSA_BITS + 15)) - - # Calculate epsilon in the target dtype - eps_normal = tl.cast(tl.cast(eps_exp << 10, tl.int16), tl.float16, bitcast=True) - - # Subnormal ULP: 2^(1 - EXPONENT_BIAS - MANTISSA_BITS) -> - # fp16 exponent bits: (1 - EXPONENT_BIAS - MANTISSA_BITS) + 15 = - # 16 - EXPONENT_BIAS - MANTISSA_BITS - eps_subnormal = tl.cast(tl.cast((16 - EXPONENT_BIAS - MANTISSA_BITS) << 10, tl.int16), tl.float16, bitcast=True) - eps = tl.where(exponent > 0, eps_normal, eps_subnormal) - - # Apply zero mask to epsilon - eps = tl.where(x == 0, 0.0, eps) - - # Apply stochastic rounding - output = tl.cast(x + rand_vals * eps, tl.bfloat16) - - # Store the result - tl.store(output_ptr + offsets, output, mask=mask) + + # 1. Setup Generators for stochastic rounding + # We use a specific generator to respect the seed argument + gen = torch.Generator(device=output.device).manual_seed(seed) + + # 2. Load and Cast to calculation precision (Float32 for safety, or Float16) + # Using Float32 ensures high precision during the intermediate math + val_x = x.to(torch.float32) + val_delta = output.to(torch.float32) + + # x = x + delta + val = val_x + val_delta + + # 3. Calculate Epsilon (The Stochastic Rounding Step) + # The Triton kernel calculates epsilon based on the magnitude of 'val' + # mapped onto the specific Float8 exponent grid. + + # Extract exponent: val = mantissa * 2^exp. + # torch.frexp returns exp such that 0.5 <= |mantissa| < 1.0. + # IEEE 754 log2(x) is (exp - 1). + _, exp_obj = torch.frexp(val) + unbiased_exp = exp_obj - 1 + + # Map to target Float8 exponent space + target_exp = unbiased_exp + EXPONENT_BIAS + + # Clamp exponent to target dtype range. + # Max is standard formulation (2*Bias + 1). + # Min is 1. Why 1? In the original Triton kernel, subnormals (exp <= 0) + # utilize a constant epsilon calculated based on exponent=1 (the smallest normal). + max_exponent = 2 * EXPONENT_BIAS + 1 + target_exp_clamped = torch.clamp(target_exp, min=1, max=max_exponent) + + # Calculate ULP exponent: E_target - BIAS - Mantissa_Bits + eps_exponent = target_exp_clamped - EXPONENT_BIAS - MANTISSA_BITS + + # Convert exponent to actual epsilon value: 2^eps_exponent + eps = torch.pow(2.0, eps_exponent.to(torch.float32)) + + # Mask epsilon where value is exactly 0 (matches `tl.where(x == 0, 0.0, eps)`) + eps = torch.where(val == 0, 0.0, eps) + + # 4. Generate Random Noise [-0.5, 0.5] + rand_vals = torch.rand(val.shape, generator=gen, device=val.device) - 0.5 + + # 5. Apply Stochastic Rounding + # output = x + (noise * epsilon) + result = val + (rand_vals * eps) + + # 6. Store Result + # In-place update of the output tensor, cast to bfloat16 + output.copy_(result.to(torch.bfloat16)) + + # No return value needed as operation is in-place on output_ptr/output \ No newline at end of file diff --git a/packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py b/packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py index 9e8853a4..55119036 100644 --- a/packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py +++ b/packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py @@ -21,11 +21,14 @@ logger: logging.Logger = logging.getLogger(__name__) +from loguru import logger +from accelerate import dispatch_model, infer_auto_device_map + @dataclass(frozen=True) class SingleGPUModelBuilder(Generic[ModelType], ModelBuilderProtocol[ModelType], LoRAAdaptableProtocol): """ - Builder for PyTorch models residing on a single GPU. + Builder for PyTorch models residing on a single GPU or offloaded via Accelerate. """ model_class_configurator: type[ModelConfigurator[ModelType]] @@ -69,33 +72,78 @@ def _return_model(self, meta_model: ModelType, device: torch.device) -> ModelTyp retval = meta_model.to(device) return retval - def build(self, device: torch.device | None = None, dtype: torch.dtype | None = None) -> ModelType: - device = torch.device("cuda") if device is None else device + def build( + self, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_memory: dict[int | str, str] | None = None + ) -> ModelType: + target_device = torch.device("cuda") if device is None else device + + # 1. Get Config and Meta Model config = self.model_config() meta_model = self.meta_model(config, self.module_ops) + + # 2. Load Base State Dict model_paths = self.model_path if isinstance(self.model_path, tuple) else [self.model_path] - model_state_dict = self.load_sd(model_paths, sd_ops=self.model_sd_ops, registry=self.registry, device=device) + load_device = target_device if max_memory is None else torch.device("cpu") + model_state_dict = self.load_sd(model_paths, sd_ops=self.model_sd_ops, registry=self.registry, + device=load_device) + # 3. Handle LoRAs lora_strengths = [lora.strength for lora in self.loras] + final_sd_map = {} + if not lora_strengths or (min(lora_strengths) == 0 and max(lora_strengths) == 0): - sd = model_state_dict.sd - if dtype is not None: - sd = {key: value.to(dtype=dtype) for key, value in model_state_dict.sd.items()} - meta_model.load_state_dict(sd, strict=False, assign=True) - return self._return_model(meta_model, device) - - lora_state_dicts = [ - self.load_sd([lora.path], sd_ops=lora.sd_ops, registry=self.registry, device=device) for lora in self.loras - ] - lora_sd_and_strengths = [ - LoraStateDictWithStrength(sd, strength) - for sd, strength in zip(lora_state_dicts, lora_strengths, strict=True) - ] - final_sd = apply_loras( - model_sd=model_state_dict, - lora_sd_and_strengths=lora_sd_and_strengths, - dtype=dtype, - destination_sd=model_state_dict if isinstance(self.registry, DummyRegistry) else None, - ) - meta_model.load_state_dict(final_sd.sd, strict=False, assign=True) - return self._return_model(meta_model, device) + final_sd_map = model_state_dict.sd + else: + # Convert LoRAs to float32 on CPU to prevent slow BF16 emulation + lora_state_dicts = [] + for lora in self.loras: + lsd = self.load_sd([lora.path], sd_ops=lora.sd_ops, registry=self.registry, device=load_device) + + if load_device.type == "cpu": + # In-place conversion of LoRA tensors to float32 for speed + # This speeds up the matmul in apply_loras significantly + for k, v in lsd.sd.items(): + if v.dtype in [torch.bfloat16, torch.float16]: + lsd.sd[k] = v.to(dtype=torch.float32) + + lora_state_dicts.append(lsd) + + lora_sd_and_strengths = [ + LoraStateDictWithStrength(sd, strength) + for sd, strength in zip(lora_state_dicts, lora_strengths, strict=True) + ] + + dest_sd = model_state_dict if isinstance(self.registry, DummyRegistry) else None + + final_sd_obj = apply_loras( + model_sd=model_state_dict, + lora_sd_and_strengths=lora_sd_and_strengths, + dtype=dtype, + destination_sd=dest_sd, + ) + final_sd_map = final_sd_obj.sd + + # 4. Cast Dtypes if requested + if dtype is not None: + final_sd_map = {k: v.to(dtype=dtype) for k, v in final_sd_map.items()} + + # 5. Load State Dict into Model + meta_model.load_state_dict(final_sd_map, strict=False, assign=True) + + # 6. Return based on Offloading strategy + if max_memory is not None: + logger.info(f"Dispatching model with max_memory constraints: {max_memory}") + no_split_modules = getattr(self.model_class_configurator, "no_split_modules", None) + device_map = infer_auto_device_map( + meta_model, + max_memory=max_memory, + no_split_module_classes=no_split_modules, + dtype=dtype + ) + model = dispatch_model(meta_model, device_map=device_map) + return model + + return self._return_model(meta_model, target_device) diff --git a/packages/ltx-core/src/ltx_core/model/transformer/model.py b/packages/ltx-core/src/ltx_core/model/transformer/model.py index 411e3b42..dc4d662f 100644 --- a/packages/ltx-core/src/ltx_core/model/transformer/model.py +++ b/packages/ltx-core/src/ltx_core/model/transformer/model.py @@ -16,6 +16,8 @@ ) from ltx_core.utils import to_denoised +#from line_profiler import profile + class LTXModelType(Enum): AudioVideo = "ltx av model" @@ -35,6 +37,7 @@ class LTXModel(torch.nn.Module): This class implements the transformer blocks for the LTX model. """ + #@profile 1.37738 s def __init__( # noqa: PLR0913 self, *, @@ -105,7 +108,7 @@ def __init__( # noqa: PLR0913 self._init_preprocessors(cross_pe_max_pos) # Initialize transformer blocks - self._init_transformer_blocks( + self._init_transformer_blocks( # 98.2% num_layers=num_layers, attention_head_dim=attention_head_dim if model_type.is_video_enabled() else 0, cross_attention_dim=cross_attention_dim, @@ -115,6 +118,7 @@ def __init__( # noqa: PLR0913 attention_type=attention_type, ) + #@profile 0.0069204 s def _init_video( self, in_channels: int, @@ -139,6 +143,7 @@ def _init_video( self.norm_out = torch.nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=norm_eps) self.proj_out = torch.nn.Linear(self.inner_dim, out_channels) + #@profile 0.0063044 s def _init_audio( self, in_channels: int, @@ -166,6 +171,7 @@ def _init_audio( self.audio_norm_out = torch.nn.LayerNorm(self.audio_inner_dim, elementwise_affine=False, eps=norm_eps) self.audio_proj_out = torch.nn.Linear(self.audio_inner_dim, out_channels) + #@profile 0.0111731 s def _init_audio_video( self, num_scale_shift_values: int, @@ -191,6 +197,7 @@ def _init_audio_video( embedding_coefficient=1, ) + #@profile 0.0002355 s def _init_preprocessors( self, cross_pe_max_pos: int | None = None, @@ -263,6 +270,7 @@ def _init_preprocessors( rope_type=self.rope_type, ) + #@profile 1.3519 s def _init_transformer_blocks( self, num_layers: int, @@ -296,7 +304,7 @@ def _init_transformer_blocks( ) self.transformer_blocks = torch.nn.ModuleList( [ - BasicAVTransformerBlock( + BasicAVTransformerBlock( # 99.9% idx=idx, video=video_config, audio=audio_config, @@ -308,6 +316,7 @@ def _init_transformer_blocks( ] ) + #@profile unused def set_gradient_checkpointing(self, enable: bool) -> None: """Enable or disable gradient checkpointing for transformer blocks. Gradient checkpointing trades compute for memory by recomputing activations @@ -318,6 +327,7 @@ def set_gradient_checkpointing(self, enable: bool) -> None: """ self._enable_gradient_checkpointing = enable + #@profile 498.557 s def _process_transformer_blocks( self, video: TransformerArgs | None, @@ -340,7 +350,7 @@ def _process_transformer_blocks( use_reentrant=False, ) else: - video, audio = block( + video, audio = block( # 100% video=video, audio=audio, perturbations=perturbations, @@ -348,6 +358,7 @@ def _process_transformer_blocks( return video, audio + #@profile 0.0648487 s def _process_output( self, scale_shift_table: torch.Tensor, @@ -368,6 +379,7 @@ def _process_output( x = proj_out(x) return x + #@profile 502.847 s def forward( self, video: Modality | None, audio: Modality | None, perturbations: BatchedPerturbationConfig ) -> tuple[torch.Tensor, torch.Tensor]: @@ -384,7 +396,7 @@ def forward( video_args = self.video_args_preprocessor.prepare(video) if video is not None else None audio_args = self.audio_args_preprocessor.prepare(audio) if audio is not None else None # Process transformer blocks - video_out, audio_out = self._process_transformer_blocks( + video_out, audio_out = self._process_transformer_blocks( # 99.1% video=video_args, audio=audio_args, perturbations=perturbations, @@ -450,7 +462,7 @@ class X0Model(torch.nn.Module): def __init__(self, velocity_model: LTXModel): super().__init__() self.velocity_model = velocity_model - + #@profile 502.854 s def forward( self, video: Modality | None, @@ -462,7 +474,7 @@ def forward( Returns: Denoised video and audio """ - vx, ax = self.velocity_model(video, audio, perturbations) + vx, ax = self.velocity_model(video, audio, perturbations) # 100% denoised_video = to_denoised(video.latent, vx, video.timesteps) if vx is not None else None denoised_audio = to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None return denoised_video, denoised_audio diff --git a/packages/ltx-core/src/ltx_core/model/transformer/transformer.py b/packages/ltx-core/src/ltx_core/model/transformer/transformer.py index 047faaab..78173ddd 100644 --- a/packages/ltx-core/src/ltx_core/model/transformer/transformer.py +++ b/packages/ltx-core/src/ltx_core/model/transformer/transformer.py @@ -9,6 +9,8 @@ from ltx_core.model.transformer.transformer_args import TransformerArgs from ltx_core.utils import rms_norm +#from line_profiler import profile + @dataclass class TransformerConfig: @@ -103,17 +105,19 @@ def __init__( self.norm_eps = norm_eps + #@profile 1.26368 s def get_ada_values( self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice ) -> tuple[torch.Tensor, ...]: num_ada_params = scale_shift_table.shape[0] ada_values = ( - scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype) + scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype) # 89.6% + timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :] ).unbind(dim=2) return ada_values + #@profile 0.925723 s def get_av_ca_ada_values( self, scale_shift_table: torch.Tensor, @@ -122,7 +126,7 @@ def get_av_ca_ada_values( gate_timestep: torch.Tensor, num_scale_shift_values: int = 4, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - scale_shift_ada_values = self.get_ada_values( + scale_shift_ada_values = self.get_ada_values( # 86% scale_shift_table[:num_scale_shift_values, :], batch_size, scale_shift_timestep, slice(None, None) ) gate_ada_values = self.get_ada_values( @@ -134,6 +138,7 @@ def get_av_ca_ada_values( return (*scale_shift_chunks, *gate_ada_values) + #@profile 859.862 s def forward( # noqa: PLR0915 self, video: TransformerArgs | None, @@ -160,9 +165,9 @@ def forward( # noqa: PLR0915 if not perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx): norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa v_mask = perturbations.mask_like(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx, vx) - vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa * v_mask + vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa * v_mask # 24% - vx = vx + self.attn2(rms_norm(vx, eps=self.norm_eps), context=video.context, mask=video.context_mask) + vx = vx + self.attn2(rms_norm(vx, eps=self.norm_eps), context=video.context, mask=video.context_mask) # 14% del vshift_msa, vscale_msa, vgate_msa @@ -258,7 +263,7 @@ def forward( # noqa: PLR0915 self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, None) ) vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp - vx = vx + self.ff(vx_scaled) * vgate_mlp + vx = vx + self.ff(vx_scaled) * vgate_mlp # 33% del vshift_mlp, vscale_mlp, vgate_mlp diff --git a/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/base_encoder.py b/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/base_encoder.py index deba4f45..4c14ed48 100644 --- a/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/base_encoder.py +++ b/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/base_encoder.py @@ -3,7 +3,7 @@ import torch from einops import rearrange -from transformers import AutoImageProcessor, Gemma3ForConditionalGeneration, Gemma3Processor +from transformers import AutoImageProcessor, Gemma3ForConditionalGeneration, Gemma3Processor, BitsAndBytesConfig from ltx_core.loader.module_ops import ModuleOps from ltx_core.text_encoders.gemma.feature_extractor import GemmaFeaturesExtractorProjLinear @@ -44,12 +44,16 @@ def _run_feature_extractor( encoded_text_features = torch.stack(hidden_states, dim=-1) encoded_text_features_dtype = encoded_text_features.dtype + print(encoded_text_features_dtype) + sequence_lengths = attention_mask.sum(dim=-1) normed_concated_encoded_text_features = _norm_and_concat_padded_batch( encoded_text_features, sequence_lengths, padding_side=padding_side ) + print("normed_concated_encoded_text_features") + print(normed_concated_encoded_text_features.dtype) - return self.feature_extractor_linear(normed_concated_encoded_text_features.to(encoded_text_features_dtype)) + return self.feature_extractor_linear(normed_concated_encoded_text_features.to(torch.bfloat16)) def _convert_to_additive_mask(self, attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: return (attention_mask - 1).to(dtype).reshape( @@ -244,13 +248,21 @@ def module_ops_from_gemma_root(gemma_root: str) -> tuple[ModuleOps, ...]: tokenizer_path = _find_matching_dir(gemma_root, "tokenizer.model") def load_gemma(module: GemmaTextEncoderModelBase) -> GemmaTextEncoderModelBase: + #max_memory = {0: "8GiB", "cpu": "32GiB"} + # 2. Load the model + #module.model = Gemma3ForConditionalGeneration.from_pretrained( + # gemma_path, + # local_files_only=True, + # device_map="auto", + # max_memory=max_memory + #) + # Reserve 2GB VRAM for context window and activations # Limit Gemma to 6GB, forcing more layers to CPU RAM - max_memory = {0: "6GiB", "cpu": "32GiB"} # GPU 0: 6GB, CPU: 32GB - + max_memory = {0: "3GiB", "cpu": "32GiB"} # GPU 0: 6GB, CPU: 32GB module.model = Gemma3ForConditionalGeneration.from_pretrained( - gemma_path, - local_files_only=True, + gemma_path, + local_files_only=True, torch_dtype=torch.bfloat16, device_map="auto", # Enable sequential offloading max_memory=max_memory # Reserve 2GB VRAM for inference diff --git a/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/prompts/gemma_t2v_system_prompt.txt b/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/prompts/gemma_t2v_system_prompt.txt index e8642019..f16acd88 100644 --- a/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/prompts/gemma_t2v_system_prompt.txt +++ b/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/prompts/gemma_t2v_system_prompt.txt @@ -1,40 +1,23 @@ -You are a Creative Assistant. Given a user's raw input prompt describing a scene or concept, expand it into a detailed video generation prompt with specific visuals and integrated audio to guide a text-to-video model. - -#### Guidelines -- Strictly follow all aspects of the user's raw input: include every element requested (style, visuals, motions, actions, camera movement, audio). - - If the input is vague, invent concrete details: lighting, textures, materials, scene settings, etc. - - For characters: describe gender, clothing, hair, expressions. DO NOT invent unrequested characters. -- Use active language: present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements. -- Maintain chronological flow: use temporal connectors ("as," "then," "while"). -- Audio layer: Describe complete soundscape (background audio, ambient sounds, SFX, speech/music when requested). Integrate sounds chronologically alongside actions. Be specific (e.g., "soft footsteps on tile"), not vague (e.g., "ambient sound is present"). -- Speech (only when requested): - - For ANY speech-related input (talking, conversation, singing, etc.), ALWAYS include exact words in quotes with voice characteristics (e.g., "The man says in an excited voice: 'You won't believe what I just saw!'"). - - Specify language if not English and accent if relevant. -- Style: Include visual style at the beginning: "Style: