Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/pt2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,4 @@ wheel>=0.41.0
xformers>=0.0.20
gradio
streamlit-keyup==0.2.0
imageio[ffmpeg]==2.26.1
89 changes: 53 additions & 36 deletions scripts/demo/streamlit_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import cv2
import imageio
import numpy as np
import shutil
import streamlit as st
import torch
import torch.nn as nn
Expand Down Expand Up @@ -40,6 +41,12 @@
from torchvision import transforms
from torchvision.utils import make_grid, save_image

# Additional options for lower end setups
USE_CUDA = True # Set this to `False`` if you want to force CPU-only mode
lowvram_mode = False # Set to `True` to enable low VRAM mode
# (low VRAM mode = float32 => float16, tested to work great on RTX 3060 w/ 12GB VRAM)

device = torch.device("cuda" if USE_CUDA and torch.cuda.is_available() else "cpu")

@st.cache_resource()
def init_st(version_dict, load_ckpt=True, load_filter=True):
Expand All @@ -59,35 +66,29 @@ def init_st(version_dict, load_ckpt=True, load_filter=True):
state["filter"] = DeepFloydDataFiltering(verbose=False)
return state


def load_model(model):
model.cuda()


lowvram_mode = False

device = torch.device("cuda" if USE_CUDA and torch.cuda.is_available() else "cpu")
model.to(device)

def set_lowvram_mode(mode):
global lowvram_mode
lowvram_mode = mode


def initial_model_load(model):
device = torch.device("cuda" if USE_CUDA and torch.cuda.is_available() else "cpu")
global lowvram_mode
if lowvram_mode:
model.model.half()
model.model.half().to(device)
else:
model.cuda()
model.to(device)
return model


def unload_model(model):
global lowvram_mode
if lowvram_mode:
model.cpu()
if lowvram_mode or not USE_CUDA:
model.cpu() # Move model to CPU to free GPU memory
torch.cuda.empty_cache()


def load_model_from_config(config, ckpt=None, verbose=True):
model = instantiate_from_config(config.model)

Expand Down Expand Up @@ -497,13 +498,14 @@ def load_img(
st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}")
return img


def get_init_img(batch_size=1, key=None):
init_image = load_img(key=key).cuda()
device = torch.device("cuda" if USE_CUDA and torch.cuda.is_available() else "cpu")

init_image = load_img(key=key).to(device) # Use `to(device)` to move to the correct device
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)

return init_image


def do_sample(
model,
sampler,
Expand All @@ -529,9 +531,9 @@ def do_sample(
st.text("Sampling")

outputs = st.empty()
precision_scope = autocast
precision_scope = autocast if USE_CUDA else lambda device: device
with torch.no_grad():
with precision_scope("cuda"):
with precision_scope("cuda" if USE_CUDA else "cpu"):
with model.ema_scope():
if T is not None:
num_samples = [num_samples, T]
Expand Down Expand Up @@ -754,7 +756,7 @@ def do_img2img(
outputs = st.empty()
precision_scope = autocast
with torch.no_grad():
with precision_scope("cuda"):
with precision_scope("cuda" if USE_CUDA else "cpu"):
with model.ema_scope():
load_model(model.conditioner)
batch, batch_uc = get_batch(
Expand Down Expand Up @@ -783,20 +785,25 @@ def do_img2img(

noise = torch.randn_like(z)

sigmas = sampler.discretization(sampler.num_steps).cuda()
# Move sigmas to the correct device (CUDA or CPU)
sigmas = sampler.discretization(sampler.num_steps).to(device)
sigma = sigmas[0]

st.info(f"all sigmas: {sigmas}")
st.info(f"noising sigma: {sigma}")

# Offset noise level handling
if offset_noise_level > 0.0:
noise = noise + offset_noise_level * append_dims(
torch.randn(z.shape[0], device=z.device), z.ndim
torch.randn(z.shape[0], device=device), z.ndim
)

# Add noise handling
if add_noise:
noised_z = z + noise * append_dims(sigma, z.ndim).cuda()
noised_z = z + noise * append_dims(sigma, z.ndim).to(device)
noised_z = noised_z / torch.sqrt(
1.0 + sigmas[0] ** 2.0
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
) # Hardcoded to DDPM-like scaling; generalize if needed
else:
noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0)

Expand Down Expand Up @@ -893,29 +900,39 @@ def load_img_for_prediction(
st.image(pil_image)
return image.to(device) * 2.0 - 1.0


def save_video_as_grid_and_mp4(
video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5
):
# Check if FFmpeg is available
try:
import imageio_ffmpeg
except ImportError:
raise RuntimeError("FFmpeg support is not installed. Use 'pip install imageio[ffmpeg]' to install it.")

if not shutil.which("ffmpeg"):
raise RuntimeError("System-level FFmpeg not found. Please install it and ensure it's in your PATH.")

os.makedirs(save_path, exist_ok=True)
base_count = len(glob(os.path.join(save_path, "*.mp4")))

video_batch = rearrange(video_batch, "(b t) c h w -> b t c h w", t=T)
video_batch = embed_watermark(video_batch)

for vid in video_batch:
save_image(vid, fp=os.path.join(save_path, f"{base_count:06d}.png"), nrow=4)

video_path = os.path.join(save_path, f"{base_count:06d}.mp4")
vid = (
(rearrange(vid, "t c h w -> t h w c") * 255).cpu().numpy().astype(np.uint8)
)
imageio.mimwrite(video_path, vid, fps=fps)

video_path_h264 = video_path[:-4] + "_h264.mp4"
os.system(f"ffmpeg -i '{video_path}' -c:v libx264 '{video_path_h264}'")
with open(video_path_h264, "rb") as f:
video_bytes = f.read()
os.remove(video_path_h264)
st.video(video_bytes)

vid = (rearrange(vid, "t c h w -> t h w c") * 255).cpu().numpy().astype(np.uint8)

# Use the correct writer for MP4 format
writer = imageio.get_writer(video_path, fps=fps, format='ffmpeg', codec='libx264')
for frame in vid:
writer.append_data(frame)
writer.close()

# Confirm that the file was created
if os.path.exists(video_path):
print(f"Video saved successfully at: {video_path}")
base_count += 1

6 changes: 6 additions & 0 deletions scripts/demo/video_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@

if mode == "img2vid":
img = load_img_for_prediction(W, H)

# Check if the image is None and use a dummy image if necessary
if img is None:
st.warning("No image provided. Using a dummy tensor for initialization.")
img = torch.zeros([1, 3, H, W]).to(device) # Dummy tensor

if "sv3d" in version:
cond_aug = 1e-5
else:
Expand Down