diff --git a/docs/source/en/api/pipelines/ltx_video.md b/docs/source/en/api/pipelines/ltx_video.md index 26d17733c10d..3464d88145ea 100644 --- a/docs/source/en/api/pipelines/ltx_video.md +++ b/docs/source/en/api/pipelines/ltx_video.md @@ -31,12 +31,103 @@ Available models: | Model name | Recommended dtype | |:-------------:|:-----------------:| -| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` | -| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` | -| [`LTX Video 0.9.5`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.5.safetensors) | `torch.bfloat16` | +| [`LTX Video 2B 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` | +| [`LTX Video 2B 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` | +| [`LTX Video 2B 0.9.5`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.5.safetensors) | `torch.bfloat16` | +| [`LTX Video 13B 0.9.7`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-dev.safetensors) | `torch.bfloat16` | +| [`LTX Video Spatial Upscaler 0.9.7`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-spatial-upscaler-0.9.7.safetensors) | `torch.bfloat16` | Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository. +## Recommended settings for generation + +For the best results, it is recommended to follow the guidelines mentioned in the official LTX Video [repository](https://github.com/Lightricks/LTX-Video). + +- Some variants of LTX Video are guidance-distilled. For guidance-distilled models, `guidance_scale` must be set to `1.0`. For any other models, `guidance_scale` should be set higher (e.g., `5.0`) for good generation quality. +- For variants with a timestep-aware VAE (LTXV 0.9.1 and above), it is recommended to set `decode_timestep` to `0.05` and `image_cond_noise_scale` to `0.025`. +- For variants that support interpolation between multiple conditioning images and videos (LTXV 0.9.5 and above), it is recommended to use similar looking images/videos for the best results. High divergence between the conditionings may lead to abrupt transitions in the generated video. + +## Using LTX Video 13B 0.9.7 + +LTX Video 0.9.7 comes with a spatial latent upscaler and a 13B parameter transformer. The inference involves generating a low resolution video first, which is very fast, followed by upscaling and refining the generated video. + + + +```python +import torch +from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline +from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition +from diffusers.utils import export_to_video, load_video + +pipe = LTXConditionPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.7-diffusers", torch_dtype=torch.bfloat16) +pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.7-Latent-Spatial-Upsampler-diffusers", vae=pipe.vae, torch_dtype=torch.bfloat16) +pipe.to("cuda") +pipe_upsample.to("cuda") +pipe.vae.enable_tiling() + +def round_to_nearest_resolution_acceptable_by_vae(height, width): + height = height - (height % pipe.vae_temporal_compression_ratio) + width = width - (width % pipe.vae_temporal_compression_ratio) + return height, width + +video = load_video( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4" +)[:21] # Use only the first 21 frames as conditioning +condition1 = LTXVideoCondition(video=video, frame_index=0) + +prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region." +negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" +expected_height, expected_width = 768, 1152 +downscale_factor = 2 / 3 +num_frames = 161 + +# Part 1. Generate video at smaller resolution +# Text-only conditioning is also supported without the need to pass `conditions` +downscaled_height, downscaled_width = int(expected_height * downscale_factor), int(expected_width * downscale_factor) +downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(downscaled_height, downscaled_width) +latents = pipe( + conditions=[condition1], + prompt=prompt, + negative_prompt=negative_prompt, + width=downscaled_width, + height=downscaled_height, + num_frames=num_frames, + num_inference_steps=30, + generator=torch.Generator().manual_seed(0), + output_type="latent", +).frames + +# Part 2. Upscale generated video using latent upsampler with fewer inference steps +# The available latent upsampler upscales the height/width by 2x +upscaled_height, upscaled_width = downscaled_height * 2, downscaled_width * 2 +upscaled_latents = pipe_upsample( + latents=latents, + output_type="latent" +).frames + +# Part 3. Denoise the upscaled video with few steps to improve texture (optional, but recommended) +video = pipe( + conditions=[condition1], + prompt=prompt, + negative_prompt=negative_prompt, + width=upscaled_width, + height=upscaled_height, + num_frames=num_frames, + denoise_strength=0.4, # Effectively, 4 inference steps out of 10 + num_inference_steps=10, + latents=upscaled_latents, + decode_timestep=0.05, + image_cond_noise_scale=0.025, + generator=torch.Generator().manual_seed(0), + output_type="pil", +).frames[0] + +# Part 4. Downscale the video to the expected resolution +video = [frame.resize((expected_width, expected_height)) for frame in video] + +export_to_video(video, "output.mp4", fps=24) +``` + ## Loading Single Files Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. We recommend using `from_single_file` for the Lightricks series of models, as they plan to release multiple models in the future in the single file format. @@ -204,6 +295,12 @@ export_to_video(video, "ship.mp4", fps=24) - all - __call__ +## LTXLatentUpsamplePipeline + +[[autodoc]] LTXLatentUpsamplePipeline + - all + - __call__ + ## LTXPipelineOutput [[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput diff --git a/scripts/convert_ltx_to_diffusers.py b/scripts/convert_ltx_to_diffusers.py index 2e966d5d110b..256312cc72ff 100644 --- a/scripts/convert_ltx_to_diffusers.py +++ b/scripts/convert_ltx_to_diffusers.py @@ -7,7 +7,15 @@ from safetensors.torch import load_file from transformers import T5EncoderModel, T5Tokenizer -from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel +from diffusers import ( + AutoencoderKLLTXVideo, + FlowMatchEulerDiscreteScheduler, + LTXConditionPipeline, + LTXLatentUpsamplePipeline, + LTXPipeline, + LTXVideoTransformer3DModel, +) +from diffusers.pipelines.ltx.modeling_latent_upsampler import LTXLatentUpsamplerModel def remove_keys_(key: str, state_dict: Dict[str, Any]): @@ -123,17 +131,10 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: state_dict[new_key] = state_dict.pop(old_key) -def convert_transformer( - ckpt_path: str, - dtype: torch.dtype, - version: str = "0.9.0", -): +def convert_transformer(ckpt_path: str, config, dtype: torch.dtype): PREFIX_KEY = "model.diffusion_model." original_state_dict = get_state_dict(load_file(ckpt_path)) - config = {} - if version == "0.9.5": - config["_use_causal_rope_fix"] = True with init_empty_weights(): transformer = LTXVideoTransformer3DModel(**config) @@ -180,8 +181,59 @@ def convert_vae(ckpt_path: str, config, dtype: torch.dtype): return vae +def convert_spatial_latent_upsampler(ckpt_path: str, config, dtype: torch.dtype): + original_state_dict = get_state_dict(load_file(ckpt_path)) + + with init_empty_weights(): + latent_upsampler = LTXLatentUpsamplerModel(**config) + + latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True) + latent_upsampler.to(dtype) + return latent_upsampler + + +def get_transformer_config(version: str) -> Dict[str, Any]: + if version == "0.9.7": + config = { + "in_channels": 128, + "out_channels": 128, + "patch_size": 1, + "patch_size_t": 1, + "num_attention_heads": 32, + "attention_head_dim": 128, + "cross_attention_dim": 4096, + "num_layers": 48, + "activation_fn": "gelu-approximate", + "qk_norm": "rms_norm_across_heads", + "norm_elementwise_affine": False, + "norm_eps": 1e-6, + "caption_channels": 4096, + "attention_bias": True, + "attention_out_bias": True, + } + else: + config = { + "in_channels": 128, + "out_channels": 128, + "patch_size": 1, + "patch_size_t": 1, + "num_attention_heads": 32, + "attention_head_dim": 64, + "cross_attention_dim": 2048, + "num_layers": 28, + "activation_fn": "gelu-approximate", + "qk_norm": "rms_norm_across_heads", + "norm_elementwise_affine": False, + "norm_eps": 1e-6, + "caption_channels": 4096, + "attention_bias": True, + "attention_out_bias": True, + } + return config + + def get_vae_config(version: str) -> Dict[str, Any]: - if version == "0.9.0": + if version in ["0.9.0"]: config = { "in_channels": 3, "out_channels": 3, @@ -210,7 +262,7 @@ def get_vae_config(version: str) -> Dict[str, Any]: "decoder_causal": False, "timestep_conditioning": False, } - elif version == "0.9.1": + elif version in ["0.9.1"]: config = { "in_channels": 3, "out_channels": 3, @@ -240,7 +292,39 @@ def get_vae_config(version: str) -> Dict[str, Any]: "decoder_causal": False, } VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT) - elif version == "0.9.5": + elif version in ["0.9.5"]: + config = { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "block_out_channels": (128, 256, 512, 1024, 2048), + "down_block_types": ( + "LTXVideo095DownBlock3D", + "LTXVideo095DownBlock3D", + "LTXVideo095DownBlock3D", + "LTXVideo095DownBlock3D", + ), + "decoder_block_out_channels": (256, 512, 1024), + "layers_per_block": (4, 6, 6, 2, 2), + "decoder_layers_per_block": (5, 5, 5, 5), + "spatio_temporal_scaling": (True, True, True, True), + "decoder_spatio_temporal_scaling": (True, True, True), + "decoder_inject_noise": (False, False, False, False), + "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + "upsample_residual": (True, True, True), + "upsample_factor": (2, 2, 2), + "timestep_conditioning": True, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-6, + "scaling_factor": 1.0, + "encoder_causal": True, + "decoder_causal": False, + "spatial_compression_ratio": 32, + "temporal_compression_ratio": 8, + } + VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT) + elif version in ["0.9.7"]: config = { "in_channels": 3, "out_channels": 3, @@ -275,12 +359,33 @@ def get_vae_config(version: str) -> Dict[str, Any]: return config +def get_spatial_latent_upsampler_config(version: str) -> Dict[str, Any]: + if version == "0.9.7": + config = { + "in_channels": 128, + "mid_channels": 512, + "num_blocks_per_stage": 4, + "dims": 3, + "spatial_upsample": True, + "temporal_upsample": False, + } + else: + raise ValueError(f"Unsupported version: {version}") + return config + + def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" ) parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") + parser.add_argument( + "--spatial_latent_upsampler_path", + type=str, + default=None, + help="Path to original spatial latent upsampler checkpoint", + ) parser.add_argument( "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory" ) @@ -294,7 +399,11 @@ def get_args(): parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.") parser.add_argument( - "--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model" + "--version", + type=str, + default="0.9.0", + choices=["0.9.0", "0.9.1", "0.9.5", "0.9.7"], + help="Version of the LTX model", ) return parser.parse_args() @@ -320,11 +429,9 @@ def get_args(): variant = VARIANT_MAPPING[args.dtype] output_path = Path(args.output_path) - if args.save_pipeline: - assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None - if args.transformer_ckpt_path is not None: - transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype) + config = get_transformer_config(args.version) + transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, config, dtype) if not args.save_pipeline: transformer.save_pretrained( output_path / "transformer", safe_serialization=True, max_shard_size="5GB", variant=variant @@ -336,6 +443,16 @@ def get_args(): if not args.save_pipeline: vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size="5GB", variant=variant) + if args.spatial_latent_upsampler_path is not None: + config = get_spatial_latent_upsampler_config(args.version) + latent_upsampler: LTXLatentUpsamplerModel = convert_spatial_latent_upsampler( + args.spatial_latent_upsampler_path, config, dtype + ) + if not args.save_pipeline: + latent_upsampler.save_pretrained( + output_path / "latent_upsampler", safe_serialization=True, max_shard_size="5GB", variant=variant + ) + if args.save_pipeline: text_encoder_id = "google/t5-v1_1-xxl" tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) @@ -348,7 +465,7 @@ def get_args(): for param in text_encoder.parameters(): param.data = param.data.contiguous() - if args.version == "0.9.5": + if args.version in ["0.9.5", "0.9.7"]: scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False) else: scheduler = FlowMatchEulerDiscreteScheduler( @@ -360,12 +477,40 @@ def get_args(): shift_terminal=0.1, ) - pipe = LTXPipeline( - scheduler=scheduler, - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - transformer=transformer, - ) - - pipe.save_pretrained(args.output_path, safe_serialization=True, variant=variant, max_shard_size="5GB") + if args.version in ["0.9.0", "0.9.1", "0.9.5"]: + pipe = LTXPipeline( + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + ) + pipe.save_pretrained( + output_path.as_posix(), safe_serialization=True, variant=variant, max_shard_size="5GB" + ) + elif args.version in ["0.9.7"]: + pipe = LTXConditionPipeline( + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + ) + pipe_upsample = LTXLatentUpsamplePipeline( + vae=vae, + latent_upsampler=latent_upsampler, + ) + pipe.save_pretrained( + (output_path / "ltx_pipeline").as_posix(), + safe_serialization=True, + variant=variant, + max_shard_size="5GB", + ) + pipe_upsample.save_pretrained( + (output_path / "ltx_upsample_pipeline").as_posix(), + safe_serialization=True, + variant=variant, + max_shard_size="5GB", + ) + else: + raise ValueError(f"Unsupported version: {args.version}") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 9e172dc3e7a1..f6dee1eb2dfa 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -420,6 +420,7 @@ "LEditsPPPipelineStableDiffusionXL", "LTXConditionPipeline", "LTXImageToVideoPipeline", + "LTXLatentUpsamplePipeline", "LTXPipeline", "Lumina2Pipeline", "Lumina2Text2ImgPipeline", @@ -1001,6 +1002,7 @@ LEditsPPPipelineStableDiffusionXL, LTXConditionPipeline, LTXImageToVideoPipeline, + LTXLatentUpsamplePipeline, LTXPipeline, Lumina2Pipeline, Lumina2Text2ImgPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 90ff30279f96..51a3d5e33377 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -268,7 +268,12 @@ ] ) _import_structure["latte"] = ["LattePipeline"] - _import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline", "LTXConditionPipeline"] + _import_structure["ltx"] = [ + "LTXPipeline", + "LTXImageToVideoPipeline", + "LTXConditionPipeline", + "LTXLatentUpsamplePipeline", + ] _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["marigold"].extend( @@ -636,7 +641,7 @@ LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, ) - from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline + from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline from .lumina import LuminaPipeline, LuminaText2ImgPipeline from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline from .marigold import ( diff --git a/src/diffusers/pipelines/ltx/__init__.py b/src/diffusers/pipelines/ltx/__init__.py index 199e730d9b4d..6001867916b3 100644 --- a/src/diffusers/pipelines/ltx/__init__.py +++ b/src/diffusers/pipelines/ltx/__init__.py @@ -22,9 +22,11 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: + _import_structure["modeling_latent_upsampler"] = ["LTXLatentUpsamplerModel"] _import_structure["pipeline_ltx"] = ["LTXPipeline"] _import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"] _import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"] + _import_structure["pipeline_ltx_latent_upsample"] = ["LTXLatentUpsamplePipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -34,9 +36,11 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: + from .modeling_latent_upsampler import LTXLatentUpsamplerModel from .pipeline_ltx import LTXPipeline from .pipeline_ltx_condition import LTXConditionPipeline from .pipeline_ltx_image2video import LTXImageToVideoPipeline + from .pipeline_ltx_latent_upsample import LTXLatentUpsamplePipeline else: import sys diff --git a/src/diffusers/pipelines/ltx/modeling_latent_upsampler.py b/src/diffusers/pipelines/ltx/modeling_latent_upsampler.py new file mode 100644 index 000000000000..6dce792a2b43 --- /dev/null +++ b/src/diffusers/pipelines/ltx/modeling_latent_upsampler.py @@ -0,0 +1,188 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin + + +class ResBlock(torch.nn.Module): + def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3): + super().__init__() + if mid_channels is None: + mid_channels = channels + + Conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = torch.nn.GroupNorm(32, mid_channels) + self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = torch.nn.GroupNorm(32, channels) + self.activation = torch.nn.SiLU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + hidden_states = self.conv1(hidden_states) + hidden_states = self.norm1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.conv2(hidden_states) + hidden_states = self.norm2(hidden_states) + hidden_states = self.activation(hidden_states + residual) + return hidden_states + + +class PixelShuffleND(torch.nn.Module): + def __init__(self, dims, upscale_factors=(2, 2, 2)): + super().__init__() + + self.dims = dims + self.upscale_factors = upscale_factors + + if dims not in [1, 2, 3]: + raise ValueError("dims must be 1, 2, or 3") + + def forward(self, x): + if self.dims == 3: + # spatiotemporal: b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3) + return ( + x.unflatten(1, (-1, *self.upscale_factors[:3])) + .permute(0, 1, 5, 2, 6, 3, 7, 4) + .flatten(6, 7) + .flatten(4, 5) + .flatten(2, 3) + ) + elif self.dims == 2: + # spatial: b (c p1 p2) h w -> b c (h p1) (w p2) + return ( + x.unflatten(1, (-1, *self.upscale_factors[:2])).permute(0, 1, 4, 2, 5, 3).flatten(4, 5).flatten(2, 3) + ) + elif self.dims == 1: + # temporal: b (c p1) f h w -> b c (f p1) h w + return x.unflatten(1, (-1, *self.upscale_factors[:1])).permute(0, 1, 3, 2, 4, 5).flatten(2, 3) + + +class LTXLatentUpsamplerModel(ModelMixin, ConfigMixin): + """ + Model to spatially upsample VAE latents. + + Args: + in_channels (`int`, defaults to `128`): + Number of channels in the input latent + mid_channels (`int`, defaults to `512`): + Number of channels in the middle layers + num_blocks_per_stage (`int`, defaults to `4`): + Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`, defaults to `3`): + Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`, defaults to `True`): + Whether to spatially upsample the latent + temporal_upsample (`bool`, defaults to `False`): + Whether to temporally upsample the latent + """ + + @register_to_config + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 512, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + + ConvNd = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.initial_conv = ConvNd(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = torch.nn.GroupNorm(32, mid_channels) + self.initial_activation = torch.nn.SiLU() + + self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]) + + if spatial_upsample and temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError("Either spatial_upsample or temporal_upsample must be True") + + self.post_upsample_res_blocks = torch.nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + if self.dims == 2: + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.initial_conv(hidden_states) + hidden_states = self.initial_norm(hidden_states) + hidden_states = self.initial_activation(hidden_states) + + for block in self.res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.upsampler(hidden_states) + + for block in self.post_upsample_res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.final_conv(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + else: + hidden_states = self.initial_conv(hidden_states) + hidden_states = self.initial_norm(hidden_states) + hidden_states = self.initial_activation(hidden_states) + + for block in self.res_blocks: + hidden_states = block(hidden_states) + + if self.temporal_upsample: + hidden_states = self.upsampler(hidden_states) + hidden_states = hidden_states[:, :, 1:, :, :] + else: + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.upsampler(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + + for block in self.post_upsample_res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.final_conv(hidden_states) + + return hidden_states diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index dcfdfaf23288..50ef7ae7d1db 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -430,6 +430,7 @@ def check_inputs( video, frame_index, strength, + denoise_strength, height, width, callback_on_step_end_tensor_inputs=None, @@ -497,6 +498,9 @@ def check_inputs( elif isinstance(video, list) and isinstance(strength, list) and len(video) != len(strength): raise ValueError("If `conditions` is not provided, `video` and `strength` must be of the same length.") + if denoise_strength < 0 or denoise_strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {denoise_strength}") + @staticmethod def _prepare_video_ids( batch_size: int, @@ -649,6 +653,8 @@ def prepare_latents( width: int = 704, num_frames: int = 161, num_prefix_latent_frames: int = 2, + sigma: Optional[torch.Tensor] = None, + latents: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, @@ -658,7 +664,18 @@ def prepare_latents( latent_width = width // self.vae_spatial_compression_ratio shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if latents is not None and sigma is not None: + if latents.shape != shape: + raise ValueError( + f"Latents shape {latents.shape} does not match expected shape {shape}. Please check the input." + ) + latents = latents.to(device=device, dtype=dtype) + sigma = sigma.to(device=device, dtype=dtype) + latents = sigma * noise + (1 - sigma) * latents + else: + latents = noise if len(conditions) > 0: condition_latent_frames_mask = torch.zeros( @@ -766,6 +783,13 @@ def prepare_latents( return latents, conditioning_mask, video_ids, extra_conditioning_num_latents + def get_timesteps(self, sigmas, timesteps, num_inference_steps, strength): + num_steps = min(int(num_inference_steps * strength), num_inference_steps) + start_index = max(num_inference_steps - num_steps, 0) + sigmas = sigmas[start_index:] + timesteps = timesteps[start_index:] + return sigmas, timesteps, num_inference_steps - start_index + @property def guidance_scale(self): return self._guidance_scale @@ -799,6 +823,7 @@ def __call__( video: List[PipelineImageInput] = None, frame_index: Union[int, List[int]] = 0, strength: Union[float, List[float]] = 1.0, + denoise_strength: float = 1.0, prompt: Union[str, List[str]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, @@ -842,6 +867,10 @@ def __call__( generation. If not provided, one has to pass `conditions`. strength (`float` or `List[float]`, *optional*): The strength or strengths of the conditioning effect. If not provided, one has to pass `conditions`. + denoise_strength (`float`, defaults to `1.0`): + The strength of the noise added to the latents for editing. Higher strength leads to more noise added + to the latents, therefore leading to more differences between original video and generated video. This + is useful for video-to-video editing. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. @@ -918,8 +947,6 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - if latents is not None: - raise ValueError("Passing latents is not yet supported.") # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -929,6 +956,7 @@ def __call__( video=video, frame_index=frame_index, strength=strength, + denoise_strength=denoise_strength, height=height, width=width, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, @@ -977,8 +1005,9 @@ def __call__( strength = [strength] * num_conditions device = self._execution_device + vae_dtype = self.vae.dtype - # 3. Prepare text embeddings + # 3. Prepare text embeddings & conditioning image/video ( prompt_embeds, prompt_attention_mask, @@ -1000,8 +1029,6 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) - vae_dtype = self.vae.dtype - conditioning_tensors = [] is_conditioning_image_or_video = image is not None or video is not None if is_conditioning_image_or_video: @@ -1032,7 +1059,25 @@ def __call__( ) conditioning_tensors.append(condition_tensor) - # 4. Prepare latent variables + # 4. Prepare timesteps + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + sigmas = linear_quadratic_schedule(num_inference_steps) + timesteps = sigmas * 1000 + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + latent_sigma = None + if denoise_strength < 1: + sigmas, timesteps, num_inference_steps = self.get_timesteps( + sigmas, timesteps, num_inference_steps, denoise_strength + ) + latent_sigma = sigmas[:1].repeat(batch_size * num_videos_per_prompt) + + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents, conditioning_mask, video_coords, extra_conditioning_num_latents = self.prepare_latents( conditioning_tensors, @@ -1043,6 +1088,8 @@ def __call__( height=height, width=width, num_frames=num_frames, + sigma=latent_sigma, + latents=latents, generator=generator, device=device, dtype=torch.float32, @@ -1056,21 +1103,6 @@ def __call__( if self.do_classifier_free_guidance: video_coords = torch.cat([video_coords, video_coords], dim=0) - # 5. Prepare timesteps - latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 - latent_height = height // self.vae_spatial_compression_ratio - latent_width = width // self.vae_spatial_compression_ratio - sigmas = linear_quadratic_schedule(num_inference_steps) - timesteps = sigmas * 1000 - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, - num_inference_steps, - device, - timesteps=timesteps, - ) - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - self._num_timesteps = len(timesteps) - # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -1168,7 +1200,7 @@ def __call__( if not self.vae.config.timestep_conditioning: timestep = None else: - noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype) + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) if not isinstance(decode_timestep, list): decode_timestep = [decode_timestep] * batch_size if decode_noise_scale is None: diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py b/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py new file mode 100644 index 000000000000..a90e52e71709 --- /dev/null +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py @@ -0,0 +1,241 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import torch + +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLLTXVideo +from ...utils import get_logger +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .modeling_latent_upsampler import LTXLatentUpsamplerModel +from .pipeline_output import LTXPipelineOutput + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class LTXLatentUpsamplePipeline(DiffusionPipeline): + def __init__( + self, + vae: AutoencoderKLLTXVideo, + latent_upsampler: LTXLatentUpsamplerModel, + ) -> None: + super().__init__() + + self.register_modules(vae=vae, latent_upsampler=latent_upsampler) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + + def prepare_latents( + self, + video: Optional[torch.Tensor] = None, + batch_size: int = 1, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + video = video.to(device=device, dtype=self.vae.dtype) + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + init_latents = [ + retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + else: + init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std) + return init_latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def check_inputs(self, video, height, width, latents): + if height % self.vae_spatial_compression_ratio != 0 or width % self.vae_spatial_compression_ratio != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if video is not None and latents is not None: + raise ValueError("Only one of `video` or `latents` can be provided.") + if video is None and latents is None: + raise ValueError("One of `video` or `latents` has to be provided.") + + @torch.no_grad() + def __call__( + self, + video: Optional[List[PipelineImageInput]] = None, + height: int = 512, + width: int = 704, + latents: Optional[torch.Tensor] = None, + decode_timestep: Union[float, List[float]] = 0.0, + decode_noise_scale: Optional[Union[float, List[float]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + self.check_inputs( + video=video, + height=height, + width=width, + latents=latents, + ) + + if video is not None: + # Batched video input is not yet tested/supported. TODO: take a look later + batch_size = 1 + else: + batch_size = latents.shape[0] + device = self._execution_device + + if video is not None: + num_frames = len(video) + if num_frames % self.vae_temporal_compression_ratio != 1: + num_frames = ( + num_frames // self.vae_temporal_compression_ratio * self.vae_temporal_compression_ratio + 1 + ) + video = video[:num_frames] + logger.warning( + f"Video length expected to be of the form `k * {self.vae_temporal_compression_ratio} + 1` but is {len(video)}. Truncating to {num_frames} frames." + ) + video = self.video_processor.preprocess_video(video, height=height, width=width) + video = video.to(device=device, dtype=torch.float32) + + latents = self.prepare_latents( + video=video, + batch_size=batch_size, + dtype=torch.float32, + device=device, + generator=generator, + latents=latents, + ) + + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = latents.to(self.latent_upsampler.dtype) + latents = self.latent_upsampler(latents) + + if output_type == "latent": + latents = self._normalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + video = latents + else: + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return LTXPipelineOutput(frames=video) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index c36c758d8752..eb4e5f0ec8c7 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1307,6 +1307,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class LTXLatentUpsamplePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LTXPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/ltx/test_ltx_latent_upsample.py b/tests/pipelines/ltx/test_ltx_latent_upsample.py new file mode 100644 index 000000000000..f9ddb121869d --- /dev/null +++ b/tests/pipelines/ltx/test_ltx_latent_upsample.py @@ -0,0 +1,159 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from diffusers import AutoencoderKLLTXVideo, LTXLatentUpsamplePipeline +from diffusers.pipelines.ltx.modeling_latent_upsampler import LTXLatentUpsamplerModel +from diffusers.utils.testing_utils import enable_full_determinism + +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class LTXLatentUpsamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = LTXLatentUpsamplePipeline + params = {"video", "generator"} + batch_params = {"video", "generator"} + required_optional_params = frozenset(["generator", "latents", "return_dict"]) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLLTXVideo( + in_channels=3, + out_channels=3, + latent_channels=8, + block_out_channels=(8, 8, 8, 8), + decoder_block_out_channels=(8, 8, 8, 8), + layers_per_block=(1, 1, 1, 1, 1), + decoder_layers_per_block=(1, 1, 1, 1, 1), + spatio_temporal_scaling=(True, True, False, False), + decoder_spatio_temporal_scaling=(True, True, False, False), + decoder_inject_noise=(False, False, False, False, False), + upsample_residual=(False, False, False, False), + upsample_factor=(1, 1, 1, 1), + timestep_conditioning=False, + patch_size=1, + patch_size_t=1, + encoder_causal=True, + decoder_causal=False, + ) + vae.use_framewise_encoding = False + vae.use_framewise_decoding = False + + torch.manual_seed(0) + latent_upsampler = LTXLatentUpsamplerModel( + in_channels=8, + mid_channels=32, + num_blocks_per_stage=1, + dims=3, + spatial_upsample=True, + temporal_upsample=False, + ) + + components = { + "vae": vae, + "latent_upsampler": latent_upsampler, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + video = torch.randn((5, 3, 32, 32), generator=generator, device=device) + + inputs = { + "video": video, + "generator": generator, + "height": 16, + "width": 16, + "output_type": "pt", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (5, 3, 32, 32)) + expected_video = torch.randn(5, 3, 32, 32) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + def test_vae_tiling(self, expected_diff_max: float = 0.25): + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_sample_stride_height=64, + tile_sample_stride_width=64, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) + + @unittest.skip("Test is not applicable.") + def test_callback_inputs(self): + pass + + @unittest.skip("Test is not applicable.") + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + pass + + @unittest.skip("Test is not applicable.") + def test_inference_batch_consistent(self): + pass + + @unittest.skip("Test is not applicable.") + def test_inference_batch_single_identical(self): + pass