-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathinteractive_inference.py
More file actions
233 lines (192 loc) · 8.38 KB
/
interactive_inference.py
File metadata and controls
233 lines (192 loc) · 8.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# To view a copy of this license, visit http://www.apache.org/licenses/LICENSE-2.0
#
# No warranties are given. The work is provided "AS IS", without warranty of any kind, express or implied.
#
# SPDX-License-Identifier: Apache-2.0
import argparse
import os
from typing import List
import torch
import torch.distributed as dist
from omegaconf import OmegaConf
from tqdm import tqdm
from torch.utils.data import DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from torchvision.io import write_video
from torchvision import transforms # noqa: F401
from einops import rearrange
from utils.misc import set_seed
from utils.distributed import barrier
from utils.memory import gpu, get_cuda_free_memory_gb, DynamicSwapInstaller
from pipeline.interactive_causal_inference import (
InteractiveCausalInferencePipeline,
)
from utils.dataset import MultiTextDataset
# ----------------------------- Argument parsing -----------------------------
parser = argparse.ArgumentParser("Interactive causal inference")
parser.add_argument("--config_path", type=str, default="configs/longlive_interactive_inference.yaml")
args = parser.parse_args()
config = OmegaConf.load(args.config_path)
# ----------------------------- Distributed setup -----------------------------
if "LOCAL_RANK" in os.environ:
os.environ["NCCL_CROSS_NIC"] = "1"
os.environ["NCCL_DEBUG"] = os.environ.get("NCCL_DEBUG", "INFO")
os.environ["NCCL_TIMEOUT"] = os.environ.get("NCCL_TIMEOUT", "1800")
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ.get("WORLD_SIZE", "1"))
rank = int(os.environ.get("RANK", str(local_rank)))
# Set device first
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")
# Initialize process group with backend and timeout
if not dist.is_initialized():
dist.init_process_group(
backend="nccl",
rank=rank,
world_size=world_size,
timeout=torch.distributed.constants.default_pg_timeout
)
set_seed(config.seed + local_rank)
print(f"[Rank {rank}] Initialized distributed processing on device {device}")
else:
local_rank = 0
rank = 0
device = torch.device("cuda")
set_seed(config.seed)
print(f"Single GPU mode on device {device}")
low_memory = get_cuda_free_memory_gb(device) < 40
torch.set_grad_enabled(False)
pipeline = InteractiveCausalInferencePipeline(config, device=device)
if config.generator_ckpt:
state_dict = torch.load(config.generator_ckpt, map_location="cpu")
raw_gen_state_dict = state_dict["generator_ema" if config.use_ema else "generator"]
if config.use_ema:
def _clean_key(name: str) -> str:
return name.replace("_fsdp_wrapped_module.", "")
cleaned_state_dict = {_clean_key(k): v for k, v in raw_gen_state_dict.items()}
missing, unexpected = pipeline.generator.load_state_dict(
cleaned_state_dict, strict=False
)
if local_rank == 0:
if missing:
print(f"[Warning] {len(missing)} parameters missing: {missing[:8]} ...")
if unexpected:
print(f"[Warning] {len(unexpected)} unexpected params: {unexpected[:8]} ...")
else:
pipeline.generator.load_state_dict(raw_gen_state_dict)
# --------------------------- LoRA support (optional) ---------------------------
if config.lora_ckpt:
from utils.lora_utils import configure_lora_for_model
import peft
pipeline.is_lora_enabled = False
if getattr(config, "adapter", None) and configure_lora_for_model is not None:
if local_rank == 0:
print(f"LoRA enabled with config: {config.adapter}")
print("Applying LoRA to generator (inference)...")
pipeline.generator.model = configure_lora_for_model(
pipeline.generator.model,
model_name="generator",
lora_config=config.adapter,
is_main_process=(local_rank == 0),
)
lora_ckpt_path = getattr(config, "lora_ckpt", None)
if lora_ckpt_path:
if local_rank == 0:
print(f"Loading LoRA checkpoint from {lora_ckpt_path}")
lora_checkpoint = torch.load(lora_ckpt_path, map_location="cpu")
if isinstance(lora_checkpoint, dict) and "generator_lora" in lora_checkpoint:
peft.set_peft_model_state_dict(pipeline.generator.model, lora_checkpoint["generator_lora"]) # type: ignore
else:
peft.set_peft_model_state_dict(pipeline.generator.model, lora_checkpoint) # type: ignore
if local_rank == 0:
print("LoRA weights loaded for generator")
else:
if local_rank == 0:
print("No LoRA checkpoint specified; using base weights with LoRA adapters initialized")
pipeline.generator.model = pipeline.generator.model.merge_and_unload()
pipeline.is_lora_enabled = True
# Move pipeline to appropriate dtype and device
print("dtype", pipeline.generator.model.dtype)
pipeline = pipeline.to(dtype=torch.bfloat16)
if low_memory:
DynamicSwapInstaller.install_model(pipeline.text_encoder, device=device)
pipeline.generator.to(device=device)
pipeline.vae.to(device=device)
# ----------------------------- Build dataset -----------------------------
# Parse switch_frame_indices
if isinstance(config.switch_frame_indices, int):
switch_frame_indices: List[int] = [int(config.switch_frame_indices)]
else:
switch_frame_indices: List[int] = [
int(x) for x in str(config.switch_frame_indices).split(",") if str(x).strip()
]
# Create dataset
dataset = MultiTextDataset(config.data_path)
# Validate number of segments & switch_frame_indices length
num_segments = len(dataset[0]["prompts_list"])
assert len(switch_frame_indices) == num_segments - 1, (
"The number of switch_frame_indices should be the number of prompt segments minus 1"
)
print("Number of segments:", num_segments)
print("Switch frame indices:", switch_frame_indices)
num_prompts_total = len(dataset)
print(f"Number of prompt lines: {num_prompts_total}")
if dist.is_initialized():
sampler = DistributedSampler(dataset, shuffle=False, drop_last=True)
else:
sampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False)
# Create output directory
if local_rank == 0:
os.makedirs(config.output_folder, exist_ok=True)
if dist.is_initialized():
dist.barrier()
# ----------------------------- Inference loop -----------------------------
for i, batch_data in tqdm(enumerate(dataloader), disable=(local_rank != 0)):
idx = batch_data["idx"].item()
prompts_list: List[str] = batch_data["prompts_list"] # type: ignore
sampled_noise = torch.randn(
[
config.num_samples,
config.num_output_frames,
16,
60,
104,
],
device=device,
dtype=torch.bfloat16,
)
video = pipeline.inference(
noise=sampled_noise,
text_prompts_list=prompts_list,
switch_frame_indices=switch_frame_indices,
return_latents=False,
)
current_video = rearrange(video, "b t c h w -> b t h w c").cpu() * 255.0
if dist.is_initialized():
rank = dist.get_rank()
else:
rank = 0
# Determine model type for filename
if hasattr(pipeline, 'is_lora_enabled') and pipeline.is_lora_enabled:
model_type = "lora"
elif getattr(config, 'use_ema', False):
model_type = "ema"
else:
model_type = "regular"
for seed_idx in range(config.num_samples):
if config.save_with_index:
output_path = os.path.join(config.output_folder, f"rank{rank}-{idx}-{seed_idx}_{model_type}.mp4")
else:
# Use the first prompt segment as the filename prefix to avoid overly long names
short_name = prompts_list[0][0][:100].replace("/", "_")
output_path = os.path.join(config.output_folder, f"rank{rank}-{short_name}-{seed_idx}_{model_type}.mp4")
write_video(output_path, current_video[seed_idx].to(torch.uint8), fps=16)
if config.inference_iter != -1 and i >= config.inference_iter:
break
if dist.is_initialized():
dist.destroy_process_group()