diff --git a/.gitignore b/.gitignore
index c83750f5..8154f988 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,6 +5,21 @@
*.pyc
.ipynb_checkpoints
results/
+data/audio/*.wav
+data/video/*.mp4
+data/video/*.mov
+data/video/ma/*.mp4
+avatars.zip
+avatars/
+avator_1_skipImage.zip
+avator_1_skipImage/
+cuda_11.7.0_515.43.04_linux.run
+eaudio --start
+ffmpeg-7.0.2-amd64-static/
+ffmpeg-release-amd64-static.tar.xz
+musetalk.zip
+scripts.zip
+venv_PhotonSync/
models/
**/__pycache__/
*.py[cod]
@@ -15,4 +30,4 @@ ffmprobe*
ffplay*
debug
exp_out
-.gradio
\ No newline at end of file
+.gradio
diff --git a/PhotonSync_Diagram.md b/PhotonSync_Diagram.md
new file mode 100644
index 00000000..cbfbaea8
--- /dev/null
+++ b/PhotonSync_Diagram.md
@@ -0,0 +1,52 @@
+graph TD
+ %% Define Styles
+ classDef process fill:#cde4ff,stroke:#6699ff,stroke-width:2px;
+ classDef data fill:#e6ffcc,stroke:#99cc66,stroke-width:2px,rx:10px,ry:10px;
+ classDef component fill:#fff2cc,stroke:#ffcc66,stroke-width:2px;
+ classDef io fill:#f2f2f2,stroke:#333,stroke-width:2px,stroke-dasharray: 5 5;
+ classDef hardware fill:#e0e0e0,stroke:#666,stroke-width:2px,rx:5px,ry:5px;
+ classDef title fill:#ffffff,stroke:#ffffff,font-weight:bold,font-size:18px;
+
+ %% One-Time Preparation Phase
+ prep_title("One-Time Preparation
(一次性素材准备)"):::title
+ prep_video(Input Video/Images
输入视频/图像) ==> prep_frames(Extract Frames
提取帧)
+ prep_frames --> prep_landmark(Get Face BBox & Landmarks
获取人脸框和关键点)
+ prep_landmark -- Face Coords (人脸坐标) --> avatar_data
+ prep_landmark -- Cropped Face (裁剪的人脸) --> prep_vae(VAE Encoder
VAE编码器)
+ prep_landmark -- Full Frame (完整帧) --> prep_parse(Face Parsing
人脸解析)
+ prep_vae --> prep_latents(Latent Vectors
潜向量)
+ prep_parse --> prep_masks(Blending Masks
融合蒙版)
+ prep_latents & prep_masks --> avatar_data(Avatar Data Storage
虚拟人数据存储
Frames, Coords, Latents, Masks)
+
+ %% Sender Application Phase
+ sender_title("PhotonSync Sender Application
(发送端应用)"):::title
+ photon_gpt[PhotonGPT Audio Input
PhotonGPT音频输入] --> audio_proc(Audio Feature Extraction
音频特征提取
Whisper)
+ photon_gpt --> audio_enc(Audio Encoding
音频编码
GStreamerAudio / opusenc)
+
+ audio_proc -- Audio Features (音频特征) --> rt_unet
+ avatar_data -- Pre-calculated Latents (预计算潜向量) --> rt_unet
+
+ rt_unet(UNet Inference
UNet推理
Generate Lip-Synced Latents
生成口型同步的潜向量) --> rt_vae(VAE Decoder
VAE解码器
Latents to Image Frame
潜向量转图像帧)
+ rt_vae -- Generated Face Frame (生成的面部帧) --> rt_blend
+
+ avatar_data -- Original Frame, Mask, Coords (原始帧、蒙版、坐标) --> rt_blend
+ rt_blend(Real-time Blending
实时融合
Combine face and background
合并面部与背景) -- Final Video Frame (最终视频帧) --> video_enc(Video Encoding
视频编码
GStreamerPipeline / nvh264enc)
+
+ video_enc -- H.264 RTP Stream --> network((Network
网络))
+ audio_enc -- Opus RTP Stream --> network
+
+ %% Receiver Phase
+ receiver_title("Holobot Receiver
(接收端)"):::title
+ network -- Video Stream (视频流) --> vid_receiver(Video UDP Source
视频UDP源)
+ network -- Audio Stream (音频流) --> aud_receiver(Audio UDP Source
音频UDP源)
+
+ vid_receiver --> vid_jitter(Video Jitter Buffer
视频抖动缓冲)
+ vid_jitter --> vid_depay(Video RTP Depayload
视频RTP解包)
+ vid_depay --> vid_parse(H.264 Parse
H.264解析)
+ vid_parse --> vid_dec(NVDEC Decode
NVDEC解码
GPU)
+ vid_dec --> vid_sink(Video Sink
视频接收器
d3d11videosink)
+
+ aud_receiver --> aud_jitter(Audio Jitter Buffer
音频抖动缓冲)
+ aud_jitter --> aud_depay(Audio RTP Depayload
音频RTP解包)
+ aud_depay --> aud_parse(Opus Parse
Opus解析)
+ aud_parse --> aud_dec
\ No newline at end of file
diff --git a/Photonsync.ps1 b/Photonsync.ps1
new file mode 100644
index 00000000..7894afb7
--- /dev/null
+++ b/Photonsync.ps1
@@ -0,0 +1,2 @@
+python -m scripts.realtime_stream_gst_15 --inference_config configs/inference/realtime.yaml --skip_save_images
+
diff --git a/configs/inference/realtime-stable.yaml b/configs/inference/realtime-stable.yaml
new file mode 100644
index 00000000..ea67de19
--- /dev/null
+++ b/configs/inference/realtime-stable.yaml
@@ -0,0 +1,7 @@
+avator_1:
+ preparation: False
+ bbox_shift: 5
+ video_path: "data/video/sun.mp4"
+ audio_clips:
+ audio_0: "data/audio/sun.wav"
+ audio_1: "data/audio/yongen.wav"
diff --git a/configs/inference/realtime.yaml b/configs/inference/realtime.yaml
index 9319e987..c740a758 100644
--- a/configs/inference/realtime.yaml
+++ b/configs/inference/realtime.yaml
@@ -1,10 +1,8 @@
-avator_1:
- preparation: True # your can set it to False if you want to use the existing avator, it will save time
- bbox_shift: 5
- video_path: "data/video/yongen.mp4"
+avatar_3:
+ preparation: False
+ bbox_shift: 0
+ batch_size: 16
+ video_path: "data/video/aiden-glasses-processed.mp4"
audio_clips:
- audio_0: "data/audio/yongen.wav"
- audio_1: "data/audio/eng.wav"
-
-
-
+ audio_0: "data/audio/sun.wav"
+ audio_1: "data/audio/yongen.wav"
diff --git a/data/video/musk.png b/data/video/musk.png
new file mode 100644
index 00000000..bca50eef
Binary files /dev/null and b/data/video/musk.png differ
diff --git a/data/video/younglook.7z b/data/video/younglook.7z
new file mode 100644
index 00000000..9d8a028e
Binary files /dev/null and b/data/video/younglook.7z differ
diff --git a/download-gi-deps.py b/download-gi-deps.py
new file mode 100644
index 00000000..79f647ec
--- /dev/null
+++ b/download-gi-deps.py
@@ -0,0 +1,40 @@
+import os
+import urllib.request
+import zipfile
+import shutil
+
+# URLs for missing dependencies
+dll_sources = {
+ "z.dll": "https://github.com/winlibs/zlib/releases/download/zlib-1.3/zlib-1.3-msvc-x64.zip",
+ "intl-8.dll": "https://github.com/mlocati/gettext-iconv-windows/releases/download/v0.21-v1.16/gettext0.21-iconv1.16-static-64.zip"
+}
+
+download_dir = "gtk_deps_download"
+os.makedirs(download_dir, exist_ok=True)
+
+# Download and extract dependencies
+for dll_name, url in dll_sources.items():
+ zip_path = os.path.join(download_dir, f"{dll_name}.zip")
+ print(f"Downloading {url}")
+ urllib.request.urlretrieve(url, zip_path)
+
+ print(f"Extracting {zip_path}")
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
+ zip_ref.extractall(download_dir)
+
+# Copy DLLs to GTK bin directory
+gtk_bin = r'C:\gtk\bin'
+for root, dirs, files in os.walk(download_dir):
+ for file in files:
+ if file.lower().endswith('.dll'):
+ source = os.path.join(root, file)
+ dest = os.path.join(gtk_bin, file)
+ print(f"Copying {source} to {dest}")
+ shutil.copy2(source, dest)
+
+ # Also create lib* version
+ lib_dest = os.path.join(gtk_bin, f"lib{file}")
+ print(f"Creating lib version at {lib_dest}")
+ shutil.copy2(source, lib_dest)
+
+print("Done installing dependencies!")
diff --git a/find -dll-dep.py b/find -dll-dep.py
new file mode 100644
index 00000000..fe822fc9
--- /dev/null
+++ b/find -dll-dep.py
@@ -0,0 +1,36 @@
+import os
+import sys
+import ctypes
+
+# Add the OpenCV bin directory to the DLL search path
+opencv_bin = r"D:\tencent\devel\cv\opencv-4.5.5\build\install\x64\vc16\bin"
+os.add_dll_directory(opencv_bin)
+
+# Add the directory containing the cv2.pyd file to the Python path
+opencv_pyd_dir = r"D:\tencent\devel\cv\opencv-4.5.5\build\lib\python3\Release"
+sys.path.insert(0, opencv_pyd_dir)
+
+# Pre-load only the essential DLLs (skip highgui)
+essential_dlls = [
+ "opencv_core455.dll",
+ "opencv_imgproc455.dll",
+ "opencv_imgcodecs455.dll",
+ "opencv_videoio455.dll",
+ "opencv_flann455.dll",
+ "opencv_features2d455.dll"
+]
+
+for dll in essential_dlls:
+ try:
+ dll_path = os.path.join(opencv_bin, dll)
+ ctypes.CDLL(dll_path)
+ print(f"Successfully pre-loaded {dll}")
+ except Exception as e:
+ print(f"Failed to load {dll}: {e}")
+
+# Now try importing cv2
+try:
+ import cv2
+ print(f"\nSuccess! OpenCV version: {cv2.__version__}")
+except ImportError as e:
+ print(f"\nStill failed: {e}")
diff --git a/fix_gtk.py b/fix_gtk.py
new file mode 100644
index 00000000..ca706b55
--- /dev/null
+++ b/fix_gtk.py
@@ -0,0 +1,71 @@
+import os
+import sys
+import ctypes
+from ctypes import windll
+import glob
+import platform
+
+def inspect_gtk_installation():
+ print(f"\n#### GTK Installation Info ####")
+ gtk_bin = r'C:\gtk\bin'
+ gtk_lib = r'C:\gtk\lib'
+
+ # Add both directories to PATH
+ os.environ['PATH'] = f"{gtk_bin};{gtk_lib};{os.environ['PATH']}"
+ print(f"Python version: {platform.python_version()}")
+
+ # Set additional environment variables
+ os.environ['GI_TYPELIB_PATH'] = r'C:\gtk\lib\girepository-1.0'
+ print(f"GI_TYPELIB_PATH: {os.environ.get('GI_TYPELIB_PATH', 'Not set')}")
+
+ # Find all DLLs in GTK directories
+ bin_dlls = glob.glob(os.path.join(gtk_bin, "*.dll"))
+ lib_dlls = glob.glob(os.path.join(gtk_lib, "*.dll"))
+ print(f"Found {len(bin_dlls)} DLLs in {gtk_bin}")
+ print(f"Found {len(lib_dlls)} DLLs in {gtk_lib}")
+
+ # Critical DLLs that need to be loaded in the right order
+ critical_dlls = [
+ os.path.join(gtk_bin, "glib-2.0-0.dll"),
+ os.path.join(gtk_bin, "gobject-2.0-0.dll"),
+ os.path.join(gtk_bin, "gmodule-2.0-0.dll"),
+ os.path.join(gtk_bin, "girepository-1.0-1.dll"),
+ os.path.join(gtk_bin, "gio-2.0-0.dll"),
+ os.path.join(gtk_bin, "ffi-8.dll"),
+ os.path.join(gtk_bin, "z.dll"),
+ os.path.join(gtk_bin, "libintl-8.dll")
+ ]
+
+ # Try to load critical DLLs first
+ print("\n#### Loading Critical DLLs ####")
+ for dll in critical_dlls:
+ if os.path.exists(dll):
+ try:
+ windll.LoadLibrary(dll)
+ print(f"✓ Loaded {os.path.basename(dll)}")
+ except Exception as e:
+ print(f"✗ Failed loading {os.path.basename(dll)}: {e}")
+ else:
+ print(f"! Missing {os.path.basename(dll)}")
+
+ # Try importing gi
+ print("\n#### Testing PyGObject Import ####")
+ try:
+ import gi
+ print(f"✓ Successfully imported gi ({gi.__file__})")
+ return True
+ except ImportError as e:
+ print(f"✗ Failed to import gi: {e}")
+ return False
+
+if __name__ == "__main__":
+ success = inspect_gtk_installation()
+
+ if not success:
+ print("\n#### Troubleshooting Tips ####")
+ print("1. Your PyGObject installation might not be compatible with your GTK installation.")
+ print("2. Try reinstalling PyGObject with:")
+ print(" pip uninstall pygobject")
+ print(" pip install pygobject")
+ print("3. If that doesn't work, try installing from an alternate source:")
+ print(" pip install --no-binary :all: pygobject")
diff --git a/holobot-receiver.ps1 b/holobot-receiver.ps1
new file mode 100644
index 00000000..9addd689
--- /dev/null
+++ b/holobot-receiver.ps1
@@ -0,0 +1,44 @@
+<#
+.SYNOPSIS
+ Launches the GStreamer pipeline to decode and play the MuseTalk real-time A/V stream.
+
+.DESCRIPTION
+ This script uses a "zero-copy" video pipeline. The video frame is decoded on the GPU
+ and rendered directly to the screen using d3d11videosink without ever being copied
+ to system RAM, providing the lowest possible latency and highest performance.
+
+.NOTES
+ - Requires GStreamer 1.0 (with msvc_x86_64 and nvcodec packages) to be installed and in the system's PATH.
+ - Run this script from a PowerShell terminal.
+ - To stop the stream, press Ctrl+C.
+#>
+
+# --- Configuration ---
+$videoPort = 5000
+$audioPort = 5001
+
+# --- User Feedback ---
+Write-Host "🚀 Launching GStreamer Decoder (Zero-Copy Video Pipeline)..." -ForegroundColor Green
+Write-Host " - Listening for VIDEO on UDP port: $videoPort"
+Write-Host " - Listening for AUDIO on UDP port: $audioPort"
+Write-Host " (Press Ctrl+C to stop the stream)"
+Write-Host ""
+
+
+# --- Launch GStreamer Pipeline ---
+gst-launch-1.0 -v `
+ udpsrc port=$videoPort caps="application/x-rtp, media=video, clock-rate=90000, encoding-name=H264, payload=96" `
+ ! rtpjitterbuffer latency=200 `
+ ! queue `
+ ! rtph264depay `
+ ! h264parse `
+ ! nvh264dec `
+ ! d3d11videosink sync=true qos=true max-lateness=200000000 `
+`
+ udpsrc port=$audioPort caps="application/x-rtp, media=audio, clock-rate=48000, encoding-name=OPUS, payload=97" `
+ ! rtpjitterbuffer latency=375 drop-on-latency=true `
+ ! rtpopusdepay `
+ ! opusdec plc=true `
+ ! audioconvert `
+ ! audioresample `
+ ! wasapisink sync=true
\ No newline at end of file
diff --git a/musetalk/utils/blending.py b/musetalk/utils/blending.py
old mode 100755
new mode 100644
diff --git a/musetalk/utils/face_parsing/resnet.py b/musetalk/utils/face_parsing/resnet.py
index e2e5d87e..a306abb7 100755
--- a/musetalk/utils/face_parsing/resnet.py
+++ b/musetalk/utils/face_parsing/resnet.py
@@ -80,7 +80,7 @@ def forward(self, x):
return feat8, feat16, feat32
def init_weight(self, model_path):
- state_dict = torch.load(model_path) #modelzoo.load_url(resnet18_url)
+ state_dict = torch.load(model_path, weights_only=False) #modelzoo.load_url(resnet18_url)
self_state_dict = self.state_dict()
for k, v in state_dict.items():
if 'fc' in k: continue
diff --git a/requirements-pip.txt b/requirements-pip.txt
new file mode 100644
index 00000000..9f628804
Binary files /dev/null and b/requirements-pip.txt differ
diff --git a/requirements.txt b/requirements.txt
index e87aa41d..b2f95cd3 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,8 +1,8 @@
diffusers==0.30.2
accelerate==0.28.0
-numpy==1.23.5
-tensorflow==2.12.0
-tensorboard==2.12.0
+numpy
+tensorflow
+tensorboard
opencv-python==4.9.0.80
soundfile==0.12.1
transformers==4.39.2
diff --git a/scripts/.env b/scripts/.env
new file mode 100644
index 00000000..c1e10c3a
--- /dev/null
+++ b/scripts/.env
@@ -0,0 +1,27 @@
+# .env file for realtime_stream_sync watcher
+
+# --- Required ---
+# Full path to the WAV file that PhotonGPT writes
+#WATCHED_WAV_FILE_PATH=E:\devel\NanoAR\HoloBot\PhotonGPT\latest_response.opus
+
+# Path to the avatar configuration YAML file used by realtime_stream_sync
+AVATAR_CONFIG_PATH=configs/inference/realtime.yaml # Or your actual path
+
+# The specific Avatar ID from the config file to use for lipsyncing
+AVATAR_ID_TO_USE=avatar_3 # Replace with your actual avatar ID from the YAML
+
+# --- Optional (Defaults will be used if not set) ---
+# Target FPS for the GStreamer output (should match video source)
+TARGET_FPS=25
+
+# Frame skipping parameters - lower values = more aggressive
+FRAME_SKIP_THRESHOLD=1000
+OVERLOAD_MULTIPLIER=10.0
+MAX_FRAMES_TO_SKIP=0
+
+#FRAME_SKIP_THRESHOLD=2
+#OVERLOAD_MULTIPLIER=1.2
+#MAX_FRAMES_TO_SKIP=1
+
+STREAM_PIPE_PATH=D:\photon_audio.opus
+GSTREAMER_LAUNCH_PATH=C:\gstreamer\1.0\msvc_x86_64\bin\gst-launch-1.0.exe
diff --git a/scripts/Downloads - Shortcut.lnk b/scripts/Downloads - Shortcut.lnk
new file mode 100644
index 00000000..e1e17cb8
Binary files /dev/null and b/scripts/Downloads - Shortcut.lnk differ
diff --git a/scripts/audiostream.py b/scripts/audiostream.py
new file mode 100644
index 00000000..d57d0180
--- /dev/null
+++ b/scripts/audiostream.py
@@ -0,0 +1,122 @@
+import gi
+import numpy as np
+import subprocess
+import time
+import ffmpeg
+
+gi.require_version('Gst', '1.0')
+from gi.repository import Gst, GLib, GObject
+
+
+class GStreamerAudio:
+ """ 使用 GStreamer 进行音频推流 """
+
+ def __init__(self):
+ Gst.init(None)
+
+ # 创建 GStreamer 管道
+ self.pipeline = Gst.parse_launch(
+ "appsrc name=audio_source format=time is-live=true "
+ "caps=audio/x-raw,format=S16LE,channels=2,rate=48000,layout=interleaved ! "
+ "queue ! audioconvert ! queue ! audioresample ! "
+ "queue ! opusenc ! queue ! rtpopuspay ! "
+ "udpsink host=127.0.0.1 port=5001 sync=false"
+ )
+
+ self.appsrc = self.pipeline.get_by_name("audio_source")
+ self.appsrc.set_property("blocksize", 65536) # 增大 blocksize 避免数据阻塞
+ self.appsrc.set_property("format", Gst.Format.TIME)
+
+ self.pipeline.set_state(Gst.State.PLAYING)
+
+ def send_audio(self, audio_data):
+ """ 发送 PCM 音频数据到 GStreamer """
+ buffer = Gst.Buffer.new_allocate(None, len(audio_data.tobytes()), None)
+ buffer.fill(0, audio_data.tobytes())
+ self.appsrc.emit("push-buffer", buffer)
+ print(f"✅ 推送音频: {len(audio_data)} samples")
+
+ def stop(self):
+ """ 关闭音频推流 """
+ self.pipeline.set_state(Gst.State.NULL)
+ print("✅ GStreamer 音频推流已关闭")
+
+
+class FFmpegAudioReader:
+ """ 使用 FFmpeg 读取整个音频文件,并转换为 PCM """
+
+ def __init__(self, audio_file):
+ self.audio_file = audio_file
+ probe = ffmpeg.probe(audio_file)
+ self.sample_rate = int(probe['streams'][0]['sample_rate'])
+ self.channels = int(probe['streams'][0]['channels'])
+
+ def read_full_audio(self):
+ """ 使用 FFmpeg 读取整个音频文件 """
+ process = subprocess.Popen(
+ ["ffmpeg", "-i", self.audio_file, "-f", "s16le", "-ac", "2", "-ar", "48000", "-"],
+ stdout=subprocess.PIPE, stderr=subprocess.DEVNULL
+ )
+ raw_data = process.stdout.read()
+ process.stdout.close()
+ process.wait()
+
+ if not raw_data:
+ print("❌ 读取音频文件失败!")
+ return None
+
+ audio_data = np.frombuffer(raw_data, dtype=np.int16).reshape(-1, 2)
+ print(f"✅ 读取完整音频,共 {len(audio_data)} samples")
+ return audio_data
+
+
+def split_audio(audio_data, num_chunks):
+ """ 将音频数据分割成 num_chunks 份 """
+ if num_chunks <= 1:
+ return [audio_data]
+
+ chunk_size = len(audio_data) // num_chunks
+ chunks = [audio_data[i * chunk_size: (i + 1) * chunk_size] for i in range(num_chunks)]
+
+ # 如果有剩余样本,加到最后一个 chunk
+ remainder = len(audio_data) % num_chunks
+ if remainder > 0:
+ chunks[-1] = np.vstack((chunks[-1], audio_data[-remainder:]))
+
+ print(f"✅ 音频已分割为 {num_chunks} 份,每份约 {chunk_size} samples")
+ return chunks
+
+
+def main(audio_file, split_chunks=1):
+ """ 读取完整音频文件,并用 GStreamer 推流 """
+ audio_reader = FFmpegAudioReader(audio_file)
+ audio_data = audio_reader.read_full_audio()
+
+ if audio_data is None:
+ return
+
+ # 分割音频
+ audio_chunks = split_audio(audio_data, split_chunks)
+
+ gst_audio = GStreamerAudio()
+ try:
+ for i, chunk in enumerate(audio_chunks):
+ gst_audio.send_audio(chunk)
+ print(f"📡 推送第 {i+1}/{split_chunks} 份音频...")
+ time.sleep(0.1) # 控制发送速率,避免堵塞
+ except KeyboardInterrupt:
+ print("\n⏹️ 手动停止音频流")
+ finally:
+ gst_audio.stop()
+
+
+if __name__ == "__main__":
+ import sys
+ if len(sys.argv) < 2:
+ print("❌ 用法: python audiostream.py <音频文件路径> [分块数量]")
+ sys.exit(1)
+
+ audio_file = sys.argv[1]
+ split_chunks = int(sys.argv[2]) if len(sys.argv) > 2 else 1 # 默认为 1,不分割
+
+ main(audio_file, split_chunks)
diff --git a/scripts/realtime_stream.py b/scripts/realtime_stream.py
new file mode 100644
index 00000000..16263498
--- /dev/null
+++ b/scripts/realtime_stream.py
@@ -0,0 +1,460 @@
+import ffmpeg
+import argparse
+import os
+import pyaudio
+import gi
+gi.require_version('Gst', '1.0')
+from gi.repository import Gst, GLib
+import threading
+import queue
+from omegaconf import OmegaConf
+import soundfile as sf
+import wave
+import subprocess
+import numpy as np
+import cv2
+import torch
+import glob
+import pickle
+import sys
+from tqdm import tqdm
+import copy
+import json
+from musetalk.utils.utils import get_file_type,get_video_fps,datagen
+from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
+from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending
+from musetalk.utils.utils import load_all_model
+import shutil
+
+import threading
+import queue
+
+import time
+
+# load model weights
+audio_processor, vae, unet, pe = load_all_model()
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+timesteps = torch.tensor([0], device=device)
+pe = pe.half()
+vae.vae = vae.vae.half()
+unet.model = unet.model.half()
+
+class GStreamerAudio:
+ """ 使用 GStreamer 进行音频推流 """
+
+ def __init__(self):
+ Gst.init(None)
+
+ # 创建 GStreamer 管道
+ self.pipeline = Gst.parse_launch(
+ "appsrc name=audio_source format=time is-live=true "
+ "caps=audio/x-raw,format=S16LE,channels=2,rate=48000,layout=interleaved ! "
+ "queue ! audioconvert ! queue ! audioresample ! "
+ "queue ! opusenc ! queue ! rtpopuspay ! "
+ "udpsink host=127.0.0.1 port=5001 sync=false"
+ )
+
+ self.appsrc = self.pipeline.get_by_name("audio_source")
+ self.appsrc.set_property("blocksize", 65536) # 增加 blocksize,避免数据阻塞
+ self.appsrc.set_property("format", Gst.Format.TIME)
+
+ self.pipeline.set_state(Gst.State.PLAYING)
+
+ def send_audio(self, audio_data):
+ """ 一次性发送完整的 PCM 音频数据到 GStreamer """
+ buffer = Gst.Buffer.new_allocate(None, len(audio_data.tobytes()), None)
+ buffer.fill(0, audio_data.tobytes())
+ self.appsrc.emit("push-buffer", buffer)
+ print(f"✅ 已推送完整音频,共 {len(audio_data)} samples")
+
+ def stop(self):
+ """ 关闭音频推流 """
+ self.pipeline.set_state(Gst.State.NULL)
+ print("✅ GStreamer 音频推流已关闭")
+
+class FFmpegAudioReader:
+ """ 使用 FFmpeg 读取整个音频文件,并转换为 PCM """
+
+ def __init__(self, audio_file):
+ self.audio_file = audio_file
+ probe = ffmpeg.probe(audio_file)
+ self.sample_rate = int(probe['streams'][0]['sample_rate'])
+ self.channels = int(probe['streams'][0]['channels'])
+
+ def read_full_audio(self):
+ """ 使用 FFmpeg 读取整个音频文件 """
+ process = subprocess.Popen(
+ ["ffmpeg", "-i", self.audio_file, "-f", "s16le", "-ac", "2", "-ar", "48000", "-"],
+ stdout=subprocess.PIPE, stderr=subprocess.DEVNULL
+ )
+ raw_data = process.stdout.read()
+ process.stdout.close()
+ process.wait()
+
+ if not raw_data:
+ print("❌ 读取音频文件失败!")
+ return None
+
+ audio_data = np.frombuffer(raw_data, dtype=np.int16).reshape(-1, 2)
+ print(f"✅ 读取完整音频,共 {len(audio_data)} samples")
+ return audio_data
+
+def split_audio(audio_data, num_chunks):
+ """ 将音频数据分割成 num_chunks 份 """
+ if num_chunks <= 1:
+ return [audio_data]
+
+ chunk_size = len(audio_data) // num_chunks
+ chunks = [audio_data[i * chunk_size: (i + 1) * chunk_size] for i in range(num_chunks)]
+
+ # 如果有剩余样本,加到最后一个 chunk
+ remainder = len(audio_data) % num_chunks
+ if remainder > 0:
+ chunks[-1] = np.vstack((chunks[-1], audio_data[-remainder:]))
+
+ print(f"✅ 音频已分割为 {num_chunks} 份,每份约 {chunk_size} samples")
+ return chunks
+
+class GStreamerPipeline:
+ def __init__(self, width=640, height=480, fps=25, host="127.0.0.1", port=5000):
+ self.width = width
+ self.height = height
+ self.fps = fps
+ self.host = host
+ self.port = port
+
+ # 🎬 GStreamer 推流管道
+ self.GSTREAMER_PIPELINE = (
+ "appsrc ! videoconvert ! video/x-raw ! "
+ "queue ! x264enc bitrate=8000 tune=zerolatency ! "
+ "rtph264pay ! udpsink host=127.0.0.1 port=5000"
+ )
+
+
+ # 使用 OpenCV 初始化 GStreamer 视频写入
+ self.video_writer = cv2.VideoWriter(self.GSTREAMER_PIPELINE, cv2.CAP_GSTREAMER, 0, 25, (640, 480), True)
+
+ if not self.video_writer.isOpened():
+ raise RuntimeError("❌ GStreamer 推流初始化失败!请检查 GStreamer 是否安装")
+
+ def send_frame(self, frame):
+ """ 发送一帧到 GStreamer """
+ if frame is None:
+ return
+
+ frame = cv2.resize(frame, (self.width, self.height)) # 确保帧大小匹配
+ frame = frame.astype(np.uint8) # 转换为 uint8 格式
+ self.video_writer.write(frame) # 推流
+
+ def stop(self):
+ """ 关闭 GStreamer 推流 """
+ self.video_writer.release()
+
+
+
+def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000):
+ cap = cv2.VideoCapture(vid_path)
+ count = 0
+ while True:
+ if count > cut_frame:
+ break
+ ret, frame = cap.read()
+ if ret:
+ cv2.imwrite(f"{save_path}/{count:08d}.png", frame)
+ count += 1
+ else:
+ break
+
+def osmakedirs(path_list):
+ for path in path_list:
+ os.makedirs(path) if not os.path.exists(path) else None
+
+
+@torch.no_grad()
+class Avatar:
+ def __init__(self, avatar_id, video_path, bbox_shift, batch_size, preparation):
+ self.avatar_id = avatar_id
+ self.video_path = video_path
+ self.bbox_shift = bbox_shift
+ self.avatar_path = f"./results/avatars/{avatar_id}"
+ self.full_imgs_path = f"{self.avatar_path}/full_imgs"
+ self.coords_path = f"{self.avatar_path}/coords.pkl"
+ self.latents_out_path= f"{self.avatar_path}/latents.pt"
+ self.video_out_path = f"{self.avatar_path}/vid_output/"
+ self.mask_out_path =f"{self.avatar_path}/mask"
+ self.mask_coords_path =f"{self.avatar_path}/mask_coords.pkl"
+ self.avatar_info_path = f"{self.avatar_path}/avator_info.json"
+ self.avatar_info = {
+ "avatar_id":avatar_id,
+ "video_path":video_path,
+ "bbox_shift":bbox_shift
+ }
+ self.preparation = preparation
+ self.batch_size = batch_size
+ self.idx = 0
+ self.init()
+
+ def init(self):
+ if self.preparation:
+ if os.path.exists(self.avatar_path):
+ response = input(f"{self.avatar_id} exists, Do you want to re-create it ? (y/n)")
+ if response.lower() == "y":
+ shutil.rmtree(self.avatar_path)
+ print("*********************************")
+ print(f" creating avator: {self.avatar_id}")
+ print("*********************************")
+ osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path])
+ self.prepare_material()
+ else:
+ self.input_latent_list_cycle = torch.load(self.latents_out_path)
+ with open(self.coords_path, 'rb') as f:
+ self.coord_list_cycle = pickle.load(f)
+ input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
+ input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
+ self.frame_list_cycle = read_imgs(input_img_list)
+ with open(self.mask_coords_path, 'rb') as f:
+ self.mask_coords_list_cycle = pickle.load(f)
+ input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]'))
+ input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
+ self.mask_list_cycle = read_imgs(input_mask_list)
+ else:
+ print("*********************************")
+ print(f" creating avator: {self.avatar_id}")
+ print("*********************************")
+ osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path])
+ self.prepare_material()
+ else:
+ if not os.path.exists(self.avatar_path):
+ print(f"{self.avatar_id} does not exist, you should set preparation to True")
+ sys.exit()
+
+ with open(self.avatar_info_path, "r") as f:
+ avatar_info = json.load(f)
+
+ if avatar_info['bbox_shift'] != self.avatar_info['bbox_shift']:
+ response = input(f" 【bbox_shift】 is changed, you need to re-create it ! (c/continue)")
+ if response.lower() == "c":
+ shutil.rmtree(self.avatar_path)
+ print("*********************************")
+ print(f" creating avator: {self.avatar_id}")
+ print("*********************************")
+ osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path])
+ self.prepare_material()
+ else:
+ sys.exit()
+ else:
+ self.input_latent_list_cycle = torch.load(self.latents_out_path)
+ with open(self.coords_path, 'rb') as f:
+ self.coord_list_cycle = pickle.load(f)
+ input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
+ input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
+ self.frame_list_cycle = read_imgs(input_img_list)
+ with open(self.mask_coords_path, 'rb') as f:
+ self.mask_coords_list_cycle = pickle.load(f)
+ input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]'))
+ input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
+ self.mask_list_cycle = read_imgs(input_mask_list)
+
+ def prepare_material(self):
+ print("preparing data materials ... ...")
+ with open(self.avatar_info_path, "w") as f:
+ json.dump(self.avatar_info, f)
+
+ if os.path.isfile(self.video_path):
+ video2imgs(self.video_path, self.full_imgs_path, ext = 'png')
+ else:
+ print(f"copy files in {self.video_path}")
+ files = os.listdir(self.video_path)
+ files.sort()
+ files = [file for file in files if file.split(".")[-1]=="png"]
+ for filename in files:
+ shutil.copyfile(f"{self.video_path}/{filename}", f"{self.full_imgs_path}/{filename}")
+ input_img_list = sorted(glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]')))
+
+ print("extracting landmarks...")
+ coord_list, frame_list = get_landmark_and_bbox(input_img_list, self.bbox_shift)
+ input_latent_list = []
+ idx = -1
+ # maker if the bbox is not sufficient
+ coord_placeholder = (0.0,0.0,0.0,0.0)
+ for bbox, frame in zip(coord_list, frame_list):
+ idx = idx + 1
+ if bbox == coord_placeholder:
+ continue
+ x1, y1, x2, y2 = bbox
+ crop_frame = frame[y1:y2, x1:x2]
+ resized_crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
+ latents = vae.get_latents_for_unet(resized_crop_frame)
+ input_latent_list.append(latents)
+
+ self.frame_list_cycle = frame_list + frame_list[::-1]
+ self.coord_list_cycle = coord_list + coord_list[::-1]
+ self.input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
+ self.mask_coords_list_cycle = []
+ self.mask_list_cycle = []
+
+ for i,frame in enumerate(tqdm(self.frame_list_cycle)):
+ cv2.imwrite(f"{self.full_imgs_path}/{str(i).zfill(8)}.png",frame)
+
+ face_box = self.coord_list_cycle[i]
+ mask,crop_box = get_image_prepare_material(frame,face_box)
+ cv2.imwrite(f"{self.mask_out_path}/{str(i).zfill(8)}.png",mask)
+ self.mask_coords_list_cycle += [crop_box]
+ self.mask_list_cycle.append(mask)
+
+ with open(self.mask_coords_path, 'wb') as f:
+ pickle.dump(self.mask_coords_list_cycle, f)
+
+ with open(self.coords_path, 'wb') as f:
+ pickle.dump(self.coord_list_cycle, f)
+
+ torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path))
+ #
+
+ def process_frames(self,
+ res_frame_queue,
+ video_len,
+ skip_save_images):
+ print(video_len)
+ while True:
+ if self.idx>=video_len-1:
+ break
+ try:
+ start = time.time()
+ res_frame = res_frame_queue.get(block=True, timeout=1)
+ except queue.Empty:
+ continue
+
+ bbox = self.coord_list_cycle[self.idx%(len(self.coord_list_cycle))]
+ ori_frame = copy.deepcopy(self.frame_list_cycle[self.idx%(len(self.frame_list_cycle))])
+ x1, y1, x2, y2 = bbox
+ try:
+ res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
+ except:
+ continue
+ mask = self.mask_list_cycle[self.idx%(len(self.mask_list_cycle))]
+ mask_crop_box = self.mask_coords_list_cycle[self.idx%(len(self.mask_coords_list_cycle))]
+ #combine_frame = get_image(ori_frame,res_frame,bbox)
+ combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
+
+ if skip_save_images is False:
+ cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png",combine_frame)
+ self.idx = self.idx + 1
+
+ def inference(self,
+ audio_path,
+ out_vid_name,
+ fps,
+ skip_save_images):
+ os.makedirs(self.avatar_path+'/tmp',exist_ok =True)
+ print("start inference")
+
+ gst_pipeline = GStreamerPipeline(width=640, height=480, fps=fps, host="127.0.0.1", port=5000)
+ audio_sender = GStreamerAudio()
+ ############################################## extract audio feature ##############################################
+ start_time = time.time()
+ whisper_feature = audio_processor.audio2feat(audio_path)
+ whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
+
+ total_iters = int(np.ceil(float(len(whisper_chunks)) / self.batch_size))
+ audio_reader = FFmpegAudioReader(audio_path)
+ audio_data = audio_reader.read_full_audio()
+ audio_chunks = split_audio(audio_data, total_iters)
+ # for i, chunk in enumerate(audio_chunks):
+ # print(f"🎧 播放第 {i+1}/{len(audio_chunks)} 个音频片段")
+ # play_audio_chunk(chunk, original_sample_rate)
+ # print(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms")
+
+ ############################################## inference batch by batch ##############################################
+ video_num = len(whisper_chunks)
+ gen = datagen(whisper_chunks, self.input_latent_list_cycle, self.batch_size)
+
+ frame_count = 0
+ start_time = time.time() # 记录第一帧推理开始时间
+
+ for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=int(np.ceil(float(video_num) / self.batch_size)))):
+ # 处理音频特征
+ audio_feature_batch = torch.from_numpy(whisper_batch)
+ audio_feature_batch = audio_feature_batch.to(device=unet.device, dtype=unet.model.dtype)
+ audio_feature_batch = pe(audio_feature_batch)
+
+ # 处理 `latent_batch`
+ latent_batch = latent_batch.to(dtype=unet.model.dtype)
+
+ # 运行 UNet 生成嘴型动画
+ pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
+ recon = vae.decode_latents(pred_latents)
+
+ audio_sender.send_audio(audio_chunks[i])
+ # 逐帧推送到 GStreamer
+ for j, res_frame in enumerate(recon):
+ frame_count += 1
+ print("✅ 正在推送视频帧...")
+ gst_pipeline.send_frame(res_frame) # **顺序推送**
+ # 计算 FPS
+ # elapsed_time = time.time() - start_time
+ # if elapsed_time > 0:
+ # fps_estimate = frame_count / elapsed_time
+ # print(f"当前直播帧率: {fps_estimate:.2f} FPS", end='\r')
+
+ ##############################################
+ # Step 4: 结束后计算最终 FPS
+ ##############################################
+ total_elapsed_time = time.time() - start_time
+ print(f"\nelapsed_time: {total_elapsed_time:.2f} s")
+ print(f"\nframe_count: {frame_count:.2f} s")
+ avg_fps = frame_count / total_elapsed_time if total_elapsed_time > 0 else 0
+ print(f"\n最终计算得到的平均帧率: {avg_fps:.2f} FPS")
+ gst_pipeline.stop()
+ audio_sender.stop()
+
+
+
+
+if __name__ == "__main__":
+ '''
+ This script is used to simulate online chatting and applies necessary pre-processing such as face detection and face parsing in advance. During online chatting, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
+ '''
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--inference_config",
+ type=str,
+ default="configs/inference/realtime.yaml",
+ )
+ parser.add_argument("--fps",
+ type=int,
+ default=25,
+ )
+ parser.add_argument("--batch_size",
+ type=int,
+ default=4,
+ )
+ parser.add_argument("--skip_save_images",
+ action="store_true",
+ help="Whether skip saving images for better generation speed calculation",
+ )
+
+ args = parser.parse_args()
+
+ inference_config = OmegaConf.load(args.inference_config)
+ print(inference_config)
+
+
+ for avatar_id in inference_config:
+ data_preparation = inference_config[avatar_id]["preparation"]
+ video_path = inference_config[avatar_id]["video_path"]
+ bbox_shift = inference_config[avatar_id]["bbox_shift"]
+ avatar = Avatar(
+ avatar_id = avatar_id,
+ video_path = video_path,
+ bbox_shift = bbox_shift,
+ batch_size = args.batch_size,
+ preparation= data_preparation)
+
+ audio_clips = inference_config[avatar_id]["audio_clips"]
+ for audio_num, audio_path in audio_clips.items():
+ print("Inferring using:",audio_path)
+ avatar.inference(audio_path,
+ audio_num,
+ args.fps,
+ args.skip_save_images)
diff --git a/scripts/realtime_stream_gst_15.py b/scripts/realtime_stream_gst_15.py
new file mode 100644
index 00000000..b92f52ec
--- /dev/null
+++ b/scripts/realtime_stream_gst_15.py
@@ -0,0 +1,1455 @@
+import argparse
+import os
+from omegaconf import OmegaConf
+import numpy as np
+import cv2
+import torch
+import glob
+import pickle
+import sys
+from tqdm import tqdm
+import copy
+import json
+import shutil
+import threading
+import queue
+import time
+from termcolor import colored # Add this import
+import subprocess
+import io
+import concurrent.futures
+import traceback
+import logging
+from dotenv import load_dotenv # Keep this at the very top!
+from PIL import Image
+import tempfile
+import ffmpeg
+import soundfile as sf # Added for sf.write, assumed to be imported
+import cpuinfo # For detailed CPU information
+
+print("Script started!")
+
+# --- Load environment variables FIRST to make them available globally ---
+load_dotenv()
+logging.info("Environment variables loaded from .env file.")
+
+# --- Global Configuration from .env or Hardcoded Fallbacks ---
+# These values will now be sourced directly from the .env file if set,
+# otherwise they will fall back to the provided default.
+# This makes them accessible globally without needing 'args.' prefix.
+
+# GStreamer related paths and parameters
+GSTREAMER_LAUNCH_PATH = os.getenv("GSTREAMER_LAUNCH_PATH", "gst-launch-1.0") # Global path for gst-launch
+# NEW - Add this sanitization immediately after
+if "GSTREAMER_LAUNCH_PATH=" in GSTREAMER_LAUNCH_PATH:
+ # Remove the erroneous key prefix if it exists
+ GSTREAMER_LAUNCH_PATH = GSTREAMER_LAUNCH_PATH.split("GSTREAMER_LAUNCH_PATH=", 1)[-1]
+ logging.warning(f"⚠️ Cleaned malformed GSTREAMER_LAUNCH_PATH value. Now: {GSTREAMER_LAUNCH_PATH}")
+
+# Also add debug logging to see what was actually loaded
+logging.info(f"🔍 Final GSTREAMER_LAUNCH_PATH: '{GSTREAMER_LAUNCH_PATH}'")
+STREAM_PIPE_PATH = os.getenv("STREAM_PIPE_PATH", "./hot_file.opus")
+TARGET_FPS = int(os.getenv("TARGET_FPS", "25")) # Use for general FPS calculations
+FRAME_SKIP_THRESHOLD = int(os.getenv("FRAME_SKIP_THRESHOLD", "3")) # Use for consumer queue management
+OVERLOAD_MULTIPLIER = float(os.getenv("OVERLOAD_MULTIPLIER", "1.2"))
+MAX_FRAMES_TO_SKIP = int(os.getenv("MAX_FRAMES_TO_SKIP", "1"))
+
+# MuseTalk Model Paths (primarily from .env)
+MUSE_VERSION = os.getenv("MUSE_VERSION", "v15")
+FFMPEG_PATH = os.getenv("FFMPEG_PATH", "./ffmpeg-4.4-amd64-static/")
+GPU_ID = int(os.getenv("GPU_ID", "0"))
+VAE_TYPE = os.getenv("VAE_TYPE", "sd-vae")
+UNET_CONFIG_PATH = os.getenv("UNET_CONFIG", "./models/musetalkV15/musetalk.json")
+UNET_MODEL_PATH = os.getenv("UNET_MODEL_PATH", "./models/musetalkV15/unet.pth")
+WHISPER_DIR = os.getenv("WHISPER_DIR", "./models/whisper")
+RESULT_DIR = os.getenv("RESULT_DIR", './results')
+
+# Avatar/Inference specific parameters (primarily from .env)
+EXTRA_MARGIN = int(os.getenv("EXTRA_MARGIN", "10"))
+AUDIO_PADDING_LENGTH_LEFT = int(os.getenv("AUDIO_PADDING_LEFT", "2"))
+AUDIO_PADDING_LENGTH_RIGHT = int(os.getenv("AUDIO_PADDING_RIGHT", "2"))
+BATCH_SIZE = int(os.getenv("BATCH_SIZE", "4"))
+PARSING_MODE = os.getenv("PARSING_MODE", 'jaw')
+LEFT_CHEEK_WIDTH = int(os.getenv("LEFT_CHEEK_WIDTH", "90"))
+RIGHT_CHEEK_WIDTH = int(os.getenv("RIGHT_CHEEK_WIDTH", "90"))
+AVATAR_CONFIG_PATH = os.getenv("AVATAR_CONFIG_PATH", "configs/inference/realtime.yaml")
+AVATAR_ID_TO_USE = os.getenv("AVATAR_ID_TO_USE", "default_avatar_id")
+
+
+# --- Set up basic logging ---
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', force=True) # Add force=True if Python >= 3.8
+print("Logging configured!")
+
+# --- PyTorch Device Setup & System Information Block [CORRECTED] ---
+logging.info("--- System & Hardware Information ---")
+
+# PyTorch Version
+try:
+ torch_version = torch.__version__
+ logging.info(colored(f"PyTorch Version: {torch_version}", 'cyan'))
+except Exception as e:
+ logging.warning(f"Could not determine PyTorch version: {e}")
+
+# CPU Information
+try:
+ cpu_info = cpuinfo.get_cpu_info()
+ cpu_brand = cpu_info.get('brand_raw', 'N/A')
+ cpu_hz = cpu_info.get('hz_actual_friendly', 'N/A')
+ logging.info(colored(f"CPU: {cpu_brand} @ {cpu_hz}", 'cyan'))
+except Exception as e:
+ logging.warning(f"Could not determine CPU info: {e}")
+
+
+# --- FIX: Define the 'device' object here, before it is used ---
+cuda_available = torch.cuda.is_available()
+if cuda_available:
+ device = torch.device(f"cuda:{GPU_ID}")
+else:
+ device = torch.device("cpu")
+# -------------------------------------------------------------
+
+
+# CUDA and GPU Information
+if cuda_available:
+ try:
+ # Get general device properties
+ properties = torch.cuda.get_device_properties(device)
+ gpu_name = properties.name
+ total_vram_gb = properties.total_memory / (1024**3)
+
+ # Get current memory usage
+ free_vram_bytes, total_vram_bytes = torch.cuda.mem_get_info(device)
+ used_vram_gb = (total_vram_bytes - free_vram_bytes) / (1024**3)
+ available_vram_gb = free_vram_bytes / (1024**3)
+
+ # CUDA Version
+ cuda_version = torch.version.cuda
+
+ # Log formatted, color-coded information
+ logging.info(colored(f"CUDA Version: {cuda_version}", 'green'))
+ logging.info(colored(f"GPU: {gpu_name} (ID: {GPU_ID})", 'green', attrs=['bold']))
+ logging.info(colored(f"VRAM: {used_vram_gb:.2f} GB Used / {available_vram_gb:.2f} GB Available / {total_vram_gb:.2f} GB Total", 'green'))
+
+ except Exception as e:
+ logging.warning(colored(f"⚠️ Could not retrieve detailed GPU/CUDA information: {e}", 'yellow'))
+else:
+ logging.info(colored("GPU: ❌ CUDA (GPU) not available. Using CPU.", 'red', attrs=['bold']))
+
+# Selected Device
+logging.info(f"Selected Device: {device}")
+logging.info("---------------------------------------")
+
+def get_optimal_gstreamer_pipeline(gpu_name, width, height, fps, host, port, bitrate=5000):
+ """
+ Returns the optimal GStreamer pipeline string based on detected GPU.
+
+ Args:
+ gpu_name (str): Name of the GPU from torch.cuda.get_device_properties()
+ width, height, fps: Video parameters
+ host, port: Streaming destination
+ bitrate: Target bitrate in kbps
+
+ Returns:
+ str: GStreamer pipeline string optimized for the detected GPU
+ """
+
+ # Detect RTX 50-series (Blackwell) - these need D3D11 pipeline
+ if "RTX 50" in gpu_name or "5090" in gpu_name or "5080" in gpu_name:
+ logging.info(f"🎮 Detected RTX 50-series GPU ({gpu_name}). Using D3D11 hardware acceleration.")
+ return (
+ f"fdsrc fd=0 do-timestamp=true is-live=true ! "
+ f"videoparse format=bgr width={width} height={height} framerate={fps}/1 ! "
+ "queue max-size-buffers=1 leaky=downstream ! "
+ "videoconvert ! video/x-raw,format=NV12 ! "
+ "d3d11upload ! "
+ "queue max-size-buffers=1 leaky=downstream ! "
+ # FIXED: Use correct mfh264enc properties
+ f"mfh264enc rc-mode=cbr bitrate={bitrate * 1000} low-latency=true max-bitrate={bitrate * 1000} ! "
+ "h264parse ! rtph264pay pt=96 config-interval=1 ! "
+ f"udpsink host={host} port={port} sync=true async=false"
+ )
+
+
+ # RTX 40-series and earlier - use CUDA pipeline
+ elif "RTX" in gpu_name or "GTX" in gpu_name or "Tesla" in gpu_name or "Quadro" in gpu_name:
+ logging.info(f"🎮 Detected NVIDIA GPU with CUDA support ({gpu_name}). Using CUDA pipeline.")
+ return (
+ f"fds qrc fd=0 do-timestamp=true is-live=true ! "
+ f"videoparse format=bgr width={width} height={height} framerate={fps}/1 ! "
+ "queue max-size-buffers=1 leaky=downstream ! "
+ "cudaupload ! "
+ "queue max-size-buffers=1 leaky=downstream ! cudaconvert ! videorate ! "
+ "video/x-raw(memory:CUDAMemory),format=NV12 ! "
+ "queue max-size-buffers=1 leaky=downstream ! "
+ f"nvh264enc preset=p1 tune=ultra-low-latency zerolatency=true rc-mode=cbr bitrate={bitrate} ! "
+ "h264parse ! rtph264pay pt=96 config-interval=1 ! "
+ f"udpsink host={host} port={port} sync=true async=false"
+ )
+
+ # Fallback to CPU encoding for unknown/non-NVIDIA GPUs
+ else:
+ logging.warning(f"⚠️ Unknown GPU type ({gpu_name}). Falling back to CPU-based x264 encoding.")
+ return (
+ f"fdsrc fd=0 do-timestamp=true is-live=true ! "
+ f"videoparse format=bgr width={width} height={height} framerate={fps}/1 ! "
+ "queue max-size-buffers=1 leaky=downstream ! "
+ "videoconvert ! video/x-raw,format=I420 ! "
+ "x264enc tune=zerolatency speed-preset=ultrafast bitrate=5000 ! "
+ "h264parse ! rtph264pay pt=96 config-interval=1 ! "
+ f"udpsink host={host} port={port} sync=true async=false"
+ )
+
+# --- Platform-specific imports for enhanced functionality ---
+if sys.platform == "win32":
+ try:
+ import psutil
+ p = psutil.Process(os.getpid())
+ p.nice(psutil.HIGH_PRIORITY_CLASS)
+ logging.info("INFO: Process priority set to HIGH on Windows (if psutil installed and permitted).")
+ except ImportError:
+ logging.warning("Warning: psutil not found. Cannot set process priority.")
+ except Exception as e:
+ logging.warning(f"Warning: Could not set process priority: {e}")
+
+
+# --- MuseTalk Specific Imports (Ensure these are in your PYTHONPATH) ---
+try:
+ from transformers import WhisperModel
+ from musetalk.utils.face_parsing import FaceParsing
+ from musetalk.utils.utils import datagen, load_all_model
+ from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs
+ from musetalk.utils.blending import get_image_prepare_material, get_image_blending
+ from musetalk.utils.audio_processor import AudioProcessor
+except ImportError as e:
+ logging.critical(f"Error importing MuseTalk utilities: {e}. Ensure the library is installed and 'musetalk' package is in your PYTHONPATH.", exc_info=True)
+ sys.exit(1)
+
+# --- Global Models and Device (initialized in __main__) ---
+# These will be initialized once and made globally accessible for use by Avatar class and worker threads.
+# They are declared here to be accessible, but will be assigned values in main.
+vae = None
+unet = None
+pe = None
+timesteps = None
+audio_processor = None
+whisper = None
+fp = None # FaceParsing instance
+weight_dtype = torch.float16 # Default to FP16 as per MuseTalk 1.5 optimizations
+
+# --- Utility Functions ---
+def fast_check_ffmpeg():
+ """Checks if ffmpeg is accessible from the system's PATH."""
+ try:
+ subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True, timeout=5)
+ return True
+ except (subprocess.CalledProcessError, FileNotFoundError, subprocess.TimeoutExpired):
+ return False
+
+def video2imgs(vid_path, save_path):
+ """Extracts frames from a video file and saves them as PNG images."""
+ logging.info(f"Extracting frames from {vid_path} to {save_path}...")
+ cap = cv2.VideoCapture(vid_path)
+ if not cap.isOpened():
+ logging.error(f"Error: Could not open video file: {vid_path}")
+ return
+ count = 0
+ while True:
+ ret, frame = cap.read()
+ if not ret: break
+ cv2.imwrite(os.path.join(save_path, f"{str(count).zfill(8)}.png"), frame)
+ count += 1
+ cap.release()
+ logging.info(f"Finished extracting {count} frames.")
+
+def osmakedirs(path_list):
+ """Creates directories if they don't exist, using exist_ok=True for robustness."""
+ for path in path_list:
+ os.makedirs(path, exist_ok=True)
+
+def _log_subprocess_output(pipe, logger_func, prefix):
+ """Reads output from a pipe line by line and logs it."""
+ try:
+ for line_bytes in iter(pipe.readline, b''):
+ logger_func(f"[{prefix}]: {line_bytes.decode(errors='ignore').rstrip()}")
+ except ValueError: # Pipe might close during readline
+ pass
+ except Exception as e:
+ logging.error(f"Error reading from {prefix} pipe: {e}", exc_info=True)
+ finally:
+ if pipe and not pipe.closed:
+ pipe.close()
+
+# --- FFmpeg Audio Reader (User's custom class for flexible audio input) ---
+class FFmpegAudioReader:
+ """Uses FFmpeg to read an audio source (file path or bytes) and convert it to raw PCM."""
+ def __init__(self, audio_source):
+ self.audio_source = audio_source
+ self.is_file_path = isinstance(audio_source, str)
+
+ def read_full_audio(self):
+ """Reads the entire audio source and converts it to PCM s16le, 48kHz, Stereo."""
+ logging.info(f"Reading and converting audio from {'file' if self.is_file_path else 'memory'}...")
+ target_sr, target_ac, target_format = 48000, 2, "s16le" # Target format for GStreamer audio pipeline
+ input_data = None
+
+ ffmpeg_input_args = {}
+ if not self.is_file_path:
+ input_filename = 'pipe:0' # Read from stdin if bytes are provided
+ input_data = self.audio_source
+ else:
+ input_filename = self.audio_source
+
+ try:
+ out, err = (
+ ffmpeg
+ .input(input_filename, **ffmpeg_input_args)
+ .output('pipe:', format=target_format, ac=target_ac, ar=target_sr)
+ .run(capture_stdout=True, capture_stderr=True, input=input_data, quiet=True)
+ )
+ if err:
+ logging.debug(f"FFmpeg stderr: {err.decode(errors='ignore')}")
+ except ffmpeg.Error as e:
+ logging.error(f"❌ FFmpeg error during audio conversion: {e.stderr.decode(errors='ignore') if e.stderr else 'Unknown FFmpeg error'}")
+ return None
+ except Exception as e:
+ logging.error(f"❌ Unexpected error during FFmpeg execution for audio: {e}", exc_info=True)
+ return None
+
+ if not out:
+ logging.error("❌ Failed to read audio: FFmpeg produced no PCM data!")
+ return None
+
+ audio_data = np.frombuffer(out, dtype=np.int16).reshape(-1, target_ac)
+ logging.info(f"✅ Read and converted audio: {len(audio_data)} samples at {target_sr}Hz, {target_ac}ch.")
+ return audio_data
+
+# --- GStreamer Classes (User's custom classes for real-time streaming) ---
+class GStreamerPipeline:
+ """Manages the GStreamer video pipeline subprocess."""
+ def __init__(self, width=1280, height=720, fps=TARGET_FPS, host="127.0.0.1", port=5000):
+ self.width, self.height, self.fps, self.host, self.port = width, height, fps, host, port
+ self.process = None
+ self.stdout_thread = None
+ self.stderr_thread = None
+ self.is_running = False
+
+ # --- DYNAMIC PIPELINE SELECTION ---
+ # Get GPU name from global detection
+ gpu_name = ""
+ if torch.cuda.is_available():
+ try:
+ gpu_name = torch.cuda.get_device_name(device)
+ except:
+ gpu_name = "Unknown"
+
+ # Generate optimal pipeline for detected hardware
+ pipeline_str = get_optimal_gstreamer_pipeline(
+ gpu_name=gpu_name,
+ width=self.width,
+ height=self.height,
+ fps=self.fps,
+ host=self.host,
+ port=self.port,
+ bitrate=5000
+ )
+
+ logging.info(f"Attempting to start GStreamer video pipeline ({self.width}x{self.height}@{self.fps}fps) to {self.host}:{self.port}...")
+ env_vars = os.environ.copy()
+
+
+
+ try:
+ # NEW - FIXED
+ command_to_run = f'"{GSTREAMER_LAUNCH_PATH}" -v {pipeline_str}'
+
+ logging.info(f"DEBUG: GStreamer VIDEO command (Popen string): {command_to_run}")
+
+ self.process = subprocess.Popen(
+ command_to_run,
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ shell=True,
+ bufsize=0,
+ env=env_vars
+ )
+
+ # Check if process actually started (pid is assigned)
+ if self.process.pid:
+ logging.info(colored(f"✅ GStreamer video process launched (PID: {self.process.pid}).", 'green', attrs=['bold']))
+ self.is_running = True
+ else:
+ logging.error(colored(f"❌ GStreamer video process failed to launch, PID not assigned.", 'red', attrs=['bold']))
+ self.process = None # Ensure it's None if not truly launched
+ self.is_running = False
+
+ # Start threads to read stdout and stderr asynchronously *only if process launched*
+ if self.is_running:
+ self.stdout_thread = threading.Thread(
+ target=_log_subprocess_output,
+ args=(self.process.stdout, logging.info, "GST_VIDEO_STDOUT"),
+ daemon=True
+ )
+ self.stderr_thread = threading.Thread(
+ target=_log_subprocess_output,
+ args=(self.process.stderr, logging.error, "GST_VIDEO_STDERR"),
+ daemon=True
+ )
+ self.stdout_thread.start()
+ self.stderr_thread.start()
+
+ except Exception as e:
+ logging.error(colored(f"❌ Failed to start GStreamer video pipeline: {e}", 'red', attrs=['bold']), exc_info=True)
+ self.process = None
+ self.is_running = False
+
+ def send_frame(self, frame):
+ """Sends a NumPy array frame to the GStreamer pipeline's stdin."""
+ if not self.process or not self.is_running or self.process.stdin.closed: # Check is_running here too
+ logging.debug("GStreamer video pipeline not running or stdin closed. Cannot send frame.")
+ return False
+ try:
+ if not frame.flags['C_CONTIGUOUS']:
+ frame = np.ascontiguousarray(frame, dtype=np.uint8)
+ self.process.stdin.write(frame.tobytes())
+ self.process.stdin.flush()
+ logging.debug(colored("Sent video frame to GStreamer.", 'cyan')) # Log successful send
+ return True
+ except (BrokenPipeError, OSError):
+ logging.error(colored("❌ GStreamer video pipeline: Broken pipe. The process may have crashed.", 'red', attrs=['bold']))
+ self.stop()
+ return False
+ except Exception as e:
+ logging.error(colored(f"❌ Error pushing video frame: {e}", 'red', attrs=['bold']), exc_info=True)
+ return False
+
+ def stop(self):
+ """Stops the GStreamer video subprocess gracefully."""
+ if self.process:
+ logging.info(f"Stopping GStreamer video pipeline (PID: {self.process.pid})...")
+ proc_to_stop, self.process = self.process, None
+ if proc_to_stop.stdin and not proc_to_stop.stdin.closed:
+ try: proc_to_stop.stdin.close()
+ except Exception: pass
+ proc_to_stop.terminate()
+ try:
+ proc_to_stop.wait(timeout=3)
+ logging.info(colored(f"✅ GStreamer video process terminated.", 'green'))
+ except subprocess.TimeoutExpired:
+ logging.warning(colored(f"⚠️ GStreamer video process did not terminate gracefully, killing...", 'yellow'))
+ proc_to_stop.kill()
+
+ if self.stdout_thread and self.stdout_thread.is_alive():
+ self.stdout_thread.join(timeout=1)
+ if self.stderr_thread and self.stderr_thread.is_alive():
+ self.stderr_thread.join(timeout=1)
+ logging.info("GStreamer video pipeline stop complete.")
+
+
+class GStreamerAudio:
+ """Manages the GStreamer audio pipeline subprocess."""
+ def __init__(self, host="127.0.0.1", port=5001, sample_rate=48000, channels=2):
+ self.host, self.port, self.sample_rate, self.channels = host, port, sample_rate, channels
+ self.process = None
+ self.stdout_thread = None
+ self.stderr_thread = None
+ self.is_running = False # New flag
+
+ pipeline_str = (
+ f"fdsrc fd=0 do-timestamp=true is-live=true ! "
+ "queue max-size-buffers=2 leaky=downstream ! "
+ f"audio/x-raw,format=S16LE,channels={self.channels},rate={self.sample_rate},layout=interleaved ! "
+ "audioconvert ! audioresample ! "
+ "opusenc bitrate=64000 ! rtpopuspay pt=97 ! "
+ f"udpsink host={self.host} port={self.port} sync=true"
+ )
+ logging.info(f"Attempting to start GStreamer audio pipeline ({self.sample_rate}Hz, {self.channels}ch) to {self.host}:{self.port}...")
+
+ env_vars = os.environ.copy()
+ env_vars['GST_DEBUG'] = '3'
+
+ # In GStreamerAudio.__init__(), around line 378:
+ try:
+ command_to_run = f'"{GSTREAMER_LAUNCH_PATH}" -v {pipeline_str}'
+ logging.debug(f"DEBUG: GStreamer AUDIO command (Popen string): {command_to_run}")
+
+ self.process = subprocess.Popen(
+ command_to_run,
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ shell=True,
+ bufsize=0,
+ env=env_vars
+ )
+
+
+ if self.process.pid:
+ logging.info(colored(f"✅ GStreamer audio process launched (PID: {self.process.pid}).", 'green', attrs=['bold']))
+ self.is_running = True
+ else:
+ logging.error(colored(f"❌ GStreamer audio process failed to launch, PID not assigned.", 'red', attrs=['bold']))
+ self.process = None
+ self.is_running = False
+
+ if self.is_running:
+ self.stdout_thread = threading.Thread(
+ target=_log_subprocess_output,
+ args=(self.process.stdout, logging.info, "GST_AUDIO_STDOUT"),
+ daemon=True
+ )
+ self.stderr_thread = threading.Thread(
+ target=_log_subprocess_output,
+ args=(self.process.stderr, logging.error, "GST_AUDIO_STDERR"),
+ daemon=True
+ )
+ self.stdout_thread.start()
+ self.stderr_thread.start()
+
+ except Exception as e:
+ logging.error(colored(f"❌ Failed to start GStreamer audio pipeline: {e}", 'red', attrs=['bold']), exc_info=True)
+ self.process = None
+ self.is_running = False
+
+ def send_audio(self, audio_data_pcm):
+ """Sends raw PCM audio data to the GStreamer pipeline's stdin."""
+ if not self.process or not self.is_running or self.process.stdin.closed: # Check is_running here too
+ logging.debug("GStreamer audio pipeline not running or stdin closed. Cannot send audio.")
+ return False
+ try:
+ self.process.stdin.write(audio_data_pcm.tobytes())
+ self.process.stdin.flush()
+ logging.debug(colored("Sent audio chunk to GStreamer.", 'cyan')) # Log successful send
+ return True
+ except (BrokenPipeError, OSError):
+ logging.error(colored("❌ GStreamer audio pipeline: Broken pipe.", 'red', attrs=['bold']))
+ self.stop()
+ return False
+ except Exception as e:
+ logging.error(colored(f"❌ Error pushing audio chunk: {e}", 'red', attrs=['bold']), exc_info=True)
+ return False
+
+ def stop(self):
+ """Stops the GStreamer audio subprocess gracefully."""
+ if self.process:
+ logging.info(f"Stopping GStreamer audio pipeline (PID: {self.process.pid})...")
+ proc_to_stop, self.process = self.process, None
+ if proc_to_stop.stdin and not proc_to_stop.stdin.closed:
+ try: proc_to_stop.stdin.close()
+ except Exception: pass
+ proc_to_stop.terminate()
+ try:
+ proc_to_stop.wait(timeout=3)
+ logging.info(colored("✅ GStreamer audio process terminated.", 'green'))
+ except subprocess.TimeoutExpired:
+ logging.warning(colored("⚠️ GStreamer audio process did not terminate gracefully, killing...", 'yellow'))
+ proc_to_stop.kill()
+
+ if self.stdout_thread and self.stdout_thread.is_alive():
+ self.stdout_thread.join(timeout=1)
+ if self.stderr_thread and self.stderr_thread.is_alive():
+ self.stderr_thread.join(timeout=1)
+ logging.info("GStreamer audio pipeline stop complete.")
+
+
+# --- Avatar Class (Unified logic from both original and user's script) ---
+@torch.no_grad()
+class Avatar:
+ def __init__(self, avatar_id, video_path, bbox_shift, batch_size, preparation, version_str, extra_margin=0, parsing_mode='jaw'):
+ logging.info(f"Initializing Avatar: {avatar_id}")
+ self.avatar_id = str(avatar_id)
+ self.video_path = video_path
+ self.bbox_shift = bbox_shift
+ self.batch_size = batch_size
+ self.preparation = preparation
+ self.version_str = version_str
+ self.extra_margin = extra_margin
+ self.parsing_mode = parsing_mode
+
+ # Define paths (adapted for v15 structure based on version_str)
+ if self.version_str == "v15":
+ self.avatar_base_path = os.path.join(RESULT_DIR, self.version_str, "avatars", self.avatar_id)
+ else: # v1
+ self.avatar_base_path = os.path.join(RESULT_DIR, "avatars", self.avatar_id)
+
+ self.full_imgs_path = os.path.join(self.avatar_base_path, "full_imgs")
+ self.mask_out_path = os.path.join(self.avatar_base_path, "masks")
+ self.coords_path = os.path.join(self.avatar_base_path, "coords.pkl")
+ self.latents_out_path = os.path.join(self.avatar_base_path, "latents.pt")
+ self.mask_coords_path = os.path.join(self.avatar_base_path, "mask_coords.pkl")
+ self.avatar_info_path = os.path.join(self.avatar_base_path, "avatar_info.json")
+
+ self.input_latent_list_cycle = []
+ self.coord_list_cycle = []
+ self.frame_list_cycle = []
+ self.mask_coords_list_cycle = []
+ self.mask_list_cycle = []
+
+ self.idx = 0
+
+ self.init_avatar_data()
+ # --- NEW OPTIMIZATION: PRE-PROCESSING STEP ---
+ logging.info("Pre-processing avatar data for optimized performance...")
+
+ # Pre-convert frames from BGR NumPy arrays to RGB PIL Images
+ self.body_pil_cycle = [
+ Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
+ for frame in self.frame_list_cycle
+ ]
+
+ # Pre-convert masks to grayscale PIL Images
+ self.mask_pil_cycle = [
+ Image.fromarray(mask).convert("L")
+ for mask in self.mask_list_cycle
+ ]
+
+ logging.info("✅ Avatar data pre-processing complete.")
+ logging.info(f"✅ Avatar '{self.avatar_id}' initialized with {len(self.frame_list_cycle)} reference frames.")
+
+ def init_avatar_data(self):
+ """Initializes or reloads avatar data based on 'preparation' flag."""
+ if self.preparation:
+ if os.path.exists(self.avatar_base_path):
+ response = input(f"Avatar '{self.avatar_id}' data exists. Re-create all material? (y/n): ").strip().lower()
+ if response == "y":
+ logging.info(f"User chose to re-create. Removing: {self.avatar_base_path}")
+ shutil.rmtree(self.avatar_base_path)
+ self._prepare_material_core()
+ else:
+ logging.info("Loading existing data as per user request.")
+ self._reload_prepared_data()
+ else:
+ self._prepare_material_core()
+ else:
+ logging.info("Preparation=False. Loading existing prepared data...")
+ self._reload_prepared_data()
+
+ def _reload_prepared_data(self):
+ """Loads pre-processed avatar data from disk."""
+ logging.info(f"Reloading prepared data from: {self.avatar_base_path}")
+ try:
+ loaded_latents = torch.load(self.latents_out_path, map_location='cpu')
+ self.input_latent_list_cycle = list(loaded_latents) if isinstance(loaded_latents, torch.Tensor) else loaded_latents
+ with open(self.coords_path, 'rb') as f: self.coord_list_cycle = pickle.load(f)
+ with open(self.mask_coords_path, 'rb') as f: self.mask_coords_list_cycle = pickle.load(f)
+
+ num_items = len(self.coord_list_cycle)
+ if num_items == 0: raise ValueError("Loaded coordinate data is empty.")
+
+ frame_files = [os.path.join(self.full_imgs_path, f"{str(i).zfill(8)}.png") for i in range(num_items)]
+ self.frame_list_cycle = read_imgs(frame_files)
+ mask_files = [os.path.join(self.mask_out_path, f"{str(i).zfill(8)}.png") for i in range(num_items)]
+ self.mask_list_cycle = read_imgs(mask_files)
+
+ data_map = {
+ "Latents": self.input_latent_list_cycle, "Coords": self.coord_list_cycle,
+ "Frames": self.frame_list_cycle, "Masks": self.mask_list_cycle, "MaskCoords": self.mask_coords_list_cycle
+ }
+ if not all(len(lst) == num_items for lst in data_map.values()):
+ lengths = {name: len(lst) for name, lst in data_map.items()}
+ raise ValueError(f"Data lists have mismatched lengths after loading: {lengths}")
+
+ with open(self.avatar_info_path, "r") as f:
+ avatar_info_loaded = json.load(f)
+ if avatar_info_loaded.get('bbox_shift') != self.bbox_shift or \
+ avatar_info_loaded.get('version') != self.version_str or \
+ avatar_info_loaded.get('extra_margin') != self.extra_margin or \
+ avatar_info_loaded.get('parsing_mode') != self.parsing_mode:
+
+ logging.warning("Avatar config (bbox_shift, version, extra_margin, or parsing_mode) has changed.")
+ response = input(f"Config change detected. Re-create avatar materials? (y/n) (Old: {avatar_info_loaded}, New: {{'bbox_shift': {self.bbox_shift}, 'version': '{self.version_str}', 'extra_margin': {self.extra_margin}, 'parsing_mode': '{self.parsing_mode}'}}): ").strip().lower()
+ if response == "y":
+ logging.info("User chose to re-create due to config change.")
+ shutil.rmtree(self.avatar_base_path)
+ self._prepare_material_core()
+ else:
+ logging.info("Continuing with old avatar data despite config change. This might lead to unexpected results.")
+
+ except FileNotFoundError:
+ logging.critical(f"❌ Prepared data not found for avatar '{self.avatar_id}'. You must run with preparation=True first.", exc_info=True)
+ raise SystemExit(f"Exiting: Prepared data not found for {self.avatar_id}.")
+ except Exception as e:
+ logging.critical(f"❌ Error reloading prepared data for avatar '{self.avatar_id}'. You may need to run with preparation=True.", exc_info=True)
+ raise SystemExit(f"Exiting: Failed to reload data for {self.avatar_id}.")
+
+ @torch.no_grad()
+ def _prepare_material_core(self):
+ logging.info(f"--- Preparing new material for avatar: {self.avatar_id} ---")
+ osmakedirs([self.avatar_base_path, self.full_imgs_path, self.mask_out_path])
+
+ # Store avatar info (unchanged)
+ avatar_info_data = {
+ "avatar_id": self.avatar_id,
+ "video_path": self.video_path,
+ "bbox_shift": self.bbox_shift,
+ "version": self.version_str,
+ "extra_margin": self.extra_margin,
+ "parsing_mode": self.parsing_mode
+ }
+
+ with open(self.avatar_info_path, "w") as f:
+ json.dump(avatar_info_data, f)
+
+ # 1. Extract frames from video or copy from image folder
+ if os.path.isfile(self.video_path):
+ logging.debug(f"DEBUG: Starting video frame extraction from {self.video_path} to {self.full_imgs_path}...")
+ video2imgs(self.video_path, self.full_imgs_path)
+ elif os.path.isdir(self.video_path):
+ logging.debug(f"DEBUG: Starting image frame copying from {self.video_path} to {self.full_imgs_path}...")
+ source_files = sorted([f for f in os.listdir(self.video_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
+ for i, filename in enumerate(tqdm(source_files, desc="Copying frames")):
+ shutil.copy(os.path.join(self.video_path, filename), os.path.join(self.full_imgs_path, f"{i:08d}.png"))
+ else:
+ raise FileNotFoundError(f"video_path '{self.video_path}' is not a valid file or directory.")
+
+ logging.debug("DEBUG: Frame extraction/copy completed. Proceeding to landmark extraction.")
+
+ # 2. Get landmarks and filter out invalid frames
+ source_images = sorted(glob.glob(os.path.join(self.full_imgs_path, '*.png')))
+ logging.debug(f"DEBUG: Starting face landmark and bbox extraction using {len(source_images)} images...")
+
+ initial_coords, initial_frames = [], []
+ try:
+ # Call the original get_landmark_and_bbox
+ initial_coords, initial_frames = get_landmark_and_bbox(source_images, self.bbox_shift)
+
+ # Count valid/invalid frames
+ valid_frame_count = sum(1 for bbox in initial_coords if bbox != (0.0, 0.0, 0.0, 0.0))
+ logging.info(f"DEBUG: get_landmark_and_bbox returned {len(initial_coords)} frames.")
+ logging.info(f"DEBUG: Detected {valid_frame_count} valid frames (non-placeholder bbox).")
+
+ if valid_frame_count == 0:
+ logging.error("CRITICAL: No valid bounding boxes were detected in any frame. Check video content, bbox_shift, or face detection model.")
+ # Save first few problematic frames for visual inspection
+ debug_output_dir = os.path.join(self.avatar_base_path, "debug_invalid_frames")
+ os.makedirs(debug_output_dir, exist_ok=True)
+ saved_debug_frames = 0
+ for i, (bbox, frame) in enumerate(zip(initial_coords, initial_frames)):
+ if bbox == (0.0, 0.0, 0.0, 0.0): # This is the placeholder for no face detected
+ if saved_debug_frames < 20: # Save up to 20 debug frames
+ debug_frame_path = os.path.join(debug_output_dir, f"invalid_frame_{i:08d}.png")
+ cv2.imwrite(debug_frame_path, frame)
+ logging.debug(f"Saved debug image for invalid frame {i} to {debug_frame_path}")
+ saved_debug_frames += 1
+ else:
+ break # Stop saving debug frames after 20
+
+ raise RuntimeError("No valid frames survived the preparation process.")
+
+ except Exception as e:
+ logging.critical(f"❌ Error during get_landmark_and_bbox: {e}", exc_info=True)
+ raise # Re-raise to ensure the main error propagates
+
+ logging.debug("DEBUG: Face landmark and bbox extraction completed. Starting VAE encoding and mask generation.")
+
+ # 3. Process valid frames: VAE encoding and mask generation
+ valid_latents, valid_coords, valid_frames, valid_masks, valid_mask_coords = [], [], [], [], []
+ logging.debug("DEBUG: Entering VAE Encoding & Masking loop.")
+
+ for i, (bbox, frame) in enumerate(tqdm(zip(initial_coords, initial_frames), total=len(initial_coords), desc="VAE Encoding & Masking")):
+ if bbox == (0.0, 0.0, 0.0, 0.0):
+ continue
+
+ # --- THE FIX: Do NOT convert bbox to integers. Keep them as floats. ---
+ x1, y1, x2, y2 = bbox
+
+ # For V15, the enlarged bbox is used for BOTH the parser and the VAE.
+ if self.version_str == "v15":
+ y2_adjusted = y2 + self.extra_margin
+ y2_adjusted = min(y2_adjusted, frame.shape[0])
+ final_bbox = [x1, y1, x2, y2_adjusted]
+ else:
+ final_bbox = [x1, y1, x2, y2]
+
+ # --- Step 1: Face Parsing with original float precision ---
+ if self.version_str == "v15":
+ mode = self.parsing_mode
+ else:
+ mode = "raw"
+
+ # This call now uses the original float coordinates and should succeed.
+ parsing_result = get_image_prepare_material(frame, final_bbox, fp=fp, mode=mode)
+
+ # The robustness check remains as a safeguard for genuinely bad frames.
+ if parsing_result is None:
+ logging.warning(f"DIAGNOSTIC: Frame {i} is a genuinely problematic frame and was skipped.")
+ continue
+
+ mask, mask_crop_box = parsing_result
+
+ # --- Step 2: VAE Encoding ---
+ # Cropping with floats is fine; array slicing will implicitly convert to integers.
+ vae_x1, vae_y1, vae_x2, vae_y2 = final_bbox
+ crop_frame = frame[int(vae_y1):int(vae_y2), int(vae_x1):int(vae_x2)]
+
+ if crop_frame.size == 0:
+ logging.warning(f"Skipping frame {i} as VAE cropped frame is empty. Bbox: {final_bbox}")
+ continue
+
+ resized_crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
+ latents = vae.get_latents_for_unet(resized_crop_frame)
+
+ # --- Step 3: Append all data from the successful frame ---
+ valid_latents.append(latents)
+ valid_coords.append(final_bbox)
+ valid_frames.append(frame)
+ valid_masks.append(mask)
+ valid_mask_coords.append(mask_crop_box)
+
+ logging.debug("DEBUG: VAE Encoding & Masking completed.")
+
+ if not valid_frames:
+ logging.error("All frames failed the parsing step. This could indicate a problem with the source video or a deeper issue with the face parsing model's environment/dependencies.")
+ raise RuntimeError("No valid frames survived the preparation process.")
+
+
+
+ # 4. Create looping cycle (forward and reverse) and save all data
+ self.frame_list_cycle = valid_frames + valid_frames[::-1]
+ self.coord_list_cycle = valid_coords + valid_coords[::-1]
+ self.input_latent_list_cycle = valid_latents + valid_latents[::-1]
+ self.mask_list_cycle = valid_masks + valid_masks[::-1]
+ self.mask_coords_list_cycle = valid_mask_coords + valid_mask_coords[::-1]
+
+ logging.debug("DEBUG: Saving final cycle data to disk.")
+ shutil.rmtree(self.full_imgs_path, ignore_errors=True)
+ os.makedirs(self.full_imgs_path)
+ shutil.rmtree(self.mask_out_path, ignore_errors=True)
+ os.makedirs(self.mask_out_path)
+
+ for i, (frame, mask) in enumerate(tqdm(zip(self.frame_list_cycle, self.mask_list_cycle), total=len(self.frame_list_cycle), desc="Saving final cycle data")):
+ cv2.imwrite(os.path.join(self.full_imgs_path, f"{i:08d}.png"), frame)
+ cv2.imwrite(os.path.join(self.mask_out_path, f"{i:08d}.png"), mask)
+
+ with open(self.coords_path, 'wb') as f:
+ pickle.dump(self.coord_list_cycle, f)
+ with open(self.mask_coords_path, 'wb') as f:
+ pickle.dump(self.mask_coords_list_cycle, f)
+ torch.save(torch.stack(self.input_latent_list_cycle), self.latents_out_path)
+
+ logging.info(f"--- Material prep complete. Final cycle length: {len(self.frame_list_cycle)} frames. ---")
+ logging.debug("DEBUG: Avatar material preparation function finished.")
+
+ @torch.no_grad()
+ def inference(self, audio_source, target_fps):
+ """
+ Main inference loop for real-time streaming.
+ Processes audio, generates frames, and sends them via GStreamer.
+ """
+ run_id = f"stream_{int(time.time())}"
+ logging.info(f"🎬 Starting inference run ID: {run_id}")
+
+ vae_to_blend_queue = queue.Queue(maxsize=self.batch_size * 2)
+ gst_video_pipeline, gst_audio_pipeline, frame_processor_thread = None, None, None
+ start_time = time.time()
+
+ try:
+ # 1. Setup GStreamer pipelines
+ if self.frame_list_cycle:
+ h, w, _ = self.frame_list_cycle[0].shape
+ gst_video_pipeline = GStreamerPipeline(width=w, height=h, fps=target_fps)
+ else:
+ logging.warning("Avatar reference frames not loaded, using default GStreamer resolution (1280x720).")
+ gst_video_pipeline = GStreamerPipeline(width=1280, height=720, fps=target_fps)
+
+ gst_audio_pipeline = GStreamerAudio(sample_rate=48000)
+ if not gst_video_pipeline.is_running or not gst_audio_pipeline.is_running: # CHECK is_running flag
+ raise RuntimeError("GStreamer pipeline(s) failed to initialize. Check GStreamer installation and plugins.")
+
+ # --- NEW: PIPELINE WARM-UP / PRE-ROLL ---
+ # This is the critical fix for the initial scrambled/sped-up quarter-second.
+ # We send a few silent, black frames to stabilize the entire pipeline
+ # (Python -> OS buffers -> GStreamer) before the real content starts.
+ logging.info("Warming up streaming pipeline with pre-roll frames...")
+ frame_duration = 1.0 / target_fps
+
+ # Create a single black frame to reuse.
+ black_frame = np.zeros((h, w, 3), dtype=np.uint8)
+
+ # Pre-roll for a quarter-second to stabilize the connection.
+ num_preroll_frames = int(target_fps / 4)
+
+ for _ in range(num_preroll_frames):
+ if not gst_video_pipeline.send_frame(black_frame):
+ raise RuntimeError("GStreamer video pipe broke during pre-roll.")
+ # We use a simple sleep here for the pre-roll, as precision is less critical
+ # than just giving the pipeline time to initialize.
+ time.sleep(frame_duration)
+ logging.info("Pre-roll complete. Starting main content stream.")
+ # --- END NEW SECTION ---
+
+ # 2. Process audio input
+ audio_reader = FFmpegAudioReader(audio_source)
+ full_audio_pcm = audio_reader.read_full_audio()
+ if full_audio_pcm is None or full_audio_pcm.size == 0:
+ raise ValueError("PCM audio is empty after FFmpeg conversion. Cannot proceed with inference.")
+
+ feature_extraction_path = audio_source if isinstance(audio_source, str) else None
+ temp_file_handle = None
+ if not feature_extraction_path:
+ temp_file_handle = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
+ # import soundfile as sf # Already imported globally now
+ sf.write(temp_file_handle.name, full_audio_pcm, gst_audio_pipeline.sample_rate)
+ feature_extraction_path = temp_file_handle.name
+ temp_file_handle.close()
+
+ whisper_input_features, librosa_length = audio_processor.get_audio_feature(feature_extraction_path, weight_dtype=weight_dtype)
+ whisper_chunks = audio_processor.get_whisper_chunk(
+ whisper_input_features,
+ device,
+ weight_dtype,
+ whisper,
+ librosa_length,
+ fps=target_fps,
+ audio_padding_length_left=AUDIO_PADDING_LENGTH_LEFT,
+ audio_padding_length_right=AUDIO_PADDING_LENGTH_RIGHT,
+ )
+
+ if temp_file_handle:
+ os.unlink(feature_extraction_path)
+
+ num_frames_to_generate = len(whisper_chunks)
+ if num_frames_to_generate == 0:
+ logging.warning("No frames to generate based on audio features. Skipping inference.")
+ return
+
+ num_vae_batches = (num_frames_to_generate + self.batch_size - 1) // self.batch_size
+ logging.info(f"Audio processed. Planning {num_frames_to_generate} frames in {num_vae_batches} VAE/UNet batches.")
+
+ # 3. Start the consumer thread
+ self.idx = 0
+ frame_processor_thread = threading.Thread(
+ target=self.process_and_send_frames,
+ args=(vae_to_blend_queue, gst_video_pipeline, gst_audio_pipeline, num_frames_to_generate, FRAME_SKIP_THRESHOLD),
+ daemon=True, name=f"FrameProcessor_{run_id}"
+ )
+ frame_processor_thread.start()
+
+ # 4. Main Generation Loop (Producer: VAE/UNet inference)
+ data_gen = datagen(whisper_chunks, self.input_latent_list_cycle, self.batch_size)
+ total_audio_samples = len(full_audio_pcm)
+ audio_samples_sent = 0
+
+ vae_pbar = tqdm(data_gen, total=num_vae_batches, desc=f"VAE/UNET [{run_id}]", unit="batch")
+ for i, batch_data in enumerate(vae_pbar):
+ current_batch_start_time = time.perf_counter() # Mark start of current batch processing
+
+ if not frame_processor_thread.is_alive():
+ logging.error(f"Frame processor thread died unexpectedly. Halting generation.")
+ break
+ if not batch_data or len(batch_data) != 2:
+ logging.warning(f"Skipping malformed batch {i}.")
+ continue
+
+ whisper_batch, latent_batch = batch_data
+ num_frames_in_batch = len(latent_batch)
+
+ # --- Core AI Inference (THIS WAS THE MISSING PART) ---
+ audio_feature = pe(whisper_batch.to(device, dtype=weight_dtype))
+ latent_input = latent_batch.to(device, dtype=unet.model.dtype)
+
+ pred_latents = unet.model(latent_input, timesteps, encoder_hidden_states=audio_feature).sample
+ pred_latents = pred_latents.to(device=device, dtype=vae.vae.dtype)
+ recon_frames = vae.decode_latents(pred_latents)
+
+ if recon_frames is None or len(recon_frames) == 0:
+ logging.warning(f"Skipping empty VAE output batch {i}.")
+ continue
+
+ # --- FPS Calculation and Display ---
+ current_batch_end_time = time.perf_counter()
+ batch_processing_time = current_batch_end_time - current_batch_start_time
+
+ if batch_processing_time > 0:
+ current_fps = num_frames_in_batch / batch_processing_time
+ else:
+ current_fps = 0.0
+
+ fps_color = 'green' if current_fps >= target_fps * 0.9 else ('yellow' if current_fps >= target_fps * 0.5 else 'red')
+ fps_string = colored(f"{current_fps:.2f} FPS", fps_color, attrs=['bold'])
+
+ vae_pbar.set_postfix_str(f"FPS: {fps_string} | Batch: {batch_processing_time:.2f}s", refresh=True)
+
+ # --- DYNAMIC A/V SYNC ---
+ audio_duration_of_batch = num_frames_in_batch / target_fps
+ num_audio_samples_for_batch = int(audio_duration_of_batch * gst_audio_pipeline.sample_rate)
+
+ start_audio_idx = audio_samples_sent
+ end_audio_idx = min(start_audio_idx + num_audio_samples_for_batch, total_audio_samples)
+ audio_chunk_pcm = full_audio_pcm[start_audio_idx:end_audio_idx]
+ audio_samples_sent = end_audio_idx
+
+ try:
+ # Put both the recon_frames (list) and audio_chunk_pcm (numpy array) into the queue
+ vae_to_blend_queue.put((list(recon_frames), audio_chunk_pcm), timeout=5.0)
+ except queue.Full:
+ logging.error(f"VAE-to-Blend queue is full. Consumer can't keep up. Halting generation.")
+ break
+
+ except Exception as e:
+ logging.critical(f"CRITICAL ERROR in inference run '{run_id}'.", exc_info=True)
+ finally:
+ logging.info(f"\n--- [{run_id}] Final Cleanup ---")
+ if 'vae_to_blend_queue' in locals():
+ try: vae_to_blend_queue.put(None) # Sentinel value for graceful shutdown
+ except Exception: pass
+
+ # Wait for consumer thread to finish
+ if frame_processor_thread and frame_processor_thread.is_alive():
+ logging.info("Waiting for frame processor thread to finish...")
+ frame_processor_thread.join(timeout=30.0)
+ if frame_processor_thread.is_alive():
+ logging.warning("Frame processor thread did not terminate gracefully.")
+
+ # Stop GStreamer pipelines
+ if gst_video_pipeline: gst_video_pipeline.stop()
+ if gst_audio_pipeline: gst_audio_pipeline.stop()
+
+ elapsed = time.time() - start_time
+ logging.info(f">>> Inference run '{run_id}' finished in {elapsed:.2f}s. <<<")
+
+
+
+ def process_and_send_frames(self, vae_to_blend_q, gst_video, gst_audio, total_frames_planned, frame_skip_threshold_val):
+ '''
+ [ENHANCED VERSION] Consumer thread: blends, paces, interleaves audio, and sends to GStreamer.
+
+ This function is the heart of the real-time playback logic. It performs three critical tasks:
+ 1. **Real-Time Pacing:** Ensures the stream does not play faster than the target FPS on fast hardware,
+ while having no negative performance impact on slower hardware.
+ 2. **Dynamic Frame Skipping:** If the AI model (producer) generates frames much faster than they can be
+ streamed, this function will intelligently drop the oldest batches to "catch up" to real-time,
+ preventing runaway latency.
+ 3. **Frame-by-Frame A/V Interleaving:** To ensure smooth audio playback without choppiness, this function
+ sends a single video frame and then IMMEDIATELY sends its corresponding small slice of audio.
+ This tight interleaving prevents audio buffer underruns on the receiving end.
+
+ Args:
+ vae_to_blend_q (queue.Queue): The queue from which to get generated face frames and audio chunks.
+ gst_video (GStreamerPipeline): The GStreamer video pipeline manager.
+ gst_audio (GStreamerAudio): The GStreamer audio pipeline manager.
+ total_frames_planned (int): The total number of frames expected for the entire audio clip.
+ frame_skip_threshold_val (int): The queue size at which we start dropping frames to reduce latency.
+ '''
+ # --- PACER INITIALIZATION ---
+ target_fps = gst_video.fps
+ frame_duration = 1.0 / target_fps
+ start_time = None # <<< FIX: Initialize start_time to None.
+ frames_sent = 0
+ # --------------------------
+ total_frames_processed = 0
+ total_frames_skipped = 0
+
+ # --- DYNAMIC FRAME SKIPPING ---
+ # (This section remains unchanged)
+ while vae_to_blend_q.qsize() > frame_skip_threshold_val:
+ try:
+ skipped_item = vae_to_blend_q.get_nowait()
+ if skipped_item and isinstance(skipped_item, tuple) and len(skipped_item[0]) > 0:
+ num_skipped_in_batch = len(skipped_item[0])
+ total_frames_skipped += num_skipped_in_batch
+ total_frames_processed += num_skipped_in_batch
+ logging.warning(
+ colored(f"Queue overloaded ({vae_to_blend_q.qsize()}). Skipping a batch of {num_skipped_in_batch} frames to catch up.", 'yellow')
+ )
+ vae_to_blend_q.task_done()
+ except queue.Empty:
+ break
+
+ # --- MAIN PROCESSING LOOP ---
+ while total_frames_processed < total_frames_planned:
+ try:
+ batch_data = vae_to_blend_q.get(block=True, timeout=10.0)
+ except queue.Empty:
+ logging.error("Timeout waiting for frames from VAE producer. Ending processor thread.")
+ break
+
+ if batch_data is None:
+ logging.info("Received sentinel. Frame processor shutting down.")
+ break
+
+ # <<< FIX: Start the pacer's clock at the exact moment the first batch arrives.
+ if start_time is None:
+ start_time = time.perf_counter()
+ logging.info("First data batch received. Starting real-time pacer clock.")
+ # <<< END FIX
+
+ vae_frames, audio_chunk = batch_data
+ num_frames_in_batch = len(vae_frames)
+ if num_frames_in_batch == 0:
+ vae_to_blend_q.task_done()
+ continue
+
+ samples_per_frame = len(audio_chunk) // num_frames_in_batch
+ audio_cursor = 0
+
+ for i, frame in enumerate(vae_frames):
+ # --- REAL-TIME PACER ---
+ next_frame_target_time = start_time + (frames_sent + 1) * frame_duration
+ sleep_duration = next_frame_target_time - time.perf_counter()
+ if sleep_duration > 0:
+ time.sleep(sleep_duration)
+ # --- END PACER ---
+
+ # 1. Blend the generated face onto the background video frame
+ blended_frame = self._blend_single_frame(frame, gst_video.width, gst_video.height)
+
+ if blended_frame is not None:
+ # 2. Send the final video frame to the video pipeline
+ if gst_video.send_frame(blended_frame):
+ frames_sent += 1
+ else:
+ logging.error("Video pipe broken. Halting all frame processing.")
+ vae_to_blend_q.task_done()
+ return
+
+ # 3. Determine the precise audio slice for THIS video frame
+ start_idx = audio_cursor
+ end_idx = (audio_cursor + samples_per_frame) if (i < num_frames_in_batch - 1) else len(audio_chunk)
+ audio_slice = audio_chunk[start_idx:end_idx]
+ audio_cursor = end_idx
+
+ # 4. Send the corresponding audio slice
+ if audio_slice.size > 0:
+ if not gst_audio.send_audio(audio_slice):
+ logging.error("Audio pipe broken. Stopping audio sends for this run.")
+
+ total_frames_processed += num_frames_in_batch
+ vae_to_blend_q.task_done()
+
+ logging.info(colored(f"--- Frame processor finished. Sent: {frames_sent}, Skipped: {total_frames_skipped} ---", "green"))
+
+ def _blend_single_frame(self, res_frame, target_width, target_height):
+ """
+ [OPTIMIZED] Blends a single generated face frame onto the original background.
+ This version uses pre-converted PIL objects to reduce real-time overhead.
+ """
+ try:
+ # 1. --- Retrieve pre-processed data for the current frame ---
+ cycle_len = len(self.coord_list_cycle)
+ if cycle_len == 0:
+ return None # No need to log error every frame
+
+ current_idx = self.idx % cycle_len
+
+ # Retrieve the pre-calculated and pre-converted PIL images
+ body_pil = self.body_pil_cycle[current_idx]
+ mask_pil = self.mask_pil_cycle[current_idx]
+
+ # These are still needed from the original lists
+ face_bbox = self.coord_list_cycle[current_idx]
+ crop_box = self.mask_coords_list_cycle[current_idx]
+
+ # 2. --- Perform necessary real-time conversions and operations ---
+
+ # Resize the AI-generated face (this must be done in real-time)
+ face_w = int(face_bbox[2] - face_bbox[0])
+ face_h = int(face_bbox[3] - face_bbox[1])
+
+ # OPTIMIZATION: Switched to a faster interpolation method.
+ # INTER_LINEAR is a great balance of speed and quality.
+ res_frame_resized = cv2.resize(res_frame.astype(np.uint8), (face_w, face_h), interpolation=cv2.INTER_LINEAR)
+
+ # This is the only BGR->RGB conversion left in the hot-loop
+ face_pil = Image.fromarray(res_frame_resized[:, :, ::-1])
+
+ # 3. --- Replicate the Original Library's Paste Logic ---
+ x_s, y_s, _, _ = [int(p) for p in crop_box]
+ x_f, y_f, _, _ = [int(p) for p in face_bbox]
+
+ # This crop operation is faster now as it's on an already-loaded PIL image
+ face_large_pil = body_pil.crop(tuple(int(p) for p in crop_box))
+
+ paste_position = (x_f - x_s, y_f - y_s)
+ face_large_pil.paste(face_pil, paste_position)
+ body_pil.paste(face_large_pil, (x_s, y_s), mask_pil)
+
+ # 4. --- Convert back to numpy for GStreamer ---
+ # This final conversion is unavoidable
+ final_frame = np.array(body_pil, dtype=np.uint8)[:, :, ::-1]
+
+ # 5. --- Increment index and return the final frame ---
+ self.idx += 1
+
+ # This final resize for GStreamer is also unavoidable
+ if final_frame.shape[0] != target_height or final_frame.shape[1] != target_width:
+ return cv2.resize(final_frame, (target_width, target_height), interpolation=cv2.INTER_LINEAR)
+ else:
+ return final_frame
+
+ except Exception as e:
+ # Reduced logging level for performance
+ if self.idx % 100 == 0: # Log only every 100 frames to avoid spam
+ logging.warning(f"Error blending frame at index {self.idx}: {e}")
+ self.idx += 1
+ return None
+
+# --- Main Application Entry Point ---
+inference_lock = threading.Lock() # Ensures only one inference run processes audio at a time
+audio_queue = queue.Queue(maxsize=10) # Queue for raw audio data received by file watcher
+
+# NEW: Load environment from powershell_env.txt (GLOBAL SCOPE)
+POWERSHELL_ENV_FILE = "C:\\temp\\powershell_env.txt"
+
+env_from_file = os.environ.copy()
+
+if os.path.exists(POWERSHELL_ENV_FILE):
+ try:
+ with open(POWERSHELL_ENV_FILE, 'r', encoding='utf-8') as f:
+ for line in f:
+ line = line.strip()
+ if '=' in line:
+ key, value = line.split('=', 1)
+ env_from_file[key] = value
+ logging.info(f"Successfully loaded environment from {POWERSHELL_ENV_FILE}.")
+ except Exception as e:
+ logging.warning(f"Failed to load environment from {POWERSHELL_ENV_FILE}: {e}", exc_info=True)
+else:
+ logging.warning(f"PowerShell environment file not found at {POWERSHELL_ENV_FILE}. Proceeding with inherited environment.")
+# --- END NEW GLOBAL BLOCK ---
+
+
+def inference_worker(avatar_instance):
+ """
+ Worker thread that waits for audio data from a queue and initiates an inference run.
+ Uses a non-blocking lock to skip audio if inference is already in progress.
+ """
+ logging.info("🚀 Inference worker started. Waiting for audio data...")
+ while True:
+ audio_data = audio_queue.get() # Blocks until audio is available or None is received
+ if audio_data is None: # Shutdown signal
+ logging.info("Inference worker received shutdown signal. Exiting.")
+ break
+
+ # Try to acquire lock non-blocking. If busy, skip this audio chunk.
+ if not inference_lock.acquire(blocking=False):
+ logging.warning("Inference already in progress. Skipping newly received audio to prevent backlog.")
+ audio_queue.task_done() # Mark this item as done even if skipped
+ continue # Continue to next item in queue
+
+ try:
+ logging.info(f"Inference lock acquired. Processing {len(audio_data)} bytes of audio.")
+ # Delegate to Avatar's inference method, which now handles GStreamer internally
+ avatar_instance.inference(audio_data, TARGET_FPS)
+ except Exception as e:
+ logging.critical(f"Unhandled exception during avatar inference: {e}", exc_info=True)
+ finally:
+ inference_lock.release() # Release lock whether successful or not
+ audio_queue.task_done() # Mark this item as done
+ logging.info("Inference lock released.")
+
+def file_watcher():
+ """
+ Monitors a specified file for changes (e.g., new audio data written to it).
+ When a change is detected, reads the file's content and puts it into the audio queue.
+ Designed for use with a "hot file" or named pipe for continuous audio input.
+ """
+ logging.info(f"👀 Starting file watcher for: {STREAM_PIPE_PATH}") # Use global STREAM_PIPE_PATH
+ last_processed_mtime = 0 # Tracks last modification time to detect new data
+ file_watcher_started_log = False # Flag to log only once
+
+ while True:
+ if not file_watcher_started_log:
+ logging.info("✅ File watcher thread is actively monitoring.")
+ file_watcher_started_log = True
+
+ try:
+ if os.path.isfile(STREAM_PIPE_PATH):
+ current_mtime = os.path.getmtime(STREAM_PIPE_PATH)
+ if current_mtime > last_processed_mtime:
+ logging.info("New audio file detected. Reading content...")
+ # Small delay to ensure file is fully written by external process
+ time.sleep(0.1)
+ with open(STREAM_PIPE_PATH, "rb") as f:
+ incoming_audio_data = f.read()
+
+ if incoming_audio_data:
+ audio_queue.put(incoming_audio_data) # Add to queue for inference worker
+ last_processed_mtime = current_mtime # Update mtime only if data was processed
+ else:
+ logging.warning("File modification detected, but file was empty. Skipping.")
+ last_processed_mtime = current_mtime # Still update mtime to avoid re-processing empty file
+ time.sleep(0.2) # Polling interval for file changes
+ except FileNotFoundError:
+ # This is normal if the file is deleted and recreated (e.g., for continuous streaming)
+ last_processed_mtime = 0 # Reset mtime to ensure next file is processed
+ time.sleep(0.5) # Wait a bit longer before re-checking for file
+ except Exception as e:
+ logging.error(f"Error in file watcher loop: {e}", exc_info=True)
+ time.sleep(2)
+
+
+if __name__ == "__main__":
+ logging.info("🎬 Starting MuseTalk Realtime Stream Sync Application (v1.5 compatible)...")
+
+ # 1. Load environment variables from .env file (must be at the top)
+ # load_dotenv() # Already done globally at the very top of script
+
+ # --- environment loading (now outside of __main__ but used by GStreamer classes) ---
+ # The `env_from_file` variable is already defined globally now.
+
+ # 2. Argument Parsing (incorporating 1.5 defaults and user's .env/cmd-line options)
+ parser = argparse.ArgumentParser(description="MuseTalk Real-Time Streaming Script")
+ parser.add_argument("--version", type=str, default=MUSE_VERSION, choices=["v1", "v15"], help="MuseTalk version (from .env MUSE_VERSION)")
+ parser.add_argument("--ffmpeg_path", type=str, default=FFMPEG_PATH, help="Path to ffmpeg executable (from .env FFMPEG_PATH)")
+ parser.add_argument("--gpu_id", type=int, default=GPU_ID, help="GPU ID to use (from .env GPU_ID)")
+ parser.add_argument("--vae_type", type=str, default=VAE_TYPE, help="Type of VAE model (from .env VAE_TYPE)")
+ parser.add_argument("--unet_config", type=str, default=UNET_CONFIG_PATH, help="Path to UNet configuration file (from .env UNET_CONFIG)")
+ parser.add_argument("--unet_model_path", type=str, default=UNET_MODEL_PATH, help="Path to UNet model weights (from .env UNET_MODEL_PATH)")
+ parser.add_argument("--whisper_dir", type=str, default=WHISPER_DIR, help="Directory containing Whisper model (from .env WHISPER_DIR)")
+ parser.add_argument("--result_dir", default=RESULT_DIR, help="Directory for output results (from .env RESULT_DIR)")
+ parser.add_argument("--extra_margin", type=int, default=EXTRA_MARGIN, help="Extra margin for face cropping (from .env EXTRA_MARGIN)")
+ parser.add_argument("--fps", type=int, default=TARGET_FPS, help="Video frames per second (from .env TARGET_FPS)")
+ parser.add_argument("--audio_padding_length_left", type=int, default=AUDIO_PADDING_LENGTH_LEFT, help="Left padding length for audio (from .env AUDIO_PADDING_LEFT)")
+ parser.add_argument("--audio_padding_length_right", type=int, default=AUDIO_PADDING_LENGTH_RIGHT, help="Right padding length for audio (from .env AUDIO_PADDING_RIGHT)")
+ parser.add_argument("--batch_size", type=int, default=BATCH_SIZE, help="Batch size for inference (from .env BATCH_SIZE)")
+ parser.add_argument("--parsing_mode", default=PARSING_MODE, help="Face blending parsing mode (from .env PARSING_MODE)")
+ parser.add_argument("--left_cheek_width", type=int, default=LEFT_CHEEK_WIDTH, help="Width of left cheek region (from .env LEFT_CHEEK_WIDTH)")
+ parser.add_argument("--right_cheek_width", type=int, default=RIGHT_CHEEK_WIDTH, help="Width of right cheek region (from .env RIGHT_CHEEK_WIDTH)")
+ parser.add_argument("--avatar_config_path", type=str, default=AVATAR_CONFIG_PATH, help="Path to avatar configuration YAML (from .env AVATAR_CONFIG_PATH)")
+ parser.add_argument("--avatar_id_to_use", type=str, default=AVATAR_ID_TO_USE, help="ID of the avatar to use from config (from .env AVATAR_ID_TO_USE)")
+ parser.add_argument("--stream_pipe_path", type=str, default=STREAM_PIPE_PATH, help="Path to watch for audio input (from .env STREAM_PIPE_PATH)")
+ parser.add_argument("--gstreamer_launch_path", type=str, default=GSTREAMER_LAUNCH_PATH, help="Full path to gst-launch-1.0.exe (from .env GSTREAMER_LAUNCH_PATH)")
+
+ args = parser.parse_args()
+ logging.info(f"INFO: GStreamer launch path resolved to: {GSTREAMER_LAUNCH_PATH}")
+
+ # 3. Configure ffmpeg path and verify
+ if not fast_check_ffmpeg():
+ logging.info("Attempting to add ffmpeg to PATH...")
+ path_separator = ';' if sys.platform == 'win32' else ':'
+ os.environ["PATH"] = f"{FFMPEG_PATH}{path_separator}{os.environ['PATH']}"
+ if not fast_check_ffmpeg():
+ logging.critical("❌ Critical: Unable to find ffmpeg even after attempting to add to PATH. Please ensure ffmpeg is properly installed and accessible. Exiting.")
+ sys.exit(1) # Exit if ffmpeg is not found
+
+ # 4. Set computing device and print GPU info
+ # 'device' is already set globally based on GPU_ID from .env
+ logging.info(f"✅ Selected device: {device}")
+ if torch.cuda.is_available():
+ try:
+ gpu_name = torch.cuda.get_device_name(device)
+ gpu_capability = torch.cuda.get_device_capability(device)
+ logging.info(f"GPU Name: {gpu_name}")
+ logging.info(f"CUDA Capability: {gpu_capability}")
+ except Exception as e:
+ logging.warning(f"Warning: Could not retrieve detailed GPU information: {e}")
+
+ # 5. Load all MuseTalk models (following MuseTalk 1.5's new loading pattern)
+ logging.info("Loading MuseTalk models (VAE, UNet, Positional Encoding)...")
+ try:
+ vae, unet, pe = load_all_model(
+ unet_model_path=UNET_MODEL_PATH,
+ vae_type=VAE_TYPE,
+ unet_config=UNET_CONFIG_PATH,
+ device=device
+ )
+ # --- NEW OPTIMIZATION ---
+ '''
+ logging.info("Applying torch.compile() for maximum performance...")
+ if torch.__version__ >= "2.0":
+ unet.model = torch.compile(unet.model, mode="reduce-overhead")
+ vae.vae = torch.compile(vae.vae, mode="reduce-overhead")
+ pe = torch.compile(pe)
+ logging.info("✅ Models successfully compiled.")
+ else:
+ logging.warning("torch.compile() requires PyTorch 2.0+. Skipping.")
+ '''
+ # ------------------------
+ timesteps = torch.tensor([0], device=device) # Initialize global timesteps tensor
+
+ pe = pe.half().to(device)
+ vae.vae = vae.vae.half().to(device)
+ unet.model = unet.model.half().to(device)
+
+ logging.info("Initializing AudioProcessor and loading Whisper model...")
+ audio_processor = AudioProcessor(feature_extractor_path=WHISPER_DIR)
+ weight_dtype = unet.model.dtype # Use model's dtype for Whisper (usually FP16 after .half())
+ whisper = WhisperModel.from_pretrained(WHISPER_DIR)
+ whisper = whisper.to(device=device, dtype=weight_dtype).eval()
+ whisper.requires_grad_(False) # Freeze Whisper model parameters
+
+ logging.info("Initializing FaceParsing model...")
+ if MUSE_VERSION == "v15":
+ fp = FaceParsing(
+ left_cheek_width=LEFT_CHEEK_WIDTH,
+ right_cheek_width=RIGHT_CHEEK_WIDTH
+ )
+ else: # v1 fallback
+ fp = FaceParsing()
+
+ logging.info("✅ All MuseTalk 1.5 core models loaded and configured successfully.")
+ except Exception as e:
+ logging.critical("❌ Fatal error loading MuseTalk models or setting up device/precision. Check model paths, CUDA, and dependencies.", exc_info=True)
+ sys.exit(1)
+
+ main_avatar = None
+ try:
+ # 6. Validate configuration paths and avatar ID
+ if not all([AVATAR_CONFIG_PATH, AVATAR_ID_TO_USE, STREAM_PIPE_PATH]):
+ raise ValueError("A required configuration path (AVATAR_CONFIG_PATH, AVATAR_ID_TO_USE, or STREAM_PIPE_PATH) is missing. Check your .env file or command-line arguments.")
+
+ logging.info(f"Attempting to load Avatar ID '{AVATAR_ID_TO_USE}' from config file '{AVATAR_CONFIG_PATH}'...")
+ config = OmegaConf.load(AVATAR_CONFIG_PATH)
+ if AVATAR_ID_TO_USE not in config:
+ raise ValueError(f"Avatar ID '{AVATAR_ID_TO_USE}' was not found as a key in the YAML file '{AVATAR_CONFIG_PATH}'")
+
+ avatar_config = config[AVATAR_ID_TO_USE]
+
+ # 7. Instantiate the Avatar with the loaded configuration
+ main_avatar = Avatar(
+ avatar_id=AVATAR_ID_TO_USE,
+ video_path=avatar_config.video_path,
+ bbox_shift=avatar_config.bbox_shift,
+ batch_size=BATCH_SIZE, # Use global BATCH_SIZE
+ preparation=avatar_config.preparation,
+ version_str=MUSE_VERSION, # Use global MUSE_VERSION
+ extra_margin=EXTRA_MARGIN, # Use global EXTRA_MARGIN
+ parsing_mode=PARSING_MODE, # Use global PARSING_MODE
+ )
+ logging.debug("DEBUG: Avatar instance successfully created. Proceeding to thread initialization.")
+ except Exception as e:
+ logging.critical("❌ CRITICAL ERROR during Avatar setup. Check your avatar config YAML and video paths.", exc_info=True)
+ sys.exit(1)
+
+ # 8. Start the file watcher and inference worker threads
+ logging.info("Starting file watcher and inference worker threads...")
+ inference_thread = threading.Thread(target=inference_worker, args=(main_avatar,), daemon=True, name="InferenceWorker")
+ watcher_thread = threading.Thread(target=file_watcher, daemon=True, name="FileWatcher")
+
+ inference_thread.start()
+ watcher_thread.start()
+
+ # 9. Keep the main thread alive and handle graceful shutdown
+ logging.info("✅ Application is running. Press Ctrl+C to shut down gracefully.")
+ try:
+ while True:
+ # Check if critical worker threads are still alive
+ if not watcher_thread.is_alive():
+ logging.error("File watcher thread died unexpectedly. Initiating shutdown.")
+ break
+ if not inference_thread.is_alive():
+ logging.error("Inference worker thread died unexpectedly. Initiating shutdown.")
+ break
+ time.sleep(1.0) # Main thread sleeps, allowing worker threads to run
+ except KeyboardInterrupt:
+ logging.info("\n🛑 KeyboardInterrupt received. Initiating graceful shutdown...")
+ finally:
+ # Signal worker threads to stop by putting a sentinel value into the queue
+ audio_queue.put(None)
+
+ logging.info("Attempting to join worker threads for clean shutdown...")
+
+ # Give watcher thread a moment to finish
+ if watcher_thread.is_alive():
+ watcher_thread.join(timeout=2)
+ if watcher_thread.is_alive():
+ logging.warning("File watcher thread did not terminate gracefully. It might be stuck.")
+
+ # Give inference thread more time as it might be processing a batch
+ if inference_thread.is_alive():
+ inference_thread.join(timeout=30)
+ if inference_thread.is_alive():
+ logging.warning("Inference worker thread did not terminate gracefully. It might be stuck.")
+
+ logging.info("✅ Application shutdown complete.")
\ No newline at end of file
diff --git a/scripts/realtime_stream_sync.py b/scripts/realtime_stream_sync.py
new file mode 100644
index 00000000..2b8e15a2
--- /dev/null
+++ b/scripts/realtime_stream_sync.py
@@ -0,0 +1,915 @@
+# -*- coding: utf-8 -*-
+import ffmpeg
+import argparse
+import os
+import concurrent.futures # For ThreadPoolExecutor in Avatar preparation
+import threading
+import queue
+import io
+from omegaconf import OmegaConf
+import subprocess
+import numpy as np
+import cv2
+import torch
+import glob
+import pickle
+import sys
+from tqdm import tqdm
+import copy
+import json
+import traceback # For detailed error printing
+import time
+from PIL import Image
+import tempfile
+import logging
+from dotenv import load_dotenv
+# --- Platform-specific imports for future use if needed ---
+if sys.platform == "win32":
+ try:
+ import psutil
+ # Optional: Set high process priority on Windows
+ p = psutil.Process(os.getpid())
+ p.nice(psutil.HIGH_PRIORITY_CLASS)
+ print("INFO: Process priority set to HIGH on Windows.")
+ except ImportError:
+ print("Warning: psutil not found. Cannot set process priority.")
+ except Exception as e:
+ print(f"Warning: Could not set process priority: {e}")
+
+
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+
+# --- MuseTalk Specific Imports (Ensure these are in your PYTHONPATH) ---
+try:
+ from musetalk.utils.utils import get_file_type, get_video_fps, datagen
+ from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
+ from musetalk.utils.blending import get_image_prepare_material
+ from musetalk.utils.utils import load_all_model
+except ImportError as e:
+ logging.critical(f"Error importing MuseTalk utilities: {e}. Ensure the library is installed and in your PYTHONPATH.", exc_info=True)
+ sys.exit(1)
+
+import shutil
+
+# --- Configuration & Global Variables ---
+# Recommended to be set via environment variables or a config file
+TARGET_FPS = int(os.getenv("TARGET_FPS", "25"))
+FRAME_SKIP_THRESHOLD = int(os.getenv("FRAME_SKIP_THRESHOLD", "3")) # Drop frames if queue size exceeds this
+AVATAR_CONFIG_PATH = os.getenv("AVATAR_CONFIG_PATH", "configs/inference/realtime.yaml")
+AVATAR_ID_TO_USE = os.getenv("AVATAR_ID_TO_USE", "default_avatar_id") # CHANGE THIS
+STREAM_PIPE_PATH = os.getenv("STREAM_PIPE_PATH", "./hot_file.opus") # File to watch for audio
+
+# --- PyTorch Device Setup ---
+cuda_available = torch.cuda.is_available()
+device = torch.device("cuda" if cuda_available else "cpu")
+logging.info("--- PyTorch Device Information ---")
+if cuda_available:
+ try:
+ gpu_name = torch.cuda.get_device_name(0)
+ logging.info(f"✅ CUDA (GPU) detected: {gpu_name}")
+ except Exception as e:
+ logging.warning(f"⚠️ Could not retrieve GPU name: {e}")
+else:
+ logging.info("❌ CUDA (GPU) not available. Using CPU.")
+logging.info(f"✅ Selected device: {device}")
+logging.info("-------------------------------")
+
+# --- Load Models ---
+logging.info("Loading models...")
+try:
+ audio_processor, vae, unet, pe = load_all_model()
+ logging.info("✅ Models loaded successfully.")
+except Exception as e:
+ logging.critical("Fatal error loading models.", exc_info=True)
+ sys.exit(1)
+
+# --- Set Model Precision (FP16) ---
+logging.info("Setting model precision to half (FP16) where applicable...")
+try:
+ if hasattr(pe, 'half'): pe = pe.to(device).half()
+ if hasattr(vae, 'vae') and hasattr(vae.vae, 'half'): vae.vae = vae.vae.to(device).half()
+ if hasattr(unet, 'model') and hasattr(unet.model, 'half'): unet.model = unet.model.to(device).half()
+ timesteps = torch.tensor([0], device=device)
+ logging.info("✅ Model precision set.")
+except Exception as e:
+ logging.warning(f"⚠️ Error setting model precision: {e}. Performance may be affected.", exc_info=True)
+
+
+# --- FFmpeg Audio Reader ---
+class FFmpegAudioReader:
+ """Uses FFmpeg to read an audio file (from path or bytes) and convert it to raw PCM."""
+ def __init__(self, audio_source):
+ self.audio_source = audio_source
+ self.is_file_path = isinstance(audio_source, str)
+
+ def read_full_audio(self):
+ """Reads the entire audio source and converts it to PCM s16le, 48kHz, Stereo."""
+ logging.info(f"Reading and converting audio from {'file' if self.is_file_path else 'memory'}...")
+ target_sr, target_ac, target_format = 48000, 2, "s16le"
+ input_data = None
+
+ ffmpeg_input_args = {}
+ if not self.is_file_path:
+ # If source is bytes, we'll pipe it to ffmpeg's stdin
+ input_filename = 'pipe:0'
+ input_data = self.audio_source
+ else:
+ input_filename = self.audio_source
+
+ try:
+ # Use ffmpeg-python for a cleaner interface
+ out, err = (
+ ffmpeg
+ .input(input_filename, **ffmpeg_input_args)
+ .output('pipe:', format=target_format, ac=target_ac, ar=target_sr)
+ .run(capture_stdout=True, capture_stderr=True, input=input_data)
+ )
+ if err:
+ logging.debug(f"FFmpeg stderr: {err.decode(errors='ignore')}")
+ except ffmpeg.Error as e:
+ logging.error(f"❌ FFmpeg error during audio conversion: {e.stderr.decode(errors='ignore') if e.stderr else 'Unknown FFmpeg error'}")
+ return None
+ except Exception as e:
+ logging.error(f"❌ Unexpected error during FFmpeg execution: {e}", exc_info=True)
+ return None
+
+ if not out:
+ logging.error("❌ Failed to read audio: FFmpeg produced no PCM data!")
+ return None
+
+ audio_data = np.frombuffer(out, dtype=np.int16).reshape(-1, target_ac)
+ logging.info(f"✅ Read and converted audio: {len(audio_data)} samples at {target_sr}Hz, {target_ac}ch.")
+ return audio_data
+
+# --- GStreamer Classes ---
+class GStreamerPipeline:
+ """Manages the GStreamer video pipeline subprocess."""
+ def __init__(self, width=1280, height=720, fps=TARGET_FPS, host="127.0.0.1", port=5000):
+ self.width, self.height, self.fps, self.host, self.port = width, height, fps, host, port
+ self.process = None
+ self.stdout_thread = None
+ self.stderr_thread = None
+
+ # A robust pipeline with leaky queues (for backpressure) and nvenc for performance
+ pipeline_str = (
+ f"fdsrc fd=0 do-timestamp=true is-live=true ! videoparse format=bgr width={self.width} height={self.height} framerate={self.fps}/1 ! "
+ "queue ! videoconvert ! videorate ! "
+ f"video/x-raw,format=NV12 ! " # Assuming NV12 is desired for nvh265enc
+ "queue ! "
+ # Reverting bitrate to 4000 as per your old working code, and keep preset for testing
+ f"nvh265enc preset=low-latency-hq rc-mode=cbr bitrate=4000 gop-size=30 ! "
+ "h265parse ! rtph265pay pt=96 config-interval=1 ! "
+ f"udpsink host={self.host} port={self.port} sync=true async=false"
+ )
+ logging.info(f"Starting GStreamer video pipeline with resolution {self.width}x{self.height}@{self.fps}fps...")
+
+ # Prepare environment variables for the subprocess
+ env_vars = os.environ.copy()
+ env_vars['GST_DEBUG'] = '3' # Set GStreamer debug level
+
+ try:
+ # Construct the command string without explicit quoting for GSTREAMER_LAUNCH_PATH here.
+ # This relies on the shell's PATH to find gst-launch-1.0.
+ # If GSTREAMER_LAUNCH_PATH itself has a full path, this will override PATH.
+ command_to_run = f"{GSTREAMER_LAUNCH_PATH} -v {pipeline_str}"
+
+ logging.debug(f"DEBUG: GStreamer VIDEO command (old way): {command_to_run}")
+
+ self.process = subprocess.Popen(
+ command_to_run,
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE, # Capture stdout
+ stderr=subprocess.PIPE, # Capture stderr
+ shell=True, # Keep shell=True
+ bufsize=0,
+ env=env_vars
+ )
+ logging.info(f"✅ GStreamer video process launched (PID: {self.process.pid}).")
+
+ # --- Start threads to read stdout and stderr asynchronously ---
+ self.stdout_thread = threading.Thread(
+ target=_log_subprocess_output,
+ args=(self.process.stdout, logging.info, "GST_VIDEO_STDOUT"),
+ daemon=True
+ )
+ self.stderr_thread = threading.Thread(
+ target=_log_subprocess_output,
+ args=(self.process.stderr, logging.error, "GST_VIDEO_STDERR"),
+ daemon=True
+ )
+ self.stdout_thread.start()
+ self.stderr_thread.start()
+
+ except Exception as e:
+ logging.error(f"❌ Failed to start GStreamer video pipeline: {e}", exc_info=True)
+ self.process = None
+
+ def send_frame(self, frame):
+ """Sends a NumPy array frame to the GStreamer pipeline's stdin."""
+ if not self.process or self.process.stdin.closed:
+ logging.error("❌ GStreamer video pipeline process is not running or stdin is closed.")
+ return False
+ try:
+ if not frame.flags['C_CONTIGUOUS']:
+ frame = np.ascontiguousarray(frame, dtype=np.uint8)
+ self.process.stdin.write(frame.tobytes())
+ self.process.stdin.flush()
+ return True
+ except (BrokenPipeError, OSError):
+ logging.error("❌ GStreamer video pipeline: Broken pipe. The process may have crashed.")
+ self.stop()
+ return False
+ except Exception as e:
+ logging.error(f"❌ Error pushing video frame: {e}", exc_info=True)
+ return False
+
+ def stop(self):
+ """Stops the GStreamer video subprocess gracefully."""
+ if self.process:
+ logging.info(f"Stopping GStreamer video pipeline (PID: {self.process.pid})...")
+ proc_to_stop, self.process = self.process, None
+ if proc_to_stop.stdin and not proc_to_stop.stdin.closed:
+ try: proc_to_stop.stdin.close()
+ except Exception: pass
+ proc_to_stop.terminate()
+ try:
+ proc_to_stop.wait(timeout=3)
+ logging.info(f"✅ GStreamer video process terminated.")
+ except subprocess.TimeoutExpired:
+ logging.warning(f"⚠️ GStreamer video process did not terminate gracefully, killing...")
+ proc_to_stop.kill()
+
+ if self.stdout_thread and self.stdout_thread.is_alive():
+ logging.info("Joining GST_VIDEO_STDOUT thread...")
+ self.stdout_thread.join(timeout=1)
+ if self.stderr_thread and self.stderr_thread.is_alive():
+ logging.info("Joining GST_VIDEO_STDERR thread...")
+ self.stderr_thread.join(timeout=1)
+ logging.info("GStreamer video pipeline stop complete.")
+
+
+class GStreamerAudio:
+ """Manages the GStreamer audio pipeline subprocess."""
+ def __init__(self, host="127.0.0.1", port=5001, sample_rate=48000, channels=2):
+ self.host, self.port, self.sample_rate, self.channels = host, port, sample_rate, channels
+ self.process = None
+ self.stdout_thread = None
+ self.stderr_thread = None
+
+ pipeline_str = (
+ f"fdsrc fd=0 do-timestamp=true is-live=true ! "
+ "queue ! "
+ f"audio/x-raw,format=S16LE,channels={self.channels},rate={self.sample_rate},layout=interleaved ! "
+ "audioconvert ! audioresample ! "
+ "opusenc bitrate=96000 ! rtpopuspay pt=97 ! "
+ f"udpsink host={self.host} port={self.port} sync=true"
+ )
+ logging.info("Starting GStreamer audio pipeline...")
+
+ env_vars = os.environ.copy()
+ env_vars['GST_DEBUG'] = '3'
+
+ try:
+ # Construct the command string without explicit quoting for GSTREAMER_LAUNCH_PATH here.
+ command_to_run = f"{GSTREAMER_LAUNCH_PATH} -v {pipeline_str}"
+
+ logging.debug(f"DEBUG: GStreamer AUDIO command (old way): {command_to_run}")
+
+ self.process = subprocess.Popen(
+ command_to_run,
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ shell=True,
+ bufsize=0,
+ env=env_vars
+ )
+ logging.info(f"✅ GStreamer audio process launched (PID: {self.process.pid}).")
+
+ # --- Start threads to read stdout and stderr asynchronously ---
+ self.stdout_thread = threading.Thread(
+ target=_log_subprocess_output,
+ args=(self.process.stdout, logging.info, "GST_AUDIO_STDOUT"),
+ daemon=True
+ )
+ self.stderr_thread = threading.Thread(
+ target=_log_subprocess_output,
+ args=(self.process.stderr, logging.error, "GST_AUDIO_STDERR"),
+ daemon=True
+ )
+ self.stdout_thread.start()
+ self.stderr_thread.start()
+
+ except Exception as e:
+ logging.error(f"❌ Failed to start GStreamer audio pipeline: {e}", exc_info=True)
+ self.process = None
+
+ def send_audio(self, audio_data_pcm):
+ """Sends raw PCM audio data to the GStreamer pipeline's stdin."""
+ if not self.process or self.process.stdin.closed:
+ logging.error("❌ GStreamer audio pipeline process is not running or stdin is closed.")
+ return False
+ try:
+ self.process.stdin.write(audio_data_pcm.tobytes())
+ self.process.stdin.flush()
+ return True
+ except (BrokenPipeError, OSError):
+ logging.error("❌ GStreamer audio pipeline: Broken pipe.")
+ self.stop()
+ return False
+ except Exception as e:
+ logging.error(f"❌ Error pushing audio chunk: {e}", exc_info=True)
+ return False
+
+ def stop(self):
+ """Stops the GStreamer audio subprocess gracefully."""
+ if self.process:
+ logging.info(f"Stopping GStreamer audio pipeline (PID: {self.process.pid})...")
+ proc_to_stop, self.process = self.process, None
+ if proc_to_stop.stdin and not proc_to_stop.stdin.closed:
+ try: proc_to_stop.stdin.close()
+ except Exception: pass
+ proc_to_stop.terminate()
+ try:
+ proc_to_stop.wait(timeout=3)
+ logging.info("✅ GStreamer audio process terminated.")
+ except subprocess.TimeoutExpired:
+ logging.warning("⚠️ GStreamer audio process did not terminate gracefully, killing...")
+ proc_to_stop.kill()
+
+ if self.stdout_thread and self.stdout_thread.is_alive():
+ logging.info("Joining GST_AUDIO_STDOUT thread...")
+ self.stdout_thread.join(timeout=1)
+ if self.stderr_thread and self.stderr_thread.is_alive():
+ logging.info("Joining GST_AUDIO_STDERR thread...")
+ self.stderr_thread.join(timeout=1)
+ logging.info("GStreamer audio pipeline stop complete.")
+
+
+ def _log_stream(self, stream, prefix):
+ try:
+ for line_bytes in iter(stream.readline, b''):
+ logging.info(f"[{prefix}]: {line_bytes.decode(errors='ignore').strip()}")
+ finally:
+ stream.close()
+
+ def send_audio(self, audio_data_pcm):
+ if not self.process or self.process.stdin.closed:
+ return False
+ try:
+ self.process.stdin.write(audio_data_pcm.tobytes())
+ self.process.stdin.flush()
+ return True
+ except (BrokenPipeError, OSError):
+ logging.error("❌ GStreamer audio pipeline: Broken pipe.")
+ self.stop()
+ return False
+ except Exception as e:
+ logging.error(f"❌ Error pushing audio chunk: {e}", exc_info=True)
+ return False
+
+ def stop(self):
+ if self.process:
+ logging.info(f"Stopping GStreamer audio pipeline (PID: {self.process.pid})...")
+ proc_to_stop, self.process = self.process, None
+ if proc_to_stop.stdin and not proc_to_stop.stdin.closed:
+ try: proc_to_stop.stdin.close()
+ except Exception: pass
+ proc_to_stop.terminate()
+ try:
+ proc_to_stop.wait(timeout=2)
+ logging.info("✅ GStreamer audio process terminated.")
+ except subprocess.TimeoutExpired:
+ logging.warning("⚠️ GStreamer audio process did not terminate gracefully, killing...")
+ proc_to_stop.kill()
+
+# --- Helper Functions ---
+def video2imgs(vid_path, save_path):
+ logging.info(f"Extracting frames from {vid_path} to {save_path}...")
+ cap = cv2.VideoCapture(vid_path)
+ if not cap.isOpened():
+ logging.error(f"Error: Could not open video file: {vid_path}")
+ return
+ count = 0
+ while True:
+ ret, frame = cap.read()
+ if not ret: break
+ cv2.imwrite(os.path.join(save_path, f"{str(count).zfill(8)}.png"), frame)
+ count += 1
+ cap.release()
+ logging.info(f"Finished extracting {count} frames.")
+
+def osmakedirs(path_list):
+ for path in path_list:
+ os.makedirs(path, exist_ok=True)
+
+# --- Avatar Class ---
+class Avatar:
+ @torch.no_grad()
+ def __init__(self, avatar_id, video_path, bbox_shift, batch_size, preparation):
+ logging.info(f"Initializing Avatar: {avatar_id}")
+ self.avatar_id = str(avatar_id)
+ self.video_path = video_path
+ self.bbox_shift = bbox_shift
+ self.batch_size = batch_size
+ self.preparation = preparation
+
+ # Define paths
+ self.avatar_base_path = os.path.join("./results/avatars", self.avatar_id)
+ self.full_imgs_path = os.path.join(self.avatar_base_path, "full_imgs")
+ self.mask_out_path = os.path.join(self.avatar_base_path, "masks")
+ self.coords_path = os.path.join(self.avatar_base_path, "coords.pkl")
+ self.latents_out_path = os.path.join(self.avatar_base_path, "latents.pt")
+ self.mask_coords_path = os.path.join(self.avatar_base_path, "mask_coords.pkl")
+ self.avatar_info_path = os.path.join(self.avatar_base_path, "avatar_info.json")
+
+ # Initialize data stores
+ self.input_latent_list_cycle = []
+ self.coord_list_cycle = []
+ self.frame_list_cycle = []
+ self.mask_coords_list_cycle = []
+ self.mask_list_cycle = []
+
+ self.idx = 0 # Tracks current position in the reference video cycle
+
+ self.init_avatar_data()
+ logging.info(f"✅ Avatar '{self.avatar_id}' initialized with {len(self.frame_list_cycle)} reference frames.")
+
+ def init_avatar_data(self):
+ # If preparation is needed, do it. Otherwise, load existing data.
+ if self.preparation:
+ if os.path.exists(self.avatar_base_path):
+ # Prompt user before overwriting existing data
+ response = input(f"Avatar '{self.avatar_id}' data exists. Re-create all material? (y/n): ").strip().lower()
+ if response == "y":
+ logging.info(f"User chose to re-create. Removing: {self.avatar_base_path}")
+ shutil.rmtree(self.avatar_base_path)
+ self._prepare_material_core()
+ else:
+ logging.info("Loading existing data as per user request.")
+ self._reload_prepared_data()
+ else:
+ self._prepare_material_core()
+ else:
+ logging.info("Preparation=False. Loading existing prepared data...")
+ self._reload_prepared_data()
+
+ def _reload_prepared_data(self):
+ logging.info(f"Reloading prepared data from: {self.avatar_base_path}")
+ try:
+ # Load all data from disk
+ loaded_latents = torch.load(self.latents_out_path, map_location='cpu')
+ self.input_latent_list_cycle = list(loaded_latents) if isinstance(loaded_latents, torch.Tensor) else loaded_latents
+ with open(self.coords_path, 'rb') as f: self.coord_list_cycle = pickle.load(f)
+ with open(self.mask_coords_path, 'rb') as f: self.mask_coords_list_cycle = pickle.load(f)
+
+ # Read corresponding images and masks
+ num_items = len(self.coord_list_cycle)
+ if num_items == 0: raise ValueError("Loaded coordinate data is empty.")
+
+ frame_files = [os.path.join(self.full_imgs_path, f"{str(i).zfill(8)}.png") for i in range(num_items)]
+ self.frame_list_cycle = read_imgs(frame_files)
+ mask_files = [os.path.join(self.mask_out_path, f"{str(i).zfill(8)}.png") for i in range(num_items)]
+ self.mask_list_cycle = read_imgs(mask_files)
+
+ # Validate that all lists have the same, non-zero length
+ data_map = {
+ "Latents": self.input_latent_list_cycle, "Coords": self.coord_list_cycle,
+ "Frames": self.frame_list_cycle, "Masks": self.mask_list_cycle, "MaskCoords": self.mask_coords_list_cycle
+ }
+ if not all(len(lst) == num_items for lst in data_map.values()):
+ lengths = {name: len(lst) for name, lst in data_map.items()}
+ raise ValueError(f"Data lists have mismatched lengths after loading: {lengths}")
+
+ except Exception as e:
+ logging.critical(f"Error reloading prepared data. You may need to run with preparation=True.", exc_info=True)
+ raise SystemExit(f"Exiting: Failed to reload data for {self.avatar_id}.")
+
+ @torch.no_grad()
+ def _prepare_material_core(self):
+ logging.info(f"--- Preparing new material for avatar: {self.avatar_id} ---")
+ osmakedirs([self.avatar_base_path, self.full_imgs_path, self.mask_out_path])
+ with open(self.avatar_info_path, "w") as f: json.dump({"avatar_id": self.avatar_id}, f)
+
+ # 1. Extract frames from video or copy from image folder
+ if os.path.isfile(self.video_path):
+ video2imgs(self.video_path, self.full_imgs_path)
+ elif os.path.isdir(self.video_path):
+ # Handles copying and renaming from a folder of images
+ source_files = sorted([f for f in os.listdir(self.video_path) if f.lower().endswith(('.png','.jpg','.jpeg'))])
+ for i, filename in enumerate(tqdm(source_files, desc="Copying frames")):
+ shutil.copy(os.path.join(self.video_path, filename), os.path.join(self.full_imgs_path, f"{i:08d}.png"))
+ else:
+ raise FileNotFoundError(f"video_path '{self.video_path}' is not a valid file or directory.")
+
+ # 2. Get landmarks and filter out invalid frames
+ source_images = sorted(glob.glob(os.path.join(self.full_imgs_path, '*.png')))
+ initial_coords, initial_frames = get_landmark_and_bbox(source_images, self.bbox_shift)
+
+ # 3. Process valid frames: VAE encoding and mask generation
+ valid_latents, valid_coords, valid_frames, valid_masks, valid_mask_coords = [], [], [], [], []
+ coord_ph_val = coord_placeholder
+
+ for i, (bbox, frame) in enumerate(tqdm(zip(initial_coords, initial_frames), total=len(initial_coords), desc="VAE Encoding & Masking")):
+ if bbox is None or np.array_equal(bbox, coord_ph_val) or frame is None:
+ continue
+
+ x1c, y1c, x2c, y2c = bbox
+ crop = frame[int(y1c):int(y2c), int(x1c):int(x2c)]
+ if crop.size == 0: continue
+
+ try:
+ # VAE Encoding
+ resized_crop = cv2.resize(crop, (256, 256), interpolation=cv2.INTER_LANCZOS4)
+ latents = vae.get_latents_for_unet(resized_crop).cpu()
+
+ # Mask Generation
+ mask, crop_box = get_image_prepare_material(frame, bbox)
+ if mask is None or crop_box is None: continue
+
+ # Add all valid data together
+ valid_latents.append(latents)
+ valid_coords.append(bbox)
+ valid_frames.append(frame)
+ valid_masks.append(mask)
+ valid_mask_coords.append(crop_box)
+ except Exception as e:
+ logging.warning(f"Skipping frame {i} due to processing error: {e}")
+
+ if not valid_frames:
+ raise RuntimeError("No valid frames survived the preparation process.")
+
+ # 4. Create looping cycle (forward and reverse) and save all data
+ self.frame_list_cycle = valid_frames + valid_frames[::-1]
+ self.coord_list_cycle = valid_coords + valid_coords[::-1]
+ self.input_latent_list_cycle = valid_latents + valid_latents[::-1]
+ self.mask_list_cycle = valid_masks + valid_masks[::-1]
+ self.mask_coords_list_cycle = valid_mask_coords + valid_mask_coords[::-1]
+
+ # Overwrite content with the final, filtered, and cycled data
+ shutil.rmtree(self.full_imgs_path); os.makedirs(self.full_imgs_path)
+ shutil.rmtree(self.mask_out_path); os.makedirs(self.mask_out_path)
+ for i, (frame, mask) in enumerate(tqdm(zip(self.frame_list_cycle, self.mask_list_cycle), total=len(self.frame_list_cycle), desc="Saving final cycle data")):
+ cv2.imwrite(os.path.join(self.full_imgs_path, f"{i:08d}.png"), frame)
+ cv2.imwrite(os.path.join(self.mask_out_path, f"{i:08d}.png"), mask)
+
+ with open(self.coords_path, 'wb') as f: pickle.dump(self.coord_list_cycle, f)
+ with open(self.mask_coords_path, 'wb') as f: pickle.dump(self.mask_coords_list_cycle, f)
+ torch.save(torch.stack(self.input_latent_list_cycle), self.latents_out_path)
+
+ logging.info(f"--- Material prep complete. Final cycle length: {len(self.frame_list_cycle)} frames. ---")
+
+ @torch.no_grad()
+ def inference(self, audio_source, target_fps):
+ # This is the main inference producer loop
+ run_id = f"stream_{int(time.time())}"
+ logging.info(f"🎬 Starting inference run ID: {run_id}")
+
+ # This queue passes generated frames and audio to the processing/sending thread
+ vae_to_blend_queue = queue.Queue(maxsize=self.batch_size * 2)
+ gst_video_pipeline, gst_audio_pipeline, frame_processor_thread = None, None, None
+ start_time = time.time()
+
+ try:
+ # 1. Setup GStreamer pipelines
+ gst_video_pipeline = GStreamerPipeline(fps=target_fps)
+ gst_audio_pipeline = GStreamerAudio()
+ if not gst_video_pipeline.process or not gst_audio_pipeline.process:
+ raise RuntimeError("GStreamer pipeline(s) failed to initialize.")
+
+ # 2. Process audio input
+ audio_reader = FFmpegAudioReader(audio_source)
+ full_audio_pcm = audio_reader.read_full_audio()
+ if full_audio_pcm is None or full_audio_pcm.size == 0:
+ raise ValueError("PCM audio is empty after FFmpeg conversion.")
+
+ # Use a temporary file for feature extraction if audio source is in memory
+ feature_extraction_path = audio_source if isinstance(audio_source, str) else None
+ temp_file_handle = None
+ if not feature_extraction_path:
+ temp_file_handle = tempfile.NamedTemporaryFile(delete=False, suffix=".opus")
+ temp_file_handle.write(audio_source)
+ feature_extraction_path = temp_file_handle.name
+ temp_file_handle.close() # Close handle so audio_processor can open it
+
+ whisper_feature = audio_processor.audio2feat(feature_extraction_path)
+ whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature, fps=target_fps)
+
+ if temp_file_handle:
+ os.unlink(feature_extraction_path)
+
+ num_frames_to_generate = len(whisper_chunks)
+ if num_frames_to_generate == 0:
+ logging.warning("No frames to generate based on audio features. Skipping.")
+ return
+
+ num_vae_batches = (num_frames_to_generate + self.batch_size - 1) // self.batch_size
+ logging.info(f"Audio processed. Planning {num_frames_to_generate} frames in {num_vae_batches} VAE batches.")
+
+ # 3. Start the consumer thread (frame processing and sending)
+ self.idx = 0 # Reset reference frame index for each run
+ frame_processor_thread = threading.Thread(
+ target=self.process_and_send_frames,
+ args=(vae_to_blend_queue, gst_video_pipeline, gst_audio_pipeline, num_frames_to_generate),
+ daemon=True, name=f"FrameProcessor_{run_id}"
+ )
+ frame_processor_thread.start()
+
+ # 4. Main Generation Loop (Producer)
+ data_gen = datagen(whisper_chunks, self.input_latent_list_cycle, self.batch_size)
+ total_audio_samples = len(full_audio_pcm)
+ audio_samples_sent = 0
+
+ for i, batch_data in enumerate(tqdm(data_gen, total=num_vae_batches, desc=f"VAE/UNET [{run_id}]")):
+ if not frame_processor_thread.is_alive():
+ logging.error(f"Frame processor thread died unexpectedly. Halting generation.")
+ break
+ if not batch_data or len(batch_data) != 2: continue
+
+ whisper_batch, latent_batch = batch_data
+
+ # --- Core AI Inference ---
+ audio_feature = pe(torch.from_numpy(whisper_batch).to(device, dtype=unet.model.dtype))
+ # The 'datagen' utility already stacks the latents into a batch tensor.
+ # We just need to move this tensor to the correct device and set its data type.
+ if not isinstance(latent_batch, torch.Tensor):
+ raise TypeError(f"The 'datagen' utility should yield a Tensor, but got {type(latent_batch)}")
+ latent_input = latent_batch.to(device, dtype=unet.model.dtype)
+ pred_latents = unet.model(latent_input, timesteps, encoder_hidden_states=audio_feature).sample
+ vae_output = vae.decode_latents(pred_latents) # Returns list/np.array of frames
+ # --- End Core AI ---
+
+ if vae_output is None or len(vae_output) == 0: continue
+
+ # --- DYNAMIC A/V SYNC ---
+ # Calculate how much audio this batch of video frames represents
+ num_frames_in_batch = len(vae_output)
+ audio_duration_of_batch = num_frames_in_batch / target_fps
+ num_audio_samples_for_batch = int(audio_duration_of_batch * gst_audio_pipeline.sample_rate)
+
+ # Slice the exact audio chunk from the full PCM data
+ start_audio_idx = audio_samples_sent
+ end_audio_idx = min(start_audio_idx + num_audio_samples_for_batch, total_audio_samples)
+ audio_chunk_pcm = full_audio_pcm[start_audio_idx:end_audio_idx]
+ audio_samples_sent = end_audio_idx
+ # --- End A/V SYNC ---
+
+ try:
+ # Put the generated frames and their corresponding audio chunk into the queue
+ vae_to_blend_queue.put((list(vae_output), audio_chunk_pcm), timeout=5.0)
+ except queue.Full:
+ logging.error(f"VAE-to-Blend queue is full. Consumer can't keep up. Halting generation.")
+ break
+
+ except Exception as e:
+ logging.critical(f"CRITICAL ERROR in inference run '{run_id}'.", exc_info=True)
+ finally:
+ logging.info(f"\n--- [{run_id}] Final Cleanup ---")
+ # Signal consumer thread to finish
+ if 'vae_to_blend_queue' in locals():
+ try: vae_to_blend_queue.put(None) # Sentinel value
+ except Exception: pass
+
+ # Wait for consumer thread to finish its work
+ if frame_processor_thread and frame_processor_thread.is_alive():
+ logging.info("Waiting for frame processor thread to finish...")
+ frame_processor_thread.join(timeout=15.0)
+
+ # Stop GStreamer pipelines
+ if gst_video_pipeline: gst_video_pipeline.stop()
+ if gst_audio_pipeline: gst_audio_pipeline.stop()
+
+ elapsed = time.time() - start_time
+ logging.info(f">>> Inference run '{run_id}' finished in {elapsed:.2f}s. <<<")
+
+ def process_and_send_frames(self, vae_to_blend_q, gst_video, gst_audio, total_frames_planned):
+ # This is the consumer loop
+ total_frames_processed = 0
+ while total_frames_processed < total_frames_planned:
+
+ # --- FRAME SKIPPING LOGIC ---
+ # If the queue is getting too full, processing is falling behind.
+ # Discard older frames to catch up to the most recent ones.
+ while vae_to_blend_q.qsize() > FRAME_SKIP_THRESHOLD:
+ try:
+ logging.warning(f"Queue high ({vae_to_blend_q.qsize()}). Skipping a frame batch to catch up.")
+ skipped_item = vae_to_blend_q.get_nowait()
+ if skipped_item is not None:
+ total_frames_processed += len(skipped_item[0]) # Add skipped frames to total
+ vae_to_blend_q.task_done()
+ except queue.Empty:
+ break # Queue emptied, proceed normally
+ # --- END FRAME SKIPPING ---
+
+ try:
+ # Block and wait for the next item from the generator
+ batch_data = vae_to_blend_q.get(block=True, timeout=10.0)
+ except queue.Empty:
+ logging.error("Timeout waiting for frames from VAE. Ending processor thread.")
+ break
+
+ if batch_data is None: # Sentinel value means we're done
+ logging.info("Received sentinel. Frame processor shutting down.")
+ break
+
+ vae_frames, audio_chunk = batch_data
+
+ # Process each frame in the batch (simplified, no inner parallelization)
+ for frame in vae_frames:
+ blended_frame = self._blend_single_frame(frame, gst_video.width, gst_video.height)
+ if blended_frame is not None:
+ if not gst_video.send_frame(blended_frame):
+ logging.error("Failed to send video frame to GStreamer. Stopping video sends.")
+ vae_to_blend_q.task_done()
+ return # Exit thread if pipe is broken
+
+ # After sending all video frames for the batch, send the corresponding audio
+ if audio_chunk is not None and audio_chunk.size > 0:
+ if not gst_audio.send_audio(audio_chunk):
+ logging.error("Failed to send audio chunk to GStreamer.")
+
+ total_frames_processed += len(vae_frames)
+ vae_to_blend_q.task_done()
+
+ logging.info("--- Frame processor finished its work. ---")
+
+ def _blend_single_frame(self, res_frame, target_width, target_height):
+ try:
+ cycle_len = len(self.coord_list_cycle)
+ if cycle_len == 0:
+ logging.error("Cannot blend frame, avatar cycle data is empty.")
+ return None
+
+ current_idx = self.idx % cycle_len
+
+ # Retrieve all necessary data for the current reference frame
+ bbox = self.coord_list_cycle[current_idx]
+ ori_frame = self.frame_list_cycle[current_idx]
+ mask = self.mask_list_cycle[current_idx]
+
+ # Check for corrupt/missing data before processing
+ if mask is None or ori_frame is None or bbox is None:
+ logging.warning(f"Skipping frame blend at index {self.idx} due to missing or corrupt reference data for cycle index {current_idx}.")
+ self.idx += 1
+ return None
+
+ x, y, x1, y1 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
+ face_w, face_h = x1 - x, y1 - y
+
+ if face_w <= 0 or face_h <= 0:
+ self.idx += 1
+ return None
+
+ resized_face = cv2.resize(res_frame.astype(np.uint8), (face_w, face_h), interpolation=cv2.INTER_LINEAR)
+
+ if mask.shape[:2] != (face_h, face_w):
+ mask = cv2.resize(mask, (face_w, face_h), interpolation=cv2.INTER_LINEAR)
+
+ # --- THIS IS THE CORRECTED LOGIC ---
+ # First, check if the mask is a 3-channel BGR image.
+ if len(mask.shape) == 3 and mask.shape[2] == 3:
+ # If it is, convert it to grayscale.
+ alpha_mask = (cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) / 255.0)[..., np.newaxis]
+ else:
+ # Otherwise, assume it's already the 1-channel grayscale image we need.
+ alpha_mask = (mask / 255.0)[..., np.newaxis]
+ # --- END OF CORRECTION ---
+
+ blended_frame = ori_frame.copy()
+ face_region = blended_frame[y:y1, x:x1]
+
+ # Ensure the alpha mask can be broadcast to the face region shape for blending
+ if face_region.shape != alpha_mask.shape:
+ alpha_mask = np.repeat(alpha_mask, 3, axis=2)
+
+ blended_face = face_region.astype(np.float32) * (1.0 - alpha_mask) + resized_face.astype(np.float32) * alpha_mask
+ blended_frame[y:y1, x:x1] = blended_face.astype(np.uint8)
+
+ self.idx += 1
+ return cv2.resize(blended_frame, (target_width, target_height), interpolation=cv2.INTER_LINEAR)
+
+ except Exception as e:
+ logging.error(f"Error blending frame at index {self.idx}: {e}", exc_info=True)
+ self.idx += 1
+ return None
+
+
+# --- Main Application Logic ---
+inference_lock = threading.Lock()
+audio_queue = queue.Queue(maxsize=10)
+
+def inference_worker(avatar, fps):
+ """Worker thread that waits for audio data from a queue and runs inference."""
+ logging.info("🚀 Inference worker started. Waiting for audio data...")
+ while True:
+ audio_data = audio_queue.get()
+ if audio_data is None: # Shutdown signal
+ break
+
+ if not inference_lock.acquire(blocking=False):
+ logging.warning("Inference already in progress. Skipping newly received audio.")
+ audio_queue.task_done()
+ continue
+
+ try:
+ logging.info(f"Inference lock acquired. Processing {len(audio_data)} bytes of audio.")
+ avatar.inference(audio_data, fps)
+ finally:
+ inference_lock.release()
+ audio_queue.task_done()
+ logging.info("Inference lock released.")
+
+def file_watcher(pipe_path):
+ """Monitors a file for changes and puts its content into the audio queue."""
+ logging.info(f"👀 Starting file watcher for: {pipe_path}")
+ last_processed_mtime = 0
+ while True:
+ try:
+ if os.path.isfile(pipe_path):
+ current_mtime = os.path.getmtime(pipe_path)
+ if current_mtime > last_processed_mtime:
+ logging.info("New audio file detected. Reading content...")
+ # Add a small delay and retry to handle partially written files
+ time.sleep(0.1)
+ with open(pipe_path, "rb") as f:
+ opus_data = f.read()
+
+ if opus_data:
+ audio_queue.put(opus_data)
+ last_processed_mtime = current_mtime
+ else:
+ logging.warning("File modification detected, but file was empty.")
+ last_processed_mtime = current_mtime
+ time.sleep(0.2) # Poll interval
+ except FileNotFoundError:
+ # This is okay, the file might be deleted and recreated
+ last_processed_mtime = 0
+ time.sleep(0.5)
+ except Exception as e:
+ logging.error(f"Error in file watcher loop.", exc_info=True)
+ time.sleep(2)
+
+
+if __name__ == "__main__":
+ logging.info("🎬 Starting Realtime Stream Sync Application...")
+
+ # 1. Load environment variables from .env file
+ # This must be at the top of the execution block.
+ from dotenv import load_dotenv
+ load_dotenv()
+
+ # 2. Read configuration exclusively from the loaded environment variables
+ AVATAR_CONFIG_PATH = os.getenv("AVATAR_CONFIG_PATH")
+ AVATAR_ID_TO_USE = os.getenv("AVATAR_ID_TO_USE")
+ TARGET_FPS = int(os.getenv("TARGET_FPS", "25"))
+ STREAM_PIPE_PATH = os.getenv("STREAM_PIPE_PATH")
+
+ main_avatar = None
+ try:
+ # 3. Validate that the required variables were found in the .env file
+ if not all([AVATAR_CONFIG_PATH, AVATAR_ID_TO_USE, STREAM_PIPE_PATH]):
+ raise ValueError("A required variable (AVATAR_CONFIG_PATH, AVATAR_ID_TO_USE, or STREAM_PIPE_PATH) is missing from your .env file.")
+
+ logging.info(f"Attempting to load Avatar ID '{AVATAR_ID_TO_USE}' from config file '{AVATAR_CONFIG_PATH}'...")
+
+ # 4. Load the YAML config and select the specified avatar
+ config = OmegaConf.load(AVATAR_CONFIG_PATH)
+ if AVATAR_ID_TO_USE not in config:
+ raise ValueError(f"Avatar ID '{AVATAR_ID_TO_USE}' was not found as a key in the YAML file '{AVATAR_CONFIG_PATH}'")
+
+ avatar_config = config[AVATAR_ID_TO_USE]
+
+ # 5. Instantiate the Avatar with the loaded configuration
+ main_avatar = Avatar(
+ avatar_id=AVATAR_ID_TO_USE,
+ video_path=avatar_config.video_path,
+ bbox_shift=avatar_config.bbox_shift,
+ batch_size=avatar_config.get("batch_size", 4),
+ preparation=avatar_config.preparation
+ )
+ except Exception as e:
+ logging.critical("❌ CRITICAL ERROR during setup.", exc_info=True)
+ sys.exit(1)
+
+ # 6. Start the watcher and inference worker threads
+ inference_thread = threading.Thread(target=inference_worker, args=(main_avatar, TARGET_FPS), daemon=True)
+ watcher_thread = threading.Thread(target=file_watcher, args=(STREAM_PIPE_PATH,), daemon=True)
+
+ inference_thread.start()
+ watcher_thread.start()
+
+ # 7. Keep the main thread alive and handle graceful shutdown
+ logging.info("✅ Application is running. Press Ctrl+C to shut down.")
+ try:
+ while True:
+ if not watcher_thread.is_alive() or not inference_thread.is_alive():
+ logging.error("A critical worker thread has died. Shutting down.")
+ break
+ time.sleep(1.0)
+ except KeyboardInterrupt:
+ logging.info("\n🛑 KeyboardInterrupt received. Initiating graceful shutdown...")
+ finally:
+ audio_queue.put(None)
+ if watcher_thread.is_alive(): watcher_thread.join(timeout=2)
+ if inference_thread.is_alive(): inference_thread.join(timeout=30)
+ logging.info("✅ Application shutdown complete.")
\ No newline at end of file
diff --git a/scripts/test.py b/scripts/test.py
new file mode 100644
index 00000000..92cbf9ae
--- /dev/null
+++ b/scripts/test.py
@@ -0,0 +1,54 @@
+import ffmpeg
+import numpy as np
+import pyaudio
+import time
+
+def load_audio_ffmpeg(audio_path, sample_rate=48000):
+ """ 用 FFmpeg 读取音频并转换为 int16 PCM """
+ out, _ = (
+ ffmpeg.input(audio_path)
+ .output('pipe:', format='s16le', acodec='pcm_s16le', ac=1, ar=sample_rate)
+ .run(capture_stdout=True)
+ )
+ audio_data = np.frombuffer(out, dtype=np.int16)
+ return audio_data, sample_rate
+
+def split_audio(audio_data, num_chunks):
+ """ 将音频数据均分为 num_chunks 份 """
+ chunk_size = len(audio_data) // num_chunks
+ audio_chunks = [audio_data[i * chunk_size: (i + 1) * chunk_size] for i in range(num_chunks)]
+ return audio_chunks
+
+def play_audio_chunk(audio_chunk, sample_rate=48000):
+ """ 播放单个音频片段 """
+ p = pyaudio.PyAudio()
+ stream = p.open(format=pyaudio.paInt16, channels=1, rate=sample_rate, output=True)
+
+ print(f"🎧 播放音频片段: {len(audio_chunk)} samples")
+ stream.write(audio_chunk.tobytes()) # 确保是 int16 PCM
+ stream.stop_stream()
+ stream.close()
+ p.terminate()
+ print("✅ 音频播放完成")
+
+def main():
+ audio_path = "input.wav" # 你的音频文件路径
+ num_chunks = 10 # 切分成 10 段
+
+ # 读取音频
+ print("🎵 读取音频中...")
+ audio_data, sample_rate = load_audio_ffmpeg("/home/fan370/Documents/MuseTalk/data/audio/sun.wav")
+ print(f"📊 音频数据大小: {audio_data.shape}, 采样率: {sample_rate} Hz")
+
+ # 切割音频
+ print("✂️ 正在切割音频...")
+ audio_chunks = split_audio(audio_data, num_chunks)
+
+ # 播放音频
+ for i, chunk in enumerate(audio_chunks):
+ print(f"🎧 播放第 {i+1}/{len(audio_chunks)} 个音频片段")
+ play_audio_chunk(chunk, sample_rate)
+ time.sleep(0.1) # 避免过快播放
+
+if __name__ == "__main__":
+ main()
diff --git a/setup-gstreamer-env.ps1 b/setup-gstreamer-env.ps1
new file mode 100644
index 00000000..4bbf0e64
--- /dev/null
+++ b/setup-gstreamer-env.ps1
@@ -0,0 +1,17 @@
+# Path to GStreamer installation
+$GST_ROOT = "E:\\gstreamer\\1.0\\msvc_x86_64"
+
+# Add GStreamer's Python module directory to PYTHONPATH
+$env:PYTHONPATH = "$GST_ROOT\\lib\\python3\\site-packages;$env:PYTHONPATH"
+
+# Add GStreamer binary directory to PATH
+$env:PATH = "$GST_ROOT\\bin;$env:PATH"
+
+# Set GI_TYPELIB_PATH to find GObject Introspection typelibs
+$env:GI_TYPELIB_PATH = "$GST_ROOT\\lib\\girepository-1.0"
+
+# Set GST_PLUGIN_PATH to find GStreamer plugins
+$env:GST_PLUGIN_PATH = "$GST_ROOT\\lib\\gstreamer-1.0"
+
+# Run a simple test to verify it works
+python -c "import gi; gi.require_version('Gst', '1.0'); from gi.repository import Gst; Gst.init(None); print(f'GStreamer version: {Gst.version_string()}'); print('GStreamer successfully initialized!')"
diff --git a/setup-pygstreamer-env.ps1 b/setup-pygstreamer-env.ps1
new file mode 100644
index 00000000..56f6d0a9
--- /dev/null
+++ b/setup-pygstreamer-env.ps1
@@ -0,0 +1,23 @@
+# 1. Create a bridge script named 'setup-gstreamer-env.ps1'
+@"
+# Path to GStreamer installation
+`$GST_ROOT = "E:\\gstreamer\\1.0\\msvc_x86_64"
+
+# Add GStreamer's Python module directory to PYTHONPATH
+`$env:PYTHONPATH = "`$GST_ROOT\\lib\\python3\\site-packages;`$env:PYTHONPATH"
+
+# Add GStreamer binary directory to PATH
+`$env:PATH = "`$GST_ROOT\\bin;`$env:PATH"
+
+# Set GI_TYPELIB_PATH to find GObject Introspection typelibs
+`$env:GI_TYPELIB_PATH = "`$GST_ROOT\\lib\\girepository-1.0"
+
+# Set GST_PLUGIN_PATH to find GStreamer plugins
+`$env:GST_PLUGIN_PATH = "`$GST_ROOT\\lib\\gstreamer-1.0"
+
+# Run a simple test to verify it works
+python -c "import gi; gi.require_version('Gst', '1.0'); from gi.repository import Gst; Gst.init(None); print(f'GStreamer version: {Gst.version_string()}'); print('GStreamer successfully initialized!')"
+"@ | Out-File -FilePath setup-gstreamer-env.ps1 -Encoding utf8
+
+# 2. Run the script to configure environment
+./setup-gstreamer-env.ps1
diff --git a/setup-pygstreamer.ps1 b/setup-pygstreamer.ps1
new file mode 100644
index 00000000..9e3011c0
--- /dev/null
+++ b/setup-pygstreamer.ps1
@@ -0,0 +1,34 @@
+# Create this script as setup-pygstreamer.ps1
+
+# Clear any existing paths that might conflict
+$env:PATH = ($env:PATH -split ';' | Where-Object { -not $_.Contains('msys64') }) -join ';'
+
+# Add GStreamer paths
+$GST_ROOT = "E:\gstreamer\1.0\msvc_x86_64"
+$env:PATH = "$GST_ROOT\bin;$env:PATH"
+$env:GI_TYPELIB_PATH = "$GST_ROOT\lib\girepository-1.0"
+$env:PKG_CONFIG_PATH = "$GST_ROOT\lib\pkgconfig"
+$env:GST_PLUGIN_PATH = "$GST_ROOT\lib\gstreamer-1.0"
+
+# Install PyGObjects' Windows binaries
+pip install --no-cache-dir PyGObject-stubs
+pip install --no-cache-dir pycairo
+
+# Don't forget to install torch/CUDA if needed for MuseTalk
+# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
+
+# Run a test script to verify
+$TEST_SCRIPT = @'
+import gi
+gi.require_version("Gst", "1.0")
+from gi.repository import Gst
+
+Gst.init(None)
+print(f"GStreamer version: {Gst.version_string()}")
+print("GStreamer successfully initialized!")
+'@
+
+Write-Host "Creating test script..."
+$TEST_SCRIPT | Out-File -FilePath test_gst.py -Encoding utf8
+Write-Host "Running test script..."
+python test_gst.py
diff --git a/test-custom-opencv.py b/test-custom-opencv.py
new file mode 100644
index 00000000..d84165db
--- /dev/null
+++ b/test-custom-opencv.py
@@ -0,0 +1,34 @@
+# test-custom-opencv.py
+import use_custom_opencv
+
+# Now import OpenCV
+import cv2
+import os
+
+# Print OpenCV details
+print(f"OpenCV version: {cv2.__version__}")
+print(f"OpenCV path: {os.path.dirname(cv2.__file__)}")
+
+# Check for GStreamer support
+build_info = cv2.getBuildInformation()
+gstreamer_line = [line for line in build_info.splitlines()
+ if "GStreamer:" in line]
+print(f"GStreamer support: {gstreamer_line[0] if gstreamer_line else 'Not found'}")
+
+# Test a simple GStreamer pipeline
+try:
+ pipeline = "videotestsrc ! videoconvert ! appsink"
+ cap = cv2.VideoCapture(pipeline, cv2.CAP_GSTREAMER)
+
+ if cap.isOpened():
+ print("GStreamer pipeline working!")
+ ret, frame = cap.read()
+ if ret:
+ print(f"Successfully read frame with shape: {frame.shape}")
+ else:
+ print("Failed to read frame")
+ cap.release()
+ else:
+ print("Failed to open GStreamer pipeline")
+except Exception as e:
+ print(f"Error testing GStreamer: {e}")
\ No newline at end of file
diff --git a/test-gst.py b/test-gst.py
new file mode 100644
index 00000000..eb237ac2
--- /dev/null
+++ b/test-gst.py
@@ -0,0 +1,21 @@
+import os
+import sys
+
+# Add DLL directory
+opencv_bin = r"D:\tencent\devel\cv\opencv-4.5.5\build\install\x64\vc16\bin"
+os.add_dll_directory(opencv_bin)
+
+# Find the exact .pyd file
+pyd_dir = r"D:\tencent\devel\cv\opencv-4.5.5\build\lib\python3\Release"
+print("Looking for cv2.pyd in:", pyd_dir)
+for file in os.listdir(pyd_dir):
+ if file.endswith(".pyd"):
+ print(f"Found: {file}")
+
+# Add to path and try import
+sys.path.insert(0, pyd_dir)
+try:
+ import cv2
+ print("Success!")
+except ImportError as e:
+ print(f"Failed: {e}")
diff --git a/test_opencv_import.py b/test_opencv_import.py
new file mode 100644
index 00000000..a40c6a14
--- /dev/null
+++ b/test_opencv_import.py
@@ -0,0 +1,31 @@
+
+import os
+import sys
+
+# Add paths
+opencv_python_path = r"D:\tencent\devel\cv\opencv-4.5.5\build\lib\python3\Release"
+opencv_bin = r"D:\tencent\devel\cv\opencv-4.5.5\build\install\x64\vc16\bin"
+gst_bin = r"E:\gstreamer\1.0\msvc_x86_64\bin"
+
+# Add to Python path
+if opencv_python_path not in sys.path:
+ sys.path.insert(0, opencv_python_path)
+
+# Add to PATH
+os.environ["PATH"] = opencv_bin + os.pathsep + os.environ["PATH"]
+os.environ["PATH"] = gst_bin + os.pathsep + os.environ["PATH"]
+
+# Add DLL directories
+if sys.version_info >= (3, 8):
+ os.add_dll_directory(opencv_bin)
+ os.add_dll_directory(gst_bin)
+
+# Try importing OpenCV
+try:
+ print("Attempting to import cv2...")
+ import cv2
+ print(f"Success! OpenCV version: {cv2.__version__}")
+except Exception as e:
+ print(f"Error importing cv2: {e}")
+ import traceback
+ traceback.print_exc()
diff --git a/use-custom-opencv.py b/use-custom-opencv.py
new file mode 100644
index 00000000..13ba8a10
--- /dev/null
+++ b/use-custom-opencv.py
@@ -0,0 +1,67 @@
+import os
+import sys
+import site
+import platform
+
+def setup_opencv_environment():
+ """Configure environment for custom OpenCV with all required DLLs"""
+ print(f"Setting up OpenCV environment for Python {sys.version}")
+
+ # Paths to OpenCV and its dependencies
+ opencv_python_path = r"D:\tencent\devel\cv\opencv-4.5.5\build\lib\python3\Release"
+ opencv_bin = r"D:\tencent\devel\cv\opencv-4.5.5\build\install\x64\vc16\bin"
+ opencv_lib = r"D:\tencent\devel\cv\opencv-4.5.5\build\install\x64\vc16\lib"
+ gst_bin = r"E:\gstreamer\1.0\msvc_x86_64\bin"
+
+ # Add OpenCV module path to Python path
+ if opencv_python_path not in sys.path:
+ sys.path.insert(0, opencv_python_path)
+ print(f"Added to Python path: {opencv_python_path}")
+
+ # Add DLL directories to PATH
+ paths_to_add = [opencv_bin, opencv_lib, gst_bin]
+ for path in paths_to_add:
+ if os.path.exists(path):
+ if path not in os.environ["PATH"]:
+ os.environ["PATH"] = path + os.pathsep + os.environ["PATH"]
+ print(f"Added to PATH: {path}")
+ else:
+ print(f"WARNING: Path does not exist: {path}")
+
+ # Use add_dll_directory for Python 3.8+
+ if sys.version_info >= (3, 8):
+ for path in paths_to_add:
+ if os.path.exists(path):
+ try:
+ os.add_dll_directory(path)
+ print(f"Added DLL directory: {path}")
+ except Exception as e:
+ print(f"ERROR adding DLL directory {path}: {e}")
+
+ # Print architecture information
+ print(f"Python architecture: {platform.architecture()[0]}")
+
+ return True
+
+# Execute if run directly
+if __name__ == "__main__":
+ success = setup_opencv_environment()
+
+ if success:
+ try:
+ import cv2
+ print(f"\nOpenCV successfully loaded!")
+ print(f"OpenCV version: {cv2.__version__}")
+ print(f"OpenCV path: {os.path.dirname(cv2.__file__)}")
+
+ # Check for GStreamer support
+ build_info = cv2.getBuildInformation()
+ gstreamer_line = [line for line in build_info.splitlines()
+ if "GStreamer:" in line]
+ print(f"GStreamer support: {gstreamer_line[0] if gstreamer_line else 'Not found'}")
+ except Exception as e:
+ print(f"\nERROR loading OpenCV: {e}")
+
+ # Provide detailed error information
+ import traceback
+ traceback.print_exc()
\ No newline at end of file