Skip to content

Conversation

@leffff
Copy link
Contributor

@leffff leffff commented Nov 14, 2025

This PR updates Kandinsky 5.0 Video code for handling Pro Model.

@yiyixuxu @sayakpaul

leffff and others added 30 commits October 4, 2025 10:10
@leffff
Copy link
Contributor Author

leffff commented Nov 14, 2025

@sayakpaul @yiyixuxu

I have commited the needed code.
However, I am facing the following problem:

import torch
from diffusers import Kandinsky5T2VPipeline

pipeline = Kandinsky5T2VPipeline.from_pretrained(
    "kandinskylab/Kandinsky-5.0-I2V-Pro-sft-5s-Diffusers",
    torch_dtype=torch.bfloat16,
)

pipeline = pipeline.to("cuda:0")
pipeline.transformer.set_attention_backend("flex")
pipeline.enable_model_cpu_offload()
pipeline.transformer.compile(mode="max-autotune-no-cudagraphs", dynamic=True)

from diffusers.utils import export_to_video

prompt = [
    "Photorealistic closeup video of two intricately detailed pirate ships locked in a fierce battle, complete with cannon fire and billowing sails, as they sail through the swirling waters of a steaming cup of coffee. The ships are miniature but highly realistic, with wooden textures and flags fluttering in the liquid breeze. Coffee splashes and foam ripple around them as they maneuver through the turbulent surface, dodging each other's attacks. A detailed reflection of the battle appears on the glossy surface of the coffee, adding to the dynamic realism. The camera pans and zooms to capture every dramatic moment of the high-seas clash within this tiny, unexpected world.",
]
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"

output = pipeline(
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=512,
    width=768,
    num_frames=21,
    num_inference_steps=40,
    guidance_scale=5.0,
    num_videos_per_prompt=1,
    generator=torch.Generator(42)
)

works fine! BUT! if i call the second time - I face an Error, because compilation conflicts with offloading:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[17], line 1
----> 1 output = pipeline(
      2     prompt=prompt,
      3     negative_prompt=negative_prompt,
      4     height=512,
      5     width=768,
      6     num_frames=21,
      7     num_inference_steps=40,
      8     guidance_scale=5.0,
      9     num_videos_per_prompt=1,
     10     generator=torch.Generator(42)
     11 )

File /home/user/conda/envs/kandinsky-cuda12.8/lib/python3.12/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File /home/user/conda/envs/kandinsky-cuda12.8/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py:828, in Kandinsky5T2VPipeline.__call__(self, prompt, negative_prompt, height, width, num_frames, num_inference_steps, guidance_scale, num_videos_per_prompt, generator, latents, prompt_embeds_qwen, prompt_embeds_clip, negative_prompt_embeds_qwen, negative_prompt_embeds_clip, prompt_cu_seqlens, negative_prompt_cu_seqlens, output_type, return_dict, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length, **kwargs)
    826 # 3. Encode input prompt
    827 if prompt_embeds_qwen is None:
--> 828     prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = self.encode_prompt(
    829         prompt=prompt,
    830         max_sequence_length=max_sequence_length,
    831         device=device,
    832         dtype=dtype,
    833     )
    835 if self.do_classifier_free_guidance:
    836     if negative_prompt is None:

File /home/user/conda/envs/kandinsky-cuda12.8/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py:492, in Kandinsky5T2VPipeline.encode_prompt(self, prompt, num_videos_per_prompt, max_sequence_length, device, dtype)
    489 prompt = [prompt_clean(p) for p in prompt]
    491 # Encode with Qwen2.5-VL
--> 492 prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen(
    493     prompt=prompt,
    494     device=device,
    495     max_sequence_length=max_sequence_length,
    496     dtype=dtype,
    497 )
    498 # prompt_embeds_qwen shape: [batch_size, seq_len, embed_dim]
    499 
    500 # Encode with CLIP
    501 prompt_embeds_clip = self._encode_prompt_clip(
    502     prompt=prompt,
    503     device=device,
    504     dtype=dtype,
    505 )

File /home/user/conda/envs/kandinsky-cuda12.8/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py:399, in Kandinsky5T2VPipeline._encode_prompt_qwen(self, prompt, device, max_sequence_length, dtype)
    387 full_texts = [self.prompt_template.format(p) for p in prompt]
    389 inputs = self.tokenizer(
    390     text=full_texts,
    391     images=None,
   (...)    396     padding=True,
    397 ).to(device)
--> 399 embeds = self.text_encoder(
    400     input_ids=inputs["input_ids"],
    401     return_dict=True,
    402     output_hidden_states=True,
    403 )["hidden_states"][-1][:, self.prompt_template_encode_start_idx :]
    405 attention_mask = inputs["attention_mask"][:, self.prompt_template_encode_start_idx :]
    406 cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0)

File /home/user/conda/envs/kandinsky-cuda12.8/lib/python3.12/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File /home/user/conda/envs/kandinsky-cuda12.8/lib/python3.12/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

File /home/user/conda/envs/kandinsky-cuda12.8/lib/python3.12/site-packages/accelerate/hooks.py:171, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    170 def new_forward(module, *args, **kwargs):
--> 171     args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
    172     if module._hf_hook.no_grad:
    173         with torch.no_grad():

AttributeError: 'NoneType' object has no attribute 'pre_forward'

Please help me fix this issue. Offloading is crucial for single GPU Inference

@leffff
Copy link
Contributor Author

leffff commented Nov 14, 2025

Also, #12666

@leffff
Copy link
Contributor Author

leffff commented Nov 15, 2025

@asomoza

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants