Skip to content

LTX Video 0.9.7 #11516

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open

LTX Video 0.9.7 #11516

wants to merge 14 commits into from

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented May 7, 2025

Checkpoints (only for the time being; unofficial):

Standalone latent upscale pipeline test (image):

import torch
from diffusers import LTXLatentUpsamplePipeline
from diffusers.utils import load_video, load_image, export_to_video

pipe = LTXLatentUpsamplePipeline.from_pretrained("/raid/aryan/diffusers-ltx/ltx_upsample_pipeline", torch_dtype=torch.bfloat16)
pipe.to("cuda")
pipe.vae.enable_tiling()

image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png")
video = [image]

output = pipe(video=video, height=832, width=480, generator=torch.Generator().manual_seed(42)).frames[0]
export_to_video(output, "output.mp4", fps=16)

Standalone latent upscale pipeline test (video):

import torch
from diffusers import LTXLatentUpsamplePipeline
from diffusers.utils import load_video, load_image, export_to_video

pipe = LTXLatentUpsamplePipeline.from_pretrained("/raid/aryan/diffusers-ltx/ltx_upsample_pipeline", torch_dtype=torch.bfloat16)
pipe.to("cuda")
pipe.vae.enable_tiling()

video = load_video("inputs/peter-dance.mp4")[::2][:81]

output = pipe(video=video, height=480, width=832, generator=torch.Generator().manual_seed(42)).frames[0]
export_to_video(output, "output.mp4", fps=16)

Full inference:

import torch
from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
from diffusers.utils import export_to_video, load_video

pipe = LTXConditionPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.7-diffusers", torch_dtype=torch.bfloat16)
pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.7-Latent-Spatial-Upsampler-diffusers", vae=pipe.vae, torch_dtype=torch.bfloat16)
pipe.to("cuda")
pipe_upsample.to("cuda")
pipe.vae.enable_tiling()

def round_to_nearest_resolution_acceptable_by_vae(height, width):
    height = height - (height % pipe.vae_temporal_compression_ratio)
    width = width - (width % pipe.vae_temporal_compression_ratio)
    return height, width

video = load_video(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4"
)[:21]  # Use only the first 21 frames as conditioning
condition1 = LTXVideoCondition(video=video, frame_index=0)

prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region."
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
expected_height, expected_width = 768, 1152
downscale_factor = 2 / 3
num_frames = 161

# Part 1. Generate video at smaller resolution
# Text-only conditioning is also supported without the need to pass `conditions`
downscaled_height, downscaled_width = int(expected_height * downscale_factor), int(expected_width * downscale_factor)
downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(downscaled_height, downscaled_width)
latents = pipe(
    conditions=[condition1],
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=downscaled_width,
    height=downscaled_height,
    num_frames=num_frames,
    num_inference_steps=30,
    generator=torch.Generator().manual_seed(0),
    output_type="latent",
).frames

# Part 2. Upscale generated video using latent upsampler with fewer inference steps
# The available latent upsampler upscales the height/width by 2x
upscaled_height, upscaled_width = downscaled_height * 2, downscaled_width * 2
upscaled_latents = pipe_upsample(
    latents=latents,
    output_type="latent"
).frames

# Part 3. Denoise the upscaled video with few steps to improve texture (optional, but recommended)
video = pipe(
    conditions=[condition1],
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=upscaled_width,
    height=upscaled_height,
    num_frames=num_frames,
    denoise_strength=0.4,  # Effectively, 4 inference steps out of 10
    num_inference_steps=10,
    latents=upscaled_latents,
    decode_timestep=0.05,
    image_cond_noise_scale=0.025,
    generator=torch.Generator().manual_seed(0),
    output_type="pil",
).frames[0]

# Part 4. Downscale the video to the expected resolution
video = [frame.resize((expected_width, expected_height)) for frame in video]

export_to_video(video, "output.mp4", fps=24)
output.mp4

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@nitinmukesh
Copy link

Hello Aryan,

Please also consider 0.9.6, 2 models in this.

@a-r-r-o-w
Copy link
Member Author

Oh, I was under the impression that it was already supported by someone else's PR :/ I'll try to take a look after this PR is complete

@nitinmukesh
Copy link

nitinmukesh commented May 7, 2025

Unfortunately the output was not upto the mark. I guess there must be few changes in code w.r.t 0.9.6
#11359

@a-r-r-o-w a-r-r-o-w marked this pull request as ready for review May 8, 2025 09:56
@a-r-r-o-w a-r-r-o-w requested a review from yiyixuxu May 8, 2025 09:56
@nitinmukesh
Copy link

@a-r-r-o-w

Is this correct code for T2V?

import torch
from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
from diffusers.utils import export_to_video, load_video

pipe = LTXConditionPipeline.from_pretrained(
    "a-r-r-o-w/LTX-Video-0.9.7-diffusers",
    torch_dtype=torch.bfloat16,
)
pipe.enable_sequential_cpu_offload()
pipe.vae.enable_tiling()
prompt = "A woman with light skin, wearing a blue jacket and a black hat with a veil, looks down and to her right, then back up as she speaks; she has brown hair styled in an updo, light brown eyebrows, and is wearing a white collared shirt under her jacket; the camera remains stationary on her face as she speaks; the background is out of focus, but shows trees and people in period clothing; the scene is captured in real-life footage."
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
generator = torch.Generator(device="cuda").manual_seed(42)
video = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=1024, # 1216
    height=576, # 704
    num_frames=121, # 257
    num_inference_steps=30,
    guidance_scale=1,
    generator=generator
).frames[0]

export_to_video(video, "LTX097_1.mp4", fps=24)
LTX097_1.mp4

@a-r-r-o-w
Copy link
Member Author

Looks correct to me except guidance scale. I don't think 0.9.7 is guidance-distilled, so it probably needs a higher guidance scale, like 5.0, to produce good results

@a-r-r-o-w
Copy link
Member Author

Additionally, the recommended values for decode_timestep is 0.05 and image_cond_noise_scale is 0.025 for LTX 0.9.1 and above. I'll add this to the documentation to clarify

code
import torch
from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
from diffusers.utils import export_to_video, load_video

pipe = LTXConditionPipeline.from_pretrained("/raid/aryan/diffusers-ltx/ltx_pipeline", torch_dtype=torch.bfloat16)
pipe.to("cuda")
pipe.vae.enable_tiling()
prompt = "A woman with light skin, wearing a blue jacket and a black hat with a veil, looks down and to her right, then back up as she speaks; she has brown hair styled in an updo, light brown eyebrows, and is wearing a white collared shirt under her jacket; the camera remains stationary on her face as she speaks; the background is out of focus, but shows trees and people in period clothing; the scene is captured in real-life footage."
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
generator = torch.Generator().manual_seed(42)
video = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=1024, # 1216
    height=576, # 704
    num_frames=121, # 257
    num_inference_steps=30,
    guidance_scale=5.0,
    decode_timestep=0.05,
    image_cond_noise_scale=0.025,
    generator=generator
).frames[0]

export_to_video(video, "output3.mp4", fps=24)
output3.mp4

@nitinmukesh
Copy link

Thank you, will try now.

@nicollegah
Copy link

nicollegah commented May 8, 2025

Here's an img2vid version that works but results could be better, not sure why


from PIL import Image
import torch
from diffusers import LTXConditionPipeline
from diffusers.utils import export_to_video

# ── 1. load & resize the frame ──────────────────────────────────────────────────
path = "/workspace/fa1ed4f5-a0f3-46c0-b5eb-44b1bfb8f1a7.png"
img = Image.open(path).convert("RGB")



# ── 2. set up the pipeline ──────────────────────────────────────────────────────
pipe = LTXConditionPipeline.from_pretrained(
    "a-r-r-o-w/LTX-Video-0.9.7-diffusers",
    torch_dtype=torch.bfloat16,
).to("cuda")
pipe.vae.enable_tiling()

# ── 3. run img → video ──────────────────────────────────────────────────────────
prompt=""
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"


generator = torch.Generator(device="cuda").manual_seed(42)
video = pipe(
    image           = img,
    prompt          = prompt,
    negative_prompt = negative_prompt,
    strength        = 0.75,
    num_frames      = 121,
    num_inference_steps = 80,
    guidance_scale  = 4.0,
    decode_timestep = 0.05,
    decode_noise_scale   = 0.05,
    image_cond_noise_scale=0.05,
    generator       = torch.Generator(device="cuda").manual_seed(42),
).frames[0]

export_to_video(video, "LTX097_img2vid.mp4", fps=24)
LTX097_img2vid.5.mp4

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @a-r-r-o-w


latent_sigma = None
if denoise_strength < 1:
sigmas, timesteps, num_inference_steps = self.get_timesteps(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we need to call this first, get timesteps and sigmas and use that set_timesteps on scheduler? (retrieve_timesteps`)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, nice catch. I'll test and fix this tomorrow

Copy link
Member Author

@a-r-r-o-w a-r-r-o-w May 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've verified this is the same order as the original inference code and intended. I think the only thing remains is moving self._num_timesteps = len(timesteps) below this.

There are some more changes in the original codebase related to removing some boundaries of conditioning latents. I'm not yet sure how effective that is (we should support it anyway), so I'll test more and then merge this PR

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at changes related to stripping the latent boundaries and other changes that were added. It seems like these are additional generation quality related improvements and are not required for the basic sampling mechanism in itself. Since we want to keep our pipeline examples as simple as possible, I think it's alright to not support it, and instead we can look at adding more complex logic like this in modular diffusers.

LMK if you'd like me to write a full LTXMultiscalePipeline (similar to the official repo) by combining the normal and upscale LTX pipelines to wrap the example code logic + stripping latent boundaries mentioned above

raise AttributeError("Could not access latents of provided encoder_output")


class LTXLatentUpsamplePipeline(DiffusionPipeline):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this technically is just an optional component I think? it does not have a denoising loop
ok if it is faster to get the model in this way & easier to use. I will leave it up to you to decide

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think LTX team may release more models and handling everything in the same pipeline will make it harder. For example, they have a temporal upscaler that was released as well but I don't see any inference code for it yet so I haven't added it here.

Also, the upsampler seems to be usable with other LTX models and not just 0.9.7, so I think makes sense to keep a separate pipeline, otherwise we'll have to add to all three pipelines, no?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good -

@tin2tin
Copy link

tin2tin commented May 9, 2025

Using this (instead of pipe.to("cuda")/pipe_upsample.to("cuda")):

pipe.enable_sequential_cpu_offload()
pipe.vae.enable_tiling()

in the Full Inference example, I get this error:

torch\nn\modules\conv.py", line 720, in _conv_forward
    return F.conv3d(
           ^^^^^^^^^
RuntimeError: Input type (CUDABFloat16Type) and weight type (CPUBFloat16Type) should be the same

That said, the conditional pipeline is a winner! Amazing to be able to load a pipeline supporting txt, img and video input in one go!

@nitinmukesh
Copy link

nitinmukesh commented May 9, 2025

@tin2tin

I don't think it will work.
You will have to separate the pipelines

  1. Do inference with enable_sequential_cpu_offload enabled. (Working fine, I tested)
  2. Now either delete the pipeline or move to cpu and create upsampling pipeline
  3. upscale, save video and then delete pipeline (or move to cpu)

This model is not for me (can't provide code) as it is taking 1 hr for inference (can't complain as model is too big for 8 GB VRAM)

@tin2tin
Copy link

tin2tin commented May 9, 2025

@nitinmukesh This code with bitsandbytes quantizing the transformer works for txt, img and vid input and is running on 14 GB VRAM. Maybe sequential_offload will bring it further down(should properly add it as low_ram handling): https://github.com/tin2tin/Pallaidium/blob/fa1da79faf817227a204da15cdfae4dfb0d5452e/__init__.py#L3170

The error I got was in the "Full Inference" example.

@nitinmukesh
Copy link

Oh, I was under the impression that it was already supported by someone else's PR :/ I'll try to take a look after this PR is complete

0.9.6 distilled is no longer needed. I tested with LTXPipeline and LTXImageToVideoPipeline and distilled model is working as expected. Time to test 0.9.6-dev.
I guess documentation is needed to cover different model and their corresponding pipelines as this might confuse few.

@nitinmukesh
Copy link

@tin2tin

Unfortunately sequential doesn't work with nf4. I did logged issue but the response didn't made sense to me. I did quantized hunyuancommunity model to nf4 so it should have worked.
bitsandbytes-foundation/bitsandbytes#1525

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants