diff --git a/.gitignore b/.gitignore index c83750f5..8154f988 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,21 @@ *.pyc .ipynb_checkpoints results/ +data/audio/*.wav +data/video/*.mp4 +data/video/*.mov +data/video/ma/*.mp4 +avatars.zip +avatars/ +avator_1_skipImage.zip +avator_1_skipImage/ +cuda_11.7.0_515.43.04_linux.run +eaudio --start +ffmpeg-7.0.2-amd64-static/ +ffmpeg-release-amd64-static.tar.xz +musetalk.zip +scripts.zip +venv_PhotonSync/ models/ **/__pycache__/ *.py[cod] @@ -15,4 +30,4 @@ ffmprobe* ffplay* debug exp_out -.gradio \ No newline at end of file +.gradio diff --git a/PhotonSync_Diagram.md b/PhotonSync_Diagram.md new file mode 100644 index 00000000..cbfbaea8 --- /dev/null +++ b/PhotonSync_Diagram.md @@ -0,0 +1,52 @@ +graph TD + %% Define Styles + classDef process fill:#cde4ff,stroke:#6699ff,stroke-width:2px; + classDef data fill:#e6ffcc,stroke:#99cc66,stroke-width:2px,rx:10px,ry:10px; + classDef component fill:#fff2cc,stroke:#ffcc66,stroke-width:2px; + classDef io fill:#f2f2f2,stroke:#333,stroke-width:2px,stroke-dasharray: 5 5; + classDef hardware fill:#e0e0e0,stroke:#666,stroke-width:2px,rx:5px,ry:5px; + classDef title fill:#ffffff,stroke:#ffffff,font-weight:bold,font-size:18px; + + %% One-Time Preparation Phase + prep_title("One-Time Preparation
(一次性素材准备)"):::title + prep_video(Input Video/Images
输入视频/图像) ==> prep_frames(Extract Frames
提取帧) + prep_frames --> prep_landmark(Get Face BBox & Landmarks
获取人脸框和关键点) + prep_landmark -- Face Coords (人脸坐标) --> avatar_data + prep_landmark -- Cropped Face (裁剪的人脸) --> prep_vae(VAE Encoder
VAE编码器) + prep_landmark -- Full Frame (完整帧) --> prep_parse(Face Parsing
人脸解析) + prep_vae --> prep_latents(Latent Vectors
潜向量) + prep_parse --> prep_masks(Blending Masks
融合蒙版) + prep_latents & prep_masks --> avatar_data(Avatar Data Storage
虚拟人数据存储
Frames, Coords, Latents, Masks) + + %% Sender Application Phase + sender_title("PhotonSync Sender Application
(发送端应用)"):::title + photon_gpt[PhotonGPT Audio Input
PhotonGPT音频输入] --> audio_proc(Audio Feature Extraction
音频特征提取
Whisper) + photon_gpt --> audio_enc(Audio Encoding
音频编码
GStreamerAudio / opusenc) + + audio_proc -- Audio Features (音频特征) --> rt_unet + avatar_data -- Pre-calculated Latents (预计算潜向量) --> rt_unet + + rt_unet(UNet Inference
UNet推理
Generate Lip-Synced Latents
生成口型同步的潜向量) --> rt_vae(VAE Decoder
VAE解码器
Latents to Image Frame
潜向量转图像帧) + rt_vae -- Generated Face Frame (生成的面部帧) --> rt_blend + + avatar_data -- Original Frame, Mask, Coords (原始帧、蒙版、坐标) --> rt_blend + rt_blend(Real-time Blending
实时融合
Combine face and background
合并面部与背景) -- Final Video Frame (最终视频帧) --> video_enc(Video Encoding
视频编码
GStreamerPipeline / nvh264enc) + + video_enc -- H.264 RTP Stream --> network((Network
网络)) + audio_enc -- Opus RTP Stream --> network + + %% Receiver Phase + receiver_title("Holobot Receiver
(接收端)"):::title + network -- Video Stream (视频流) --> vid_receiver(Video UDP Source
视频UDP源) + network -- Audio Stream (音频流) --> aud_receiver(Audio UDP Source
音频UDP源) + + vid_receiver --> vid_jitter(Video Jitter Buffer
视频抖动缓冲) + vid_jitter --> vid_depay(Video RTP Depayload
视频RTP解包) + vid_depay --> vid_parse(H.264 Parse
H.264解析) + vid_parse --> vid_dec(NVDEC Decode
NVDEC解码
GPU) + vid_dec --> vid_sink(Video Sink
视频接收器
d3d11videosink) + + aud_receiver --> aud_jitter(Audio Jitter Buffer
音频抖动缓冲) + aud_jitter --> aud_depay(Audio RTP Depayload
音频RTP解包) + aud_depay --> aud_parse(Opus Parse
Opus解析) + aud_parse --> aud_dec \ No newline at end of file diff --git a/Photonsync.ps1 b/Photonsync.ps1 new file mode 100644 index 00000000..7894afb7 --- /dev/null +++ b/Photonsync.ps1 @@ -0,0 +1,2 @@ +python -m scripts.realtime_stream_gst_15 --inference_config configs/inference/realtime.yaml --skip_save_images + diff --git a/configs/inference/realtime-stable.yaml b/configs/inference/realtime-stable.yaml new file mode 100644 index 00000000..ea67de19 --- /dev/null +++ b/configs/inference/realtime-stable.yaml @@ -0,0 +1,7 @@ +avator_1: + preparation: False + bbox_shift: 5 + video_path: "data/video/sun.mp4" + audio_clips: + audio_0: "data/audio/sun.wav" + audio_1: "data/audio/yongen.wav" diff --git a/configs/inference/realtime.yaml b/configs/inference/realtime.yaml index 9319e987..c740a758 100644 --- a/configs/inference/realtime.yaml +++ b/configs/inference/realtime.yaml @@ -1,10 +1,8 @@ -avator_1: - preparation: True # your can set it to False if you want to use the existing avator, it will save time - bbox_shift: 5 - video_path: "data/video/yongen.mp4" +avatar_3: + preparation: False + bbox_shift: 0 + batch_size: 16 + video_path: "data/video/aiden-glasses-processed.mp4" audio_clips: - audio_0: "data/audio/yongen.wav" - audio_1: "data/audio/eng.wav" - - - + audio_0: "data/audio/sun.wav" + audio_1: "data/audio/yongen.wav" diff --git a/data/video/musk.png b/data/video/musk.png new file mode 100644 index 00000000..bca50eef Binary files /dev/null and b/data/video/musk.png differ diff --git a/data/video/younglook.7z b/data/video/younglook.7z new file mode 100644 index 00000000..9d8a028e Binary files /dev/null and b/data/video/younglook.7z differ diff --git a/download-gi-deps.py b/download-gi-deps.py new file mode 100644 index 00000000..79f647ec --- /dev/null +++ b/download-gi-deps.py @@ -0,0 +1,40 @@ +import os +import urllib.request +import zipfile +import shutil + +# URLs for missing dependencies +dll_sources = { + "z.dll": "https://github.com/winlibs/zlib/releases/download/zlib-1.3/zlib-1.3-msvc-x64.zip", + "intl-8.dll": "https://github.com/mlocati/gettext-iconv-windows/releases/download/v0.21-v1.16/gettext0.21-iconv1.16-static-64.zip" +} + +download_dir = "gtk_deps_download" +os.makedirs(download_dir, exist_ok=True) + +# Download and extract dependencies +for dll_name, url in dll_sources.items(): + zip_path = os.path.join(download_dir, f"{dll_name}.zip") + print(f"Downloading {url}") + urllib.request.urlretrieve(url, zip_path) + + print(f"Extracting {zip_path}") + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + zip_ref.extractall(download_dir) + +# Copy DLLs to GTK bin directory +gtk_bin = r'C:\gtk\bin' +for root, dirs, files in os.walk(download_dir): + for file in files: + if file.lower().endswith('.dll'): + source = os.path.join(root, file) + dest = os.path.join(gtk_bin, file) + print(f"Copying {source} to {dest}") + shutil.copy2(source, dest) + + # Also create lib* version + lib_dest = os.path.join(gtk_bin, f"lib{file}") + print(f"Creating lib version at {lib_dest}") + shutil.copy2(source, lib_dest) + +print("Done installing dependencies!") diff --git a/find -dll-dep.py b/find -dll-dep.py new file mode 100644 index 00000000..fe822fc9 --- /dev/null +++ b/find -dll-dep.py @@ -0,0 +1,36 @@ +import os +import sys +import ctypes + +# Add the OpenCV bin directory to the DLL search path +opencv_bin = r"D:\tencent\devel\cv\opencv-4.5.5\build\install\x64\vc16\bin" +os.add_dll_directory(opencv_bin) + +# Add the directory containing the cv2.pyd file to the Python path +opencv_pyd_dir = r"D:\tencent\devel\cv\opencv-4.5.5\build\lib\python3\Release" +sys.path.insert(0, opencv_pyd_dir) + +# Pre-load only the essential DLLs (skip highgui) +essential_dlls = [ + "opencv_core455.dll", + "opencv_imgproc455.dll", + "opencv_imgcodecs455.dll", + "opencv_videoio455.dll", + "opencv_flann455.dll", + "opencv_features2d455.dll" +] + +for dll in essential_dlls: + try: + dll_path = os.path.join(opencv_bin, dll) + ctypes.CDLL(dll_path) + print(f"Successfully pre-loaded {dll}") + except Exception as e: + print(f"Failed to load {dll}: {e}") + +# Now try importing cv2 +try: + import cv2 + print(f"\nSuccess! OpenCV version: {cv2.__version__}") +except ImportError as e: + print(f"\nStill failed: {e}") diff --git a/fix_gtk.py b/fix_gtk.py new file mode 100644 index 00000000..ca706b55 --- /dev/null +++ b/fix_gtk.py @@ -0,0 +1,71 @@ +import os +import sys +import ctypes +from ctypes import windll +import glob +import platform + +def inspect_gtk_installation(): + print(f"\n#### GTK Installation Info ####") + gtk_bin = r'C:\gtk\bin' + gtk_lib = r'C:\gtk\lib' + + # Add both directories to PATH + os.environ['PATH'] = f"{gtk_bin};{gtk_lib};{os.environ['PATH']}" + print(f"Python version: {platform.python_version()}") + + # Set additional environment variables + os.environ['GI_TYPELIB_PATH'] = r'C:\gtk\lib\girepository-1.0' + print(f"GI_TYPELIB_PATH: {os.environ.get('GI_TYPELIB_PATH', 'Not set')}") + + # Find all DLLs in GTK directories + bin_dlls = glob.glob(os.path.join(gtk_bin, "*.dll")) + lib_dlls = glob.glob(os.path.join(gtk_lib, "*.dll")) + print(f"Found {len(bin_dlls)} DLLs in {gtk_bin}") + print(f"Found {len(lib_dlls)} DLLs in {gtk_lib}") + + # Critical DLLs that need to be loaded in the right order + critical_dlls = [ + os.path.join(gtk_bin, "glib-2.0-0.dll"), + os.path.join(gtk_bin, "gobject-2.0-0.dll"), + os.path.join(gtk_bin, "gmodule-2.0-0.dll"), + os.path.join(gtk_bin, "girepository-1.0-1.dll"), + os.path.join(gtk_bin, "gio-2.0-0.dll"), + os.path.join(gtk_bin, "ffi-8.dll"), + os.path.join(gtk_bin, "z.dll"), + os.path.join(gtk_bin, "libintl-8.dll") + ] + + # Try to load critical DLLs first + print("\n#### Loading Critical DLLs ####") + for dll in critical_dlls: + if os.path.exists(dll): + try: + windll.LoadLibrary(dll) + print(f"✓ Loaded {os.path.basename(dll)}") + except Exception as e: + print(f"✗ Failed loading {os.path.basename(dll)}: {e}") + else: + print(f"! Missing {os.path.basename(dll)}") + + # Try importing gi + print("\n#### Testing PyGObject Import ####") + try: + import gi + print(f"✓ Successfully imported gi ({gi.__file__})") + return True + except ImportError as e: + print(f"✗ Failed to import gi: {e}") + return False + +if __name__ == "__main__": + success = inspect_gtk_installation() + + if not success: + print("\n#### Troubleshooting Tips ####") + print("1. Your PyGObject installation might not be compatible with your GTK installation.") + print("2. Try reinstalling PyGObject with:") + print(" pip uninstall pygobject") + print(" pip install pygobject") + print("3. If that doesn't work, try installing from an alternate source:") + print(" pip install --no-binary :all: pygobject") diff --git a/holobot-receiver.ps1 b/holobot-receiver.ps1 new file mode 100644 index 00000000..9addd689 --- /dev/null +++ b/holobot-receiver.ps1 @@ -0,0 +1,44 @@ +<# +.SYNOPSIS + Launches the GStreamer pipeline to decode and play the MuseTalk real-time A/V stream. + +.DESCRIPTION + This script uses a "zero-copy" video pipeline. The video frame is decoded on the GPU + and rendered directly to the screen using d3d11videosink without ever being copied + to system RAM, providing the lowest possible latency and highest performance. + +.NOTES + - Requires GStreamer 1.0 (with msvc_x86_64 and nvcodec packages) to be installed and in the system's PATH. + - Run this script from a PowerShell terminal. + - To stop the stream, press Ctrl+C. +#> + +# --- Configuration --- +$videoPort = 5000 +$audioPort = 5001 + +# --- User Feedback --- +Write-Host "🚀 Launching GStreamer Decoder (Zero-Copy Video Pipeline)..." -ForegroundColor Green +Write-Host " - Listening for VIDEO on UDP port: $videoPort" +Write-Host " - Listening for AUDIO on UDP port: $audioPort" +Write-Host " (Press Ctrl+C to stop the stream)" +Write-Host "" + + +# --- Launch GStreamer Pipeline --- +gst-launch-1.0 -v ` + udpsrc port=$videoPort caps="application/x-rtp, media=video, clock-rate=90000, encoding-name=H264, payload=96" ` + ! rtpjitterbuffer latency=200 ` + ! queue ` + ! rtph264depay ` + ! h264parse ` + ! nvh264dec ` + ! d3d11videosink sync=true qos=true max-lateness=200000000 ` +` + udpsrc port=$audioPort caps="application/x-rtp, media=audio, clock-rate=48000, encoding-name=OPUS, payload=97" ` + ! rtpjitterbuffer latency=375 drop-on-latency=true ` + ! rtpopusdepay ` + ! opusdec plc=true ` + ! audioconvert ` + ! audioresample ` + ! wasapisink sync=true \ No newline at end of file diff --git a/musetalk/utils/blending.py b/musetalk/utils/blending.py old mode 100755 new mode 100644 diff --git a/musetalk/utils/face_parsing/resnet.py b/musetalk/utils/face_parsing/resnet.py index e2e5d87e..a306abb7 100755 --- a/musetalk/utils/face_parsing/resnet.py +++ b/musetalk/utils/face_parsing/resnet.py @@ -80,7 +80,7 @@ def forward(self, x): return feat8, feat16, feat32 def init_weight(self, model_path): - state_dict = torch.load(model_path) #modelzoo.load_url(resnet18_url) + state_dict = torch.load(model_path, weights_only=False) #modelzoo.load_url(resnet18_url) self_state_dict = self.state_dict() for k, v in state_dict.items(): if 'fc' in k: continue diff --git a/requirements-pip.txt b/requirements-pip.txt new file mode 100644 index 00000000..9f628804 Binary files /dev/null and b/requirements-pip.txt differ diff --git a/requirements.txt b/requirements.txt index e87aa41d..b2f95cd3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ diffusers==0.30.2 accelerate==0.28.0 -numpy==1.23.5 -tensorflow==2.12.0 -tensorboard==2.12.0 +numpy +tensorflow +tensorboard opencv-python==4.9.0.80 soundfile==0.12.1 transformers==4.39.2 diff --git a/scripts/.env b/scripts/.env new file mode 100644 index 00000000..c1e10c3a --- /dev/null +++ b/scripts/.env @@ -0,0 +1,27 @@ +# .env file for realtime_stream_sync watcher + +# --- Required --- +# Full path to the WAV file that PhotonGPT writes +#WATCHED_WAV_FILE_PATH=E:\devel\NanoAR\HoloBot\PhotonGPT\latest_response.opus + +# Path to the avatar configuration YAML file used by realtime_stream_sync +AVATAR_CONFIG_PATH=configs/inference/realtime.yaml # Or your actual path + +# The specific Avatar ID from the config file to use for lipsyncing +AVATAR_ID_TO_USE=avatar_3 # Replace with your actual avatar ID from the YAML + +# --- Optional (Defaults will be used if not set) --- +# Target FPS for the GStreamer output (should match video source) +TARGET_FPS=25 + +# Frame skipping parameters - lower values = more aggressive +FRAME_SKIP_THRESHOLD=1000 +OVERLOAD_MULTIPLIER=10.0 +MAX_FRAMES_TO_SKIP=0 + +#FRAME_SKIP_THRESHOLD=2 +#OVERLOAD_MULTIPLIER=1.2 +#MAX_FRAMES_TO_SKIP=1 + +STREAM_PIPE_PATH=D:\photon_audio.opus +GSTREAMER_LAUNCH_PATH=C:\gstreamer\1.0\msvc_x86_64\bin\gst-launch-1.0.exe diff --git a/scripts/Downloads - Shortcut.lnk b/scripts/Downloads - Shortcut.lnk new file mode 100644 index 00000000..e1e17cb8 Binary files /dev/null and b/scripts/Downloads - Shortcut.lnk differ diff --git a/scripts/audiostream.py b/scripts/audiostream.py new file mode 100644 index 00000000..d57d0180 --- /dev/null +++ b/scripts/audiostream.py @@ -0,0 +1,122 @@ +import gi +import numpy as np +import subprocess +import time +import ffmpeg + +gi.require_version('Gst', '1.0') +from gi.repository import Gst, GLib, GObject + + +class GStreamerAudio: + """ 使用 GStreamer 进行音频推流 """ + + def __init__(self): + Gst.init(None) + + # 创建 GStreamer 管道 + self.pipeline = Gst.parse_launch( + "appsrc name=audio_source format=time is-live=true " + "caps=audio/x-raw,format=S16LE,channels=2,rate=48000,layout=interleaved ! " + "queue ! audioconvert ! queue ! audioresample ! " + "queue ! opusenc ! queue ! rtpopuspay ! " + "udpsink host=127.0.0.1 port=5001 sync=false" + ) + + self.appsrc = self.pipeline.get_by_name("audio_source") + self.appsrc.set_property("blocksize", 65536) # 增大 blocksize 避免数据阻塞 + self.appsrc.set_property("format", Gst.Format.TIME) + + self.pipeline.set_state(Gst.State.PLAYING) + + def send_audio(self, audio_data): + """ 发送 PCM 音频数据到 GStreamer """ + buffer = Gst.Buffer.new_allocate(None, len(audio_data.tobytes()), None) + buffer.fill(0, audio_data.tobytes()) + self.appsrc.emit("push-buffer", buffer) + print(f"✅ 推送音频: {len(audio_data)} samples") + + def stop(self): + """ 关闭音频推流 """ + self.pipeline.set_state(Gst.State.NULL) + print("✅ GStreamer 音频推流已关闭") + + +class FFmpegAudioReader: + """ 使用 FFmpeg 读取整个音频文件,并转换为 PCM """ + + def __init__(self, audio_file): + self.audio_file = audio_file + probe = ffmpeg.probe(audio_file) + self.sample_rate = int(probe['streams'][0]['sample_rate']) + self.channels = int(probe['streams'][0]['channels']) + + def read_full_audio(self): + """ 使用 FFmpeg 读取整个音频文件 """ + process = subprocess.Popen( + ["ffmpeg", "-i", self.audio_file, "-f", "s16le", "-ac", "2", "-ar", "48000", "-"], + stdout=subprocess.PIPE, stderr=subprocess.DEVNULL + ) + raw_data = process.stdout.read() + process.stdout.close() + process.wait() + + if not raw_data: + print("❌ 读取音频文件失败!") + return None + + audio_data = np.frombuffer(raw_data, dtype=np.int16).reshape(-1, 2) + print(f"✅ 读取完整音频,共 {len(audio_data)} samples") + return audio_data + + +def split_audio(audio_data, num_chunks): + """ 将音频数据分割成 num_chunks 份 """ + if num_chunks <= 1: + return [audio_data] + + chunk_size = len(audio_data) // num_chunks + chunks = [audio_data[i * chunk_size: (i + 1) * chunk_size] for i in range(num_chunks)] + + # 如果有剩余样本,加到最后一个 chunk + remainder = len(audio_data) % num_chunks + if remainder > 0: + chunks[-1] = np.vstack((chunks[-1], audio_data[-remainder:])) + + print(f"✅ 音频已分割为 {num_chunks} 份,每份约 {chunk_size} samples") + return chunks + + +def main(audio_file, split_chunks=1): + """ 读取完整音频文件,并用 GStreamer 推流 """ + audio_reader = FFmpegAudioReader(audio_file) + audio_data = audio_reader.read_full_audio() + + if audio_data is None: + return + + # 分割音频 + audio_chunks = split_audio(audio_data, split_chunks) + + gst_audio = GStreamerAudio() + try: + for i, chunk in enumerate(audio_chunks): + gst_audio.send_audio(chunk) + print(f"📡 推送第 {i+1}/{split_chunks} 份音频...") + time.sleep(0.1) # 控制发送速率,避免堵塞 + except KeyboardInterrupt: + print("\n⏹️ 手动停止音频流") + finally: + gst_audio.stop() + + +if __name__ == "__main__": + import sys + if len(sys.argv) < 2: + print("❌ 用法: python audiostream.py <音频文件路径> [分块数量]") + sys.exit(1) + + audio_file = sys.argv[1] + split_chunks = int(sys.argv[2]) if len(sys.argv) > 2 else 1 # 默认为 1,不分割 + + main(audio_file, split_chunks) diff --git a/scripts/realtime_stream.py b/scripts/realtime_stream.py new file mode 100644 index 00000000..16263498 --- /dev/null +++ b/scripts/realtime_stream.py @@ -0,0 +1,460 @@ +import ffmpeg +import argparse +import os +import pyaudio +import gi +gi.require_version('Gst', '1.0') +from gi.repository import Gst, GLib +import threading +import queue +from omegaconf import OmegaConf +import soundfile as sf +import wave +import subprocess +import numpy as np +import cv2 +import torch +import glob +import pickle +import sys +from tqdm import tqdm +import copy +import json +from musetalk.utils.utils import get_file_type,get_video_fps,datagen +from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder +from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending +from musetalk.utils.utils import load_all_model +import shutil + +import threading +import queue + +import time + +# load model weights +audio_processor, vae, unet, pe = load_all_model() +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +timesteps = torch.tensor([0], device=device) +pe = pe.half() +vae.vae = vae.vae.half() +unet.model = unet.model.half() + +class GStreamerAudio: + """ 使用 GStreamer 进行音频推流 """ + + def __init__(self): + Gst.init(None) + + # 创建 GStreamer 管道 + self.pipeline = Gst.parse_launch( + "appsrc name=audio_source format=time is-live=true " + "caps=audio/x-raw,format=S16LE,channels=2,rate=48000,layout=interleaved ! " + "queue ! audioconvert ! queue ! audioresample ! " + "queue ! opusenc ! queue ! rtpopuspay ! " + "udpsink host=127.0.0.1 port=5001 sync=false" + ) + + self.appsrc = self.pipeline.get_by_name("audio_source") + self.appsrc.set_property("blocksize", 65536) # 增加 blocksize,避免数据阻塞 + self.appsrc.set_property("format", Gst.Format.TIME) + + self.pipeline.set_state(Gst.State.PLAYING) + + def send_audio(self, audio_data): + """ 一次性发送完整的 PCM 音频数据到 GStreamer """ + buffer = Gst.Buffer.new_allocate(None, len(audio_data.tobytes()), None) + buffer.fill(0, audio_data.tobytes()) + self.appsrc.emit("push-buffer", buffer) + print(f"✅ 已推送完整音频,共 {len(audio_data)} samples") + + def stop(self): + """ 关闭音频推流 """ + self.pipeline.set_state(Gst.State.NULL) + print("✅ GStreamer 音频推流已关闭") + +class FFmpegAudioReader: + """ 使用 FFmpeg 读取整个音频文件,并转换为 PCM """ + + def __init__(self, audio_file): + self.audio_file = audio_file + probe = ffmpeg.probe(audio_file) + self.sample_rate = int(probe['streams'][0]['sample_rate']) + self.channels = int(probe['streams'][0]['channels']) + + def read_full_audio(self): + """ 使用 FFmpeg 读取整个音频文件 """ + process = subprocess.Popen( + ["ffmpeg", "-i", self.audio_file, "-f", "s16le", "-ac", "2", "-ar", "48000", "-"], + stdout=subprocess.PIPE, stderr=subprocess.DEVNULL + ) + raw_data = process.stdout.read() + process.stdout.close() + process.wait() + + if not raw_data: + print("❌ 读取音频文件失败!") + return None + + audio_data = np.frombuffer(raw_data, dtype=np.int16).reshape(-1, 2) + print(f"✅ 读取完整音频,共 {len(audio_data)} samples") + return audio_data + +def split_audio(audio_data, num_chunks): + """ 将音频数据分割成 num_chunks 份 """ + if num_chunks <= 1: + return [audio_data] + + chunk_size = len(audio_data) // num_chunks + chunks = [audio_data[i * chunk_size: (i + 1) * chunk_size] for i in range(num_chunks)] + + # 如果有剩余样本,加到最后一个 chunk + remainder = len(audio_data) % num_chunks + if remainder > 0: + chunks[-1] = np.vstack((chunks[-1], audio_data[-remainder:])) + + print(f"✅ 音频已分割为 {num_chunks} 份,每份约 {chunk_size} samples") + return chunks + +class GStreamerPipeline: + def __init__(self, width=640, height=480, fps=25, host="127.0.0.1", port=5000): + self.width = width + self.height = height + self.fps = fps + self.host = host + self.port = port + + # 🎬 GStreamer 推流管道 + self.GSTREAMER_PIPELINE = ( + "appsrc ! videoconvert ! video/x-raw ! " + "queue ! x264enc bitrate=8000 tune=zerolatency ! " + "rtph264pay ! udpsink host=127.0.0.1 port=5000" + ) + + + # 使用 OpenCV 初始化 GStreamer 视频写入 + self.video_writer = cv2.VideoWriter(self.GSTREAMER_PIPELINE, cv2.CAP_GSTREAMER, 0, 25, (640, 480), True) + + if not self.video_writer.isOpened(): + raise RuntimeError("❌ GStreamer 推流初始化失败!请检查 GStreamer 是否安装") + + def send_frame(self, frame): + """ 发送一帧到 GStreamer """ + if frame is None: + return + + frame = cv2.resize(frame, (self.width, self.height)) # 确保帧大小匹配 + frame = frame.astype(np.uint8) # 转换为 uint8 格式 + self.video_writer.write(frame) # 推流 + + def stop(self): + """ 关闭 GStreamer 推流 """ + self.video_writer.release() + + + +def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000): + cap = cv2.VideoCapture(vid_path) + count = 0 + while True: + if count > cut_frame: + break + ret, frame = cap.read() + if ret: + cv2.imwrite(f"{save_path}/{count:08d}.png", frame) + count += 1 + else: + break + +def osmakedirs(path_list): + for path in path_list: + os.makedirs(path) if not os.path.exists(path) else None + + +@torch.no_grad() +class Avatar: + def __init__(self, avatar_id, video_path, bbox_shift, batch_size, preparation): + self.avatar_id = avatar_id + self.video_path = video_path + self.bbox_shift = bbox_shift + self.avatar_path = f"./results/avatars/{avatar_id}" + self.full_imgs_path = f"{self.avatar_path}/full_imgs" + self.coords_path = f"{self.avatar_path}/coords.pkl" + self.latents_out_path= f"{self.avatar_path}/latents.pt" + self.video_out_path = f"{self.avatar_path}/vid_output/" + self.mask_out_path =f"{self.avatar_path}/mask" + self.mask_coords_path =f"{self.avatar_path}/mask_coords.pkl" + self.avatar_info_path = f"{self.avatar_path}/avator_info.json" + self.avatar_info = { + "avatar_id":avatar_id, + "video_path":video_path, + "bbox_shift":bbox_shift + } + self.preparation = preparation + self.batch_size = batch_size + self.idx = 0 + self.init() + + def init(self): + if self.preparation: + if os.path.exists(self.avatar_path): + response = input(f"{self.avatar_id} exists, Do you want to re-create it ? (y/n)") + if response.lower() == "y": + shutil.rmtree(self.avatar_path) + print("*********************************") + print(f" creating avator: {self.avatar_id}") + print("*********************************") + osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path]) + self.prepare_material() + else: + self.input_latent_list_cycle = torch.load(self.latents_out_path) + with open(self.coords_path, 'rb') as f: + self.coord_list_cycle = pickle.load(f) + input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]')) + input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) + self.frame_list_cycle = read_imgs(input_img_list) + with open(self.mask_coords_path, 'rb') as f: + self.mask_coords_list_cycle = pickle.load(f) + input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]')) + input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) + self.mask_list_cycle = read_imgs(input_mask_list) + else: + print("*********************************") + print(f" creating avator: {self.avatar_id}") + print("*********************************") + osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path]) + self.prepare_material() + else: + if not os.path.exists(self.avatar_path): + print(f"{self.avatar_id} does not exist, you should set preparation to True") + sys.exit() + + with open(self.avatar_info_path, "r") as f: + avatar_info = json.load(f) + + if avatar_info['bbox_shift'] != self.avatar_info['bbox_shift']: + response = input(f" 【bbox_shift】 is changed, you need to re-create it ! (c/continue)") + if response.lower() == "c": + shutil.rmtree(self.avatar_path) + print("*********************************") + print(f" creating avator: {self.avatar_id}") + print("*********************************") + osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path]) + self.prepare_material() + else: + sys.exit() + else: + self.input_latent_list_cycle = torch.load(self.latents_out_path) + with open(self.coords_path, 'rb') as f: + self.coord_list_cycle = pickle.load(f) + input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]')) + input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) + self.frame_list_cycle = read_imgs(input_img_list) + with open(self.mask_coords_path, 'rb') as f: + self.mask_coords_list_cycle = pickle.load(f) + input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]')) + input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) + self.mask_list_cycle = read_imgs(input_mask_list) + + def prepare_material(self): + print("preparing data materials ... ...") + with open(self.avatar_info_path, "w") as f: + json.dump(self.avatar_info, f) + + if os.path.isfile(self.video_path): + video2imgs(self.video_path, self.full_imgs_path, ext = 'png') + else: + print(f"copy files in {self.video_path}") + files = os.listdir(self.video_path) + files.sort() + files = [file for file in files if file.split(".")[-1]=="png"] + for filename in files: + shutil.copyfile(f"{self.video_path}/{filename}", f"{self.full_imgs_path}/{filename}") + input_img_list = sorted(glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))) + + print("extracting landmarks...") + coord_list, frame_list = get_landmark_and_bbox(input_img_list, self.bbox_shift) + input_latent_list = [] + idx = -1 + # maker if the bbox is not sufficient + coord_placeholder = (0.0,0.0,0.0,0.0) + for bbox, frame in zip(coord_list, frame_list): + idx = idx + 1 + if bbox == coord_placeholder: + continue + x1, y1, x2, y2 = bbox + crop_frame = frame[y1:y2, x1:x2] + resized_crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4) + latents = vae.get_latents_for_unet(resized_crop_frame) + input_latent_list.append(latents) + + self.frame_list_cycle = frame_list + frame_list[::-1] + self.coord_list_cycle = coord_list + coord_list[::-1] + self.input_latent_list_cycle = input_latent_list + input_latent_list[::-1] + self.mask_coords_list_cycle = [] + self.mask_list_cycle = [] + + for i,frame in enumerate(tqdm(self.frame_list_cycle)): + cv2.imwrite(f"{self.full_imgs_path}/{str(i).zfill(8)}.png",frame) + + face_box = self.coord_list_cycle[i] + mask,crop_box = get_image_prepare_material(frame,face_box) + cv2.imwrite(f"{self.mask_out_path}/{str(i).zfill(8)}.png",mask) + self.mask_coords_list_cycle += [crop_box] + self.mask_list_cycle.append(mask) + + with open(self.mask_coords_path, 'wb') as f: + pickle.dump(self.mask_coords_list_cycle, f) + + with open(self.coords_path, 'wb') as f: + pickle.dump(self.coord_list_cycle, f) + + torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path)) + # + + def process_frames(self, + res_frame_queue, + video_len, + skip_save_images): + print(video_len) + while True: + if self.idx>=video_len-1: + break + try: + start = time.time() + res_frame = res_frame_queue.get(block=True, timeout=1) + except queue.Empty: + continue + + bbox = self.coord_list_cycle[self.idx%(len(self.coord_list_cycle))] + ori_frame = copy.deepcopy(self.frame_list_cycle[self.idx%(len(self.frame_list_cycle))]) + x1, y1, x2, y2 = bbox + try: + res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1)) + except: + continue + mask = self.mask_list_cycle[self.idx%(len(self.mask_list_cycle))] + mask_crop_box = self.mask_coords_list_cycle[self.idx%(len(self.mask_coords_list_cycle))] + #combine_frame = get_image(ori_frame,res_frame,bbox) + combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box) + + if skip_save_images is False: + cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png",combine_frame) + self.idx = self.idx + 1 + + def inference(self, + audio_path, + out_vid_name, + fps, + skip_save_images): + os.makedirs(self.avatar_path+'/tmp',exist_ok =True) + print("start inference") + + gst_pipeline = GStreamerPipeline(width=640, height=480, fps=fps, host="127.0.0.1", port=5000) + audio_sender = GStreamerAudio() + ############################################## extract audio feature ############################################## + start_time = time.time() + whisper_feature = audio_processor.audio2feat(audio_path) + whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps) + + total_iters = int(np.ceil(float(len(whisper_chunks)) / self.batch_size)) + audio_reader = FFmpegAudioReader(audio_path) + audio_data = audio_reader.read_full_audio() + audio_chunks = split_audio(audio_data, total_iters) + # for i, chunk in enumerate(audio_chunks): + # print(f"🎧 播放第 {i+1}/{len(audio_chunks)} 个音频片段") + # play_audio_chunk(chunk, original_sample_rate) + # print(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms") + + ############################################## inference batch by batch ############################################## + video_num = len(whisper_chunks) + gen = datagen(whisper_chunks, self.input_latent_list_cycle, self.batch_size) + + frame_count = 0 + start_time = time.time() # 记录第一帧推理开始时间 + + for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=int(np.ceil(float(video_num) / self.batch_size)))): + # 处理音频特征 + audio_feature_batch = torch.from_numpy(whisper_batch) + audio_feature_batch = audio_feature_batch.to(device=unet.device, dtype=unet.model.dtype) + audio_feature_batch = pe(audio_feature_batch) + + # 处理 `latent_batch` + latent_batch = latent_batch.to(dtype=unet.model.dtype) + + # 运行 UNet 生成嘴型动画 + pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample + recon = vae.decode_latents(pred_latents) + + audio_sender.send_audio(audio_chunks[i]) + # 逐帧推送到 GStreamer + for j, res_frame in enumerate(recon): + frame_count += 1 + print("✅ 正在推送视频帧...") + gst_pipeline.send_frame(res_frame) # **顺序推送** + # 计算 FPS + # elapsed_time = time.time() - start_time + # if elapsed_time > 0: + # fps_estimate = frame_count / elapsed_time + # print(f"当前直播帧率: {fps_estimate:.2f} FPS", end='\r') + + ############################################## + # Step 4: 结束后计算最终 FPS + ############################################## + total_elapsed_time = time.time() - start_time + print(f"\nelapsed_time: {total_elapsed_time:.2f} s") + print(f"\nframe_count: {frame_count:.2f} s") + avg_fps = frame_count / total_elapsed_time if total_elapsed_time > 0 else 0 + print(f"\n最终计算得到的平均帧率: {avg_fps:.2f} FPS") + gst_pipeline.stop() + audio_sender.stop() + + + + +if __name__ == "__main__": + ''' + This script is used to simulate online chatting and applies necessary pre-processing such as face detection and face parsing in advance. During online chatting, only UNet and the VAE decoder are involved, which makes MuseTalk real-time. + ''' + + parser = argparse.ArgumentParser() + parser.add_argument("--inference_config", + type=str, + default="configs/inference/realtime.yaml", + ) + parser.add_argument("--fps", + type=int, + default=25, + ) + parser.add_argument("--batch_size", + type=int, + default=4, + ) + parser.add_argument("--skip_save_images", + action="store_true", + help="Whether skip saving images for better generation speed calculation", + ) + + args = parser.parse_args() + + inference_config = OmegaConf.load(args.inference_config) + print(inference_config) + + + for avatar_id in inference_config: + data_preparation = inference_config[avatar_id]["preparation"] + video_path = inference_config[avatar_id]["video_path"] + bbox_shift = inference_config[avatar_id]["bbox_shift"] + avatar = Avatar( + avatar_id = avatar_id, + video_path = video_path, + bbox_shift = bbox_shift, + batch_size = args.batch_size, + preparation= data_preparation) + + audio_clips = inference_config[avatar_id]["audio_clips"] + for audio_num, audio_path in audio_clips.items(): + print("Inferring using:",audio_path) + avatar.inference(audio_path, + audio_num, + args.fps, + args.skip_save_images) diff --git a/scripts/realtime_stream_gst_15.py b/scripts/realtime_stream_gst_15.py new file mode 100644 index 00000000..b92f52ec --- /dev/null +++ b/scripts/realtime_stream_gst_15.py @@ -0,0 +1,1455 @@ +import argparse +import os +from omegaconf import OmegaConf +import numpy as np +import cv2 +import torch +import glob +import pickle +import sys +from tqdm import tqdm +import copy +import json +import shutil +import threading +import queue +import time +from termcolor import colored # Add this import +import subprocess +import io +import concurrent.futures +import traceback +import logging +from dotenv import load_dotenv # Keep this at the very top! +from PIL import Image +import tempfile +import ffmpeg +import soundfile as sf # Added for sf.write, assumed to be imported +import cpuinfo # For detailed CPU information + +print("Script started!") + +# --- Load environment variables FIRST to make them available globally --- +load_dotenv() +logging.info("Environment variables loaded from .env file.") + +# --- Global Configuration from .env or Hardcoded Fallbacks --- +# These values will now be sourced directly from the .env file if set, +# otherwise they will fall back to the provided default. +# This makes them accessible globally without needing 'args.' prefix. + +# GStreamer related paths and parameters +GSTREAMER_LAUNCH_PATH = os.getenv("GSTREAMER_LAUNCH_PATH", "gst-launch-1.0") # Global path for gst-launch +# NEW - Add this sanitization immediately after +if "GSTREAMER_LAUNCH_PATH=" in GSTREAMER_LAUNCH_PATH: + # Remove the erroneous key prefix if it exists + GSTREAMER_LAUNCH_PATH = GSTREAMER_LAUNCH_PATH.split("GSTREAMER_LAUNCH_PATH=", 1)[-1] + logging.warning(f"⚠️ Cleaned malformed GSTREAMER_LAUNCH_PATH value. Now: {GSTREAMER_LAUNCH_PATH}") + +# Also add debug logging to see what was actually loaded +logging.info(f"🔍 Final GSTREAMER_LAUNCH_PATH: '{GSTREAMER_LAUNCH_PATH}'") +STREAM_PIPE_PATH = os.getenv("STREAM_PIPE_PATH", "./hot_file.opus") +TARGET_FPS = int(os.getenv("TARGET_FPS", "25")) # Use for general FPS calculations +FRAME_SKIP_THRESHOLD = int(os.getenv("FRAME_SKIP_THRESHOLD", "3")) # Use for consumer queue management +OVERLOAD_MULTIPLIER = float(os.getenv("OVERLOAD_MULTIPLIER", "1.2")) +MAX_FRAMES_TO_SKIP = int(os.getenv("MAX_FRAMES_TO_SKIP", "1")) + +# MuseTalk Model Paths (primarily from .env) +MUSE_VERSION = os.getenv("MUSE_VERSION", "v15") +FFMPEG_PATH = os.getenv("FFMPEG_PATH", "./ffmpeg-4.4-amd64-static/") +GPU_ID = int(os.getenv("GPU_ID", "0")) +VAE_TYPE = os.getenv("VAE_TYPE", "sd-vae") +UNET_CONFIG_PATH = os.getenv("UNET_CONFIG", "./models/musetalkV15/musetalk.json") +UNET_MODEL_PATH = os.getenv("UNET_MODEL_PATH", "./models/musetalkV15/unet.pth") +WHISPER_DIR = os.getenv("WHISPER_DIR", "./models/whisper") +RESULT_DIR = os.getenv("RESULT_DIR", './results') + +# Avatar/Inference specific parameters (primarily from .env) +EXTRA_MARGIN = int(os.getenv("EXTRA_MARGIN", "10")) +AUDIO_PADDING_LENGTH_LEFT = int(os.getenv("AUDIO_PADDING_LEFT", "2")) +AUDIO_PADDING_LENGTH_RIGHT = int(os.getenv("AUDIO_PADDING_RIGHT", "2")) +BATCH_SIZE = int(os.getenv("BATCH_SIZE", "4")) +PARSING_MODE = os.getenv("PARSING_MODE", 'jaw') +LEFT_CHEEK_WIDTH = int(os.getenv("LEFT_CHEEK_WIDTH", "90")) +RIGHT_CHEEK_WIDTH = int(os.getenv("RIGHT_CHEEK_WIDTH", "90")) +AVATAR_CONFIG_PATH = os.getenv("AVATAR_CONFIG_PATH", "configs/inference/realtime.yaml") +AVATAR_ID_TO_USE = os.getenv("AVATAR_ID_TO_USE", "default_avatar_id") + + +# --- Set up basic logging --- +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', force=True) # Add force=True if Python >= 3.8 +print("Logging configured!") + +# --- PyTorch Device Setup & System Information Block [CORRECTED] --- +logging.info("--- System & Hardware Information ---") + +# PyTorch Version +try: + torch_version = torch.__version__ + logging.info(colored(f"PyTorch Version: {torch_version}", 'cyan')) +except Exception as e: + logging.warning(f"Could not determine PyTorch version: {e}") + +# CPU Information +try: + cpu_info = cpuinfo.get_cpu_info() + cpu_brand = cpu_info.get('brand_raw', 'N/A') + cpu_hz = cpu_info.get('hz_actual_friendly', 'N/A') + logging.info(colored(f"CPU: {cpu_brand} @ {cpu_hz}", 'cyan')) +except Exception as e: + logging.warning(f"Could not determine CPU info: {e}") + + +# --- FIX: Define the 'device' object here, before it is used --- +cuda_available = torch.cuda.is_available() +if cuda_available: + device = torch.device(f"cuda:{GPU_ID}") +else: + device = torch.device("cpu") +# ------------------------------------------------------------- + + +# CUDA and GPU Information +if cuda_available: + try: + # Get general device properties + properties = torch.cuda.get_device_properties(device) + gpu_name = properties.name + total_vram_gb = properties.total_memory / (1024**3) + + # Get current memory usage + free_vram_bytes, total_vram_bytes = torch.cuda.mem_get_info(device) + used_vram_gb = (total_vram_bytes - free_vram_bytes) / (1024**3) + available_vram_gb = free_vram_bytes / (1024**3) + + # CUDA Version + cuda_version = torch.version.cuda + + # Log formatted, color-coded information + logging.info(colored(f"CUDA Version: {cuda_version}", 'green')) + logging.info(colored(f"GPU: {gpu_name} (ID: {GPU_ID})", 'green', attrs=['bold'])) + logging.info(colored(f"VRAM: {used_vram_gb:.2f} GB Used / {available_vram_gb:.2f} GB Available / {total_vram_gb:.2f} GB Total", 'green')) + + except Exception as e: + logging.warning(colored(f"⚠️ Could not retrieve detailed GPU/CUDA information: {e}", 'yellow')) +else: + logging.info(colored("GPU: ❌ CUDA (GPU) not available. Using CPU.", 'red', attrs=['bold'])) + +# Selected Device +logging.info(f"Selected Device: {device}") +logging.info("---------------------------------------") + +def get_optimal_gstreamer_pipeline(gpu_name, width, height, fps, host, port, bitrate=5000): + """ + Returns the optimal GStreamer pipeline string based on detected GPU. + + Args: + gpu_name (str): Name of the GPU from torch.cuda.get_device_properties() + width, height, fps: Video parameters + host, port: Streaming destination + bitrate: Target bitrate in kbps + + Returns: + str: GStreamer pipeline string optimized for the detected GPU + """ + + # Detect RTX 50-series (Blackwell) - these need D3D11 pipeline + if "RTX 50" in gpu_name or "5090" in gpu_name or "5080" in gpu_name: + logging.info(f"🎮 Detected RTX 50-series GPU ({gpu_name}). Using D3D11 hardware acceleration.") + return ( + f"fdsrc fd=0 do-timestamp=true is-live=true ! " + f"videoparse format=bgr width={width} height={height} framerate={fps}/1 ! " + "queue max-size-buffers=1 leaky=downstream ! " + "videoconvert ! video/x-raw,format=NV12 ! " + "d3d11upload ! " + "queue max-size-buffers=1 leaky=downstream ! " + # FIXED: Use correct mfh264enc properties + f"mfh264enc rc-mode=cbr bitrate={bitrate * 1000} low-latency=true max-bitrate={bitrate * 1000} ! " + "h264parse ! rtph264pay pt=96 config-interval=1 ! " + f"udpsink host={host} port={port} sync=true async=false" + ) + + + # RTX 40-series and earlier - use CUDA pipeline + elif "RTX" in gpu_name or "GTX" in gpu_name or "Tesla" in gpu_name or "Quadro" in gpu_name: + logging.info(f"🎮 Detected NVIDIA GPU with CUDA support ({gpu_name}). Using CUDA pipeline.") + return ( + f"fds qrc fd=0 do-timestamp=true is-live=true ! " + f"videoparse format=bgr width={width} height={height} framerate={fps}/1 ! " + "queue max-size-buffers=1 leaky=downstream ! " + "cudaupload ! " + "queue max-size-buffers=1 leaky=downstream ! cudaconvert ! videorate ! " + "video/x-raw(memory:CUDAMemory),format=NV12 ! " + "queue max-size-buffers=1 leaky=downstream ! " + f"nvh264enc preset=p1 tune=ultra-low-latency zerolatency=true rc-mode=cbr bitrate={bitrate} ! " + "h264parse ! rtph264pay pt=96 config-interval=1 ! " + f"udpsink host={host} port={port} sync=true async=false" + ) + + # Fallback to CPU encoding for unknown/non-NVIDIA GPUs + else: + logging.warning(f"⚠️ Unknown GPU type ({gpu_name}). Falling back to CPU-based x264 encoding.") + return ( + f"fdsrc fd=0 do-timestamp=true is-live=true ! " + f"videoparse format=bgr width={width} height={height} framerate={fps}/1 ! " + "queue max-size-buffers=1 leaky=downstream ! " + "videoconvert ! video/x-raw,format=I420 ! " + "x264enc tune=zerolatency speed-preset=ultrafast bitrate=5000 ! " + "h264parse ! rtph264pay pt=96 config-interval=1 ! " + f"udpsink host={host} port={port} sync=true async=false" + ) + +# --- Platform-specific imports for enhanced functionality --- +if sys.platform == "win32": + try: + import psutil + p = psutil.Process(os.getpid()) + p.nice(psutil.HIGH_PRIORITY_CLASS) + logging.info("INFO: Process priority set to HIGH on Windows (if psutil installed and permitted).") + except ImportError: + logging.warning("Warning: psutil not found. Cannot set process priority.") + except Exception as e: + logging.warning(f"Warning: Could not set process priority: {e}") + + +# --- MuseTalk Specific Imports (Ensure these are in your PYTHONPATH) --- +try: + from transformers import WhisperModel + from musetalk.utils.face_parsing import FaceParsing + from musetalk.utils.utils import datagen, load_all_model + from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs + from musetalk.utils.blending import get_image_prepare_material, get_image_blending + from musetalk.utils.audio_processor import AudioProcessor +except ImportError as e: + logging.critical(f"Error importing MuseTalk utilities: {e}. Ensure the library is installed and 'musetalk' package is in your PYTHONPATH.", exc_info=True) + sys.exit(1) + +# --- Global Models and Device (initialized in __main__) --- +# These will be initialized once and made globally accessible for use by Avatar class and worker threads. +# They are declared here to be accessible, but will be assigned values in main. +vae = None +unet = None +pe = None +timesteps = None +audio_processor = None +whisper = None +fp = None # FaceParsing instance +weight_dtype = torch.float16 # Default to FP16 as per MuseTalk 1.5 optimizations + +# --- Utility Functions --- +def fast_check_ffmpeg(): + """Checks if ffmpeg is accessible from the system's PATH.""" + try: + subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True, timeout=5) + return True + except (subprocess.CalledProcessError, FileNotFoundError, subprocess.TimeoutExpired): + return False + +def video2imgs(vid_path, save_path): + """Extracts frames from a video file and saves them as PNG images.""" + logging.info(f"Extracting frames from {vid_path} to {save_path}...") + cap = cv2.VideoCapture(vid_path) + if not cap.isOpened(): + logging.error(f"Error: Could not open video file: {vid_path}") + return + count = 0 + while True: + ret, frame = cap.read() + if not ret: break + cv2.imwrite(os.path.join(save_path, f"{str(count).zfill(8)}.png"), frame) + count += 1 + cap.release() + logging.info(f"Finished extracting {count} frames.") + +def osmakedirs(path_list): + """Creates directories if they don't exist, using exist_ok=True for robustness.""" + for path in path_list: + os.makedirs(path, exist_ok=True) + +def _log_subprocess_output(pipe, logger_func, prefix): + """Reads output from a pipe line by line and logs it.""" + try: + for line_bytes in iter(pipe.readline, b''): + logger_func(f"[{prefix}]: {line_bytes.decode(errors='ignore').rstrip()}") + except ValueError: # Pipe might close during readline + pass + except Exception as e: + logging.error(f"Error reading from {prefix} pipe: {e}", exc_info=True) + finally: + if pipe and not pipe.closed: + pipe.close() + +# --- FFmpeg Audio Reader (User's custom class for flexible audio input) --- +class FFmpegAudioReader: + """Uses FFmpeg to read an audio source (file path or bytes) and convert it to raw PCM.""" + def __init__(self, audio_source): + self.audio_source = audio_source + self.is_file_path = isinstance(audio_source, str) + + def read_full_audio(self): + """Reads the entire audio source and converts it to PCM s16le, 48kHz, Stereo.""" + logging.info(f"Reading and converting audio from {'file' if self.is_file_path else 'memory'}...") + target_sr, target_ac, target_format = 48000, 2, "s16le" # Target format for GStreamer audio pipeline + input_data = None + + ffmpeg_input_args = {} + if not self.is_file_path: + input_filename = 'pipe:0' # Read from stdin if bytes are provided + input_data = self.audio_source + else: + input_filename = self.audio_source + + try: + out, err = ( + ffmpeg + .input(input_filename, **ffmpeg_input_args) + .output('pipe:', format=target_format, ac=target_ac, ar=target_sr) + .run(capture_stdout=True, capture_stderr=True, input=input_data, quiet=True) + ) + if err: + logging.debug(f"FFmpeg stderr: {err.decode(errors='ignore')}") + except ffmpeg.Error as e: + logging.error(f"❌ FFmpeg error during audio conversion: {e.stderr.decode(errors='ignore') if e.stderr else 'Unknown FFmpeg error'}") + return None + except Exception as e: + logging.error(f"❌ Unexpected error during FFmpeg execution for audio: {e}", exc_info=True) + return None + + if not out: + logging.error("❌ Failed to read audio: FFmpeg produced no PCM data!") + return None + + audio_data = np.frombuffer(out, dtype=np.int16).reshape(-1, target_ac) + logging.info(f"✅ Read and converted audio: {len(audio_data)} samples at {target_sr}Hz, {target_ac}ch.") + return audio_data + +# --- GStreamer Classes (User's custom classes for real-time streaming) --- +class GStreamerPipeline: + """Manages the GStreamer video pipeline subprocess.""" + def __init__(self, width=1280, height=720, fps=TARGET_FPS, host="127.0.0.1", port=5000): + self.width, self.height, self.fps, self.host, self.port = width, height, fps, host, port + self.process = None + self.stdout_thread = None + self.stderr_thread = None + self.is_running = False + + # --- DYNAMIC PIPELINE SELECTION --- + # Get GPU name from global detection + gpu_name = "" + if torch.cuda.is_available(): + try: + gpu_name = torch.cuda.get_device_name(device) + except: + gpu_name = "Unknown" + + # Generate optimal pipeline for detected hardware + pipeline_str = get_optimal_gstreamer_pipeline( + gpu_name=gpu_name, + width=self.width, + height=self.height, + fps=self.fps, + host=self.host, + port=self.port, + bitrate=5000 + ) + + logging.info(f"Attempting to start GStreamer video pipeline ({self.width}x{self.height}@{self.fps}fps) to {self.host}:{self.port}...") + env_vars = os.environ.copy() + + + + try: + # NEW - FIXED + command_to_run = f'"{GSTREAMER_LAUNCH_PATH}" -v {pipeline_str}' + + logging.info(f"DEBUG: GStreamer VIDEO command (Popen string): {command_to_run}") + + self.process = subprocess.Popen( + command_to_run, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True, + bufsize=0, + env=env_vars + ) + + # Check if process actually started (pid is assigned) + if self.process.pid: + logging.info(colored(f"✅ GStreamer video process launched (PID: {self.process.pid}).", 'green', attrs=['bold'])) + self.is_running = True + else: + logging.error(colored(f"❌ GStreamer video process failed to launch, PID not assigned.", 'red', attrs=['bold'])) + self.process = None # Ensure it's None if not truly launched + self.is_running = False + + # Start threads to read stdout and stderr asynchronously *only if process launched* + if self.is_running: + self.stdout_thread = threading.Thread( + target=_log_subprocess_output, + args=(self.process.stdout, logging.info, "GST_VIDEO_STDOUT"), + daemon=True + ) + self.stderr_thread = threading.Thread( + target=_log_subprocess_output, + args=(self.process.stderr, logging.error, "GST_VIDEO_STDERR"), + daemon=True + ) + self.stdout_thread.start() + self.stderr_thread.start() + + except Exception as e: + logging.error(colored(f"❌ Failed to start GStreamer video pipeline: {e}", 'red', attrs=['bold']), exc_info=True) + self.process = None + self.is_running = False + + def send_frame(self, frame): + """Sends a NumPy array frame to the GStreamer pipeline's stdin.""" + if not self.process or not self.is_running or self.process.stdin.closed: # Check is_running here too + logging.debug("GStreamer video pipeline not running or stdin closed. Cannot send frame.") + return False + try: + if not frame.flags['C_CONTIGUOUS']: + frame = np.ascontiguousarray(frame, dtype=np.uint8) + self.process.stdin.write(frame.tobytes()) + self.process.stdin.flush() + logging.debug(colored("Sent video frame to GStreamer.", 'cyan')) # Log successful send + return True + except (BrokenPipeError, OSError): + logging.error(colored("❌ GStreamer video pipeline: Broken pipe. The process may have crashed.", 'red', attrs=['bold'])) + self.stop() + return False + except Exception as e: + logging.error(colored(f"❌ Error pushing video frame: {e}", 'red', attrs=['bold']), exc_info=True) + return False + + def stop(self): + """Stops the GStreamer video subprocess gracefully.""" + if self.process: + logging.info(f"Stopping GStreamer video pipeline (PID: {self.process.pid})...") + proc_to_stop, self.process = self.process, None + if proc_to_stop.stdin and not proc_to_stop.stdin.closed: + try: proc_to_stop.stdin.close() + except Exception: pass + proc_to_stop.terminate() + try: + proc_to_stop.wait(timeout=3) + logging.info(colored(f"✅ GStreamer video process terminated.", 'green')) + except subprocess.TimeoutExpired: + logging.warning(colored(f"⚠️ GStreamer video process did not terminate gracefully, killing...", 'yellow')) + proc_to_stop.kill() + + if self.stdout_thread and self.stdout_thread.is_alive(): + self.stdout_thread.join(timeout=1) + if self.stderr_thread and self.stderr_thread.is_alive(): + self.stderr_thread.join(timeout=1) + logging.info("GStreamer video pipeline stop complete.") + + +class GStreamerAudio: + """Manages the GStreamer audio pipeline subprocess.""" + def __init__(self, host="127.0.0.1", port=5001, sample_rate=48000, channels=2): + self.host, self.port, self.sample_rate, self.channels = host, port, sample_rate, channels + self.process = None + self.stdout_thread = None + self.stderr_thread = None + self.is_running = False # New flag + + pipeline_str = ( + f"fdsrc fd=0 do-timestamp=true is-live=true ! " + "queue max-size-buffers=2 leaky=downstream ! " + f"audio/x-raw,format=S16LE,channels={self.channels},rate={self.sample_rate},layout=interleaved ! " + "audioconvert ! audioresample ! " + "opusenc bitrate=64000 ! rtpopuspay pt=97 ! " + f"udpsink host={self.host} port={self.port} sync=true" + ) + logging.info(f"Attempting to start GStreamer audio pipeline ({self.sample_rate}Hz, {self.channels}ch) to {self.host}:{self.port}...") + + env_vars = os.environ.copy() + env_vars['GST_DEBUG'] = '3' + + # In GStreamerAudio.__init__(), around line 378: + try: + command_to_run = f'"{GSTREAMER_LAUNCH_PATH}" -v {pipeline_str}' + logging.debug(f"DEBUG: GStreamer AUDIO command (Popen string): {command_to_run}") + + self.process = subprocess.Popen( + command_to_run, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True, + bufsize=0, + env=env_vars + ) + + + if self.process.pid: + logging.info(colored(f"✅ GStreamer audio process launched (PID: {self.process.pid}).", 'green', attrs=['bold'])) + self.is_running = True + else: + logging.error(colored(f"❌ GStreamer audio process failed to launch, PID not assigned.", 'red', attrs=['bold'])) + self.process = None + self.is_running = False + + if self.is_running: + self.stdout_thread = threading.Thread( + target=_log_subprocess_output, + args=(self.process.stdout, logging.info, "GST_AUDIO_STDOUT"), + daemon=True + ) + self.stderr_thread = threading.Thread( + target=_log_subprocess_output, + args=(self.process.stderr, logging.error, "GST_AUDIO_STDERR"), + daemon=True + ) + self.stdout_thread.start() + self.stderr_thread.start() + + except Exception as e: + logging.error(colored(f"❌ Failed to start GStreamer audio pipeline: {e}", 'red', attrs=['bold']), exc_info=True) + self.process = None + self.is_running = False + + def send_audio(self, audio_data_pcm): + """Sends raw PCM audio data to the GStreamer pipeline's stdin.""" + if not self.process or not self.is_running or self.process.stdin.closed: # Check is_running here too + logging.debug("GStreamer audio pipeline not running or stdin closed. Cannot send audio.") + return False + try: + self.process.stdin.write(audio_data_pcm.tobytes()) + self.process.stdin.flush() + logging.debug(colored("Sent audio chunk to GStreamer.", 'cyan')) # Log successful send + return True + except (BrokenPipeError, OSError): + logging.error(colored("❌ GStreamer audio pipeline: Broken pipe.", 'red', attrs=['bold'])) + self.stop() + return False + except Exception as e: + logging.error(colored(f"❌ Error pushing audio chunk: {e}", 'red', attrs=['bold']), exc_info=True) + return False + + def stop(self): + """Stops the GStreamer audio subprocess gracefully.""" + if self.process: + logging.info(f"Stopping GStreamer audio pipeline (PID: {self.process.pid})...") + proc_to_stop, self.process = self.process, None + if proc_to_stop.stdin and not proc_to_stop.stdin.closed: + try: proc_to_stop.stdin.close() + except Exception: pass + proc_to_stop.terminate() + try: + proc_to_stop.wait(timeout=3) + logging.info(colored("✅ GStreamer audio process terminated.", 'green')) + except subprocess.TimeoutExpired: + logging.warning(colored("⚠️ GStreamer audio process did not terminate gracefully, killing...", 'yellow')) + proc_to_stop.kill() + + if self.stdout_thread and self.stdout_thread.is_alive(): + self.stdout_thread.join(timeout=1) + if self.stderr_thread and self.stderr_thread.is_alive(): + self.stderr_thread.join(timeout=1) + logging.info("GStreamer audio pipeline stop complete.") + + +# --- Avatar Class (Unified logic from both original and user's script) --- +@torch.no_grad() +class Avatar: + def __init__(self, avatar_id, video_path, bbox_shift, batch_size, preparation, version_str, extra_margin=0, parsing_mode='jaw'): + logging.info(f"Initializing Avatar: {avatar_id}") + self.avatar_id = str(avatar_id) + self.video_path = video_path + self.bbox_shift = bbox_shift + self.batch_size = batch_size + self.preparation = preparation + self.version_str = version_str + self.extra_margin = extra_margin + self.parsing_mode = parsing_mode + + # Define paths (adapted for v15 structure based on version_str) + if self.version_str == "v15": + self.avatar_base_path = os.path.join(RESULT_DIR, self.version_str, "avatars", self.avatar_id) + else: # v1 + self.avatar_base_path = os.path.join(RESULT_DIR, "avatars", self.avatar_id) + + self.full_imgs_path = os.path.join(self.avatar_base_path, "full_imgs") + self.mask_out_path = os.path.join(self.avatar_base_path, "masks") + self.coords_path = os.path.join(self.avatar_base_path, "coords.pkl") + self.latents_out_path = os.path.join(self.avatar_base_path, "latents.pt") + self.mask_coords_path = os.path.join(self.avatar_base_path, "mask_coords.pkl") + self.avatar_info_path = os.path.join(self.avatar_base_path, "avatar_info.json") + + self.input_latent_list_cycle = [] + self.coord_list_cycle = [] + self.frame_list_cycle = [] + self.mask_coords_list_cycle = [] + self.mask_list_cycle = [] + + self.idx = 0 + + self.init_avatar_data() + # --- NEW OPTIMIZATION: PRE-PROCESSING STEP --- + logging.info("Pre-processing avatar data for optimized performance...") + + # Pre-convert frames from BGR NumPy arrays to RGB PIL Images + self.body_pil_cycle = [ + Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + for frame in self.frame_list_cycle + ] + + # Pre-convert masks to grayscale PIL Images + self.mask_pil_cycle = [ + Image.fromarray(mask).convert("L") + for mask in self.mask_list_cycle + ] + + logging.info("✅ Avatar data pre-processing complete.") + logging.info(f"✅ Avatar '{self.avatar_id}' initialized with {len(self.frame_list_cycle)} reference frames.") + + def init_avatar_data(self): + """Initializes or reloads avatar data based on 'preparation' flag.""" + if self.preparation: + if os.path.exists(self.avatar_base_path): + response = input(f"Avatar '{self.avatar_id}' data exists. Re-create all material? (y/n): ").strip().lower() + if response == "y": + logging.info(f"User chose to re-create. Removing: {self.avatar_base_path}") + shutil.rmtree(self.avatar_base_path) + self._prepare_material_core() + else: + logging.info("Loading existing data as per user request.") + self._reload_prepared_data() + else: + self._prepare_material_core() + else: + logging.info("Preparation=False. Loading existing prepared data...") + self._reload_prepared_data() + + def _reload_prepared_data(self): + """Loads pre-processed avatar data from disk.""" + logging.info(f"Reloading prepared data from: {self.avatar_base_path}") + try: + loaded_latents = torch.load(self.latents_out_path, map_location='cpu') + self.input_latent_list_cycle = list(loaded_latents) if isinstance(loaded_latents, torch.Tensor) else loaded_latents + with open(self.coords_path, 'rb') as f: self.coord_list_cycle = pickle.load(f) + with open(self.mask_coords_path, 'rb') as f: self.mask_coords_list_cycle = pickle.load(f) + + num_items = len(self.coord_list_cycle) + if num_items == 0: raise ValueError("Loaded coordinate data is empty.") + + frame_files = [os.path.join(self.full_imgs_path, f"{str(i).zfill(8)}.png") for i in range(num_items)] + self.frame_list_cycle = read_imgs(frame_files) + mask_files = [os.path.join(self.mask_out_path, f"{str(i).zfill(8)}.png") for i in range(num_items)] + self.mask_list_cycle = read_imgs(mask_files) + + data_map = { + "Latents": self.input_latent_list_cycle, "Coords": self.coord_list_cycle, + "Frames": self.frame_list_cycle, "Masks": self.mask_list_cycle, "MaskCoords": self.mask_coords_list_cycle + } + if not all(len(lst) == num_items for lst in data_map.values()): + lengths = {name: len(lst) for name, lst in data_map.items()} + raise ValueError(f"Data lists have mismatched lengths after loading: {lengths}") + + with open(self.avatar_info_path, "r") as f: + avatar_info_loaded = json.load(f) + if avatar_info_loaded.get('bbox_shift') != self.bbox_shift or \ + avatar_info_loaded.get('version') != self.version_str or \ + avatar_info_loaded.get('extra_margin') != self.extra_margin or \ + avatar_info_loaded.get('parsing_mode') != self.parsing_mode: + + logging.warning("Avatar config (bbox_shift, version, extra_margin, or parsing_mode) has changed.") + response = input(f"Config change detected. Re-create avatar materials? (y/n) (Old: {avatar_info_loaded}, New: {{'bbox_shift': {self.bbox_shift}, 'version': '{self.version_str}', 'extra_margin': {self.extra_margin}, 'parsing_mode': '{self.parsing_mode}'}}): ").strip().lower() + if response == "y": + logging.info("User chose to re-create due to config change.") + shutil.rmtree(self.avatar_base_path) + self._prepare_material_core() + else: + logging.info("Continuing with old avatar data despite config change. This might lead to unexpected results.") + + except FileNotFoundError: + logging.critical(f"❌ Prepared data not found for avatar '{self.avatar_id}'. You must run with preparation=True first.", exc_info=True) + raise SystemExit(f"Exiting: Prepared data not found for {self.avatar_id}.") + except Exception as e: + logging.critical(f"❌ Error reloading prepared data for avatar '{self.avatar_id}'. You may need to run with preparation=True.", exc_info=True) + raise SystemExit(f"Exiting: Failed to reload data for {self.avatar_id}.") + + @torch.no_grad() + def _prepare_material_core(self): + logging.info(f"--- Preparing new material for avatar: {self.avatar_id} ---") + osmakedirs([self.avatar_base_path, self.full_imgs_path, self.mask_out_path]) + + # Store avatar info (unchanged) + avatar_info_data = { + "avatar_id": self.avatar_id, + "video_path": self.video_path, + "bbox_shift": self.bbox_shift, + "version": self.version_str, + "extra_margin": self.extra_margin, + "parsing_mode": self.parsing_mode + } + + with open(self.avatar_info_path, "w") as f: + json.dump(avatar_info_data, f) + + # 1. Extract frames from video or copy from image folder + if os.path.isfile(self.video_path): + logging.debug(f"DEBUG: Starting video frame extraction from {self.video_path} to {self.full_imgs_path}...") + video2imgs(self.video_path, self.full_imgs_path) + elif os.path.isdir(self.video_path): + logging.debug(f"DEBUG: Starting image frame copying from {self.video_path} to {self.full_imgs_path}...") + source_files = sorted([f for f in os.listdir(self.video_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]) + for i, filename in enumerate(tqdm(source_files, desc="Copying frames")): + shutil.copy(os.path.join(self.video_path, filename), os.path.join(self.full_imgs_path, f"{i:08d}.png")) + else: + raise FileNotFoundError(f"video_path '{self.video_path}' is not a valid file or directory.") + + logging.debug("DEBUG: Frame extraction/copy completed. Proceeding to landmark extraction.") + + # 2. Get landmarks and filter out invalid frames + source_images = sorted(glob.glob(os.path.join(self.full_imgs_path, '*.png'))) + logging.debug(f"DEBUG: Starting face landmark and bbox extraction using {len(source_images)} images...") + + initial_coords, initial_frames = [], [] + try: + # Call the original get_landmark_and_bbox + initial_coords, initial_frames = get_landmark_and_bbox(source_images, self.bbox_shift) + + # Count valid/invalid frames + valid_frame_count = sum(1 for bbox in initial_coords if bbox != (0.0, 0.0, 0.0, 0.0)) + logging.info(f"DEBUG: get_landmark_and_bbox returned {len(initial_coords)} frames.") + logging.info(f"DEBUG: Detected {valid_frame_count} valid frames (non-placeholder bbox).") + + if valid_frame_count == 0: + logging.error("CRITICAL: No valid bounding boxes were detected in any frame. Check video content, bbox_shift, or face detection model.") + # Save first few problematic frames for visual inspection + debug_output_dir = os.path.join(self.avatar_base_path, "debug_invalid_frames") + os.makedirs(debug_output_dir, exist_ok=True) + saved_debug_frames = 0 + for i, (bbox, frame) in enumerate(zip(initial_coords, initial_frames)): + if bbox == (0.0, 0.0, 0.0, 0.0): # This is the placeholder for no face detected + if saved_debug_frames < 20: # Save up to 20 debug frames + debug_frame_path = os.path.join(debug_output_dir, f"invalid_frame_{i:08d}.png") + cv2.imwrite(debug_frame_path, frame) + logging.debug(f"Saved debug image for invalid frame {i} to {debug_frame_path}") + saved_debug_frames += 1 + else: + break # Stop saving debug frames after 20 + + raise RuntimeError("No valid frames survived the preparation process.") + + except Exception as e: + logging.critical(f"❌ Error during get_landmark_and_bbox: {e}", exc_info=True) + raise # Re-raise to ensure the main error propagates + + logging.debug("DEBUG: Face landmark and bbox extraction completed. Starting VAE encoding and mask generation.") + + # 3. Process valid frames: VAE encoding and mask generation + valid_latents, valid_coords, valid_frames, valid_masks, valid_mask_coords = [], [], [], [], [] + logging.debug("DEBUG: Entering VAE Encoding & Masking loop.") + + for i, (bbox, frame) in enumerate(tqdm(zip(initial_coords, initial_frames), total=len(initial_coords), desc="VAE Encoding & Masking")): + if bbox == (0.0, 0.0, 0.0, 0.0): + continue + + # --- THE FIX: Do NOT convert bbox to integers. Keep them as floats. --- + x1, y1, x2, y2 = bbox + + # For V15, the enlarged bbox is used for BOTH the parser and the VAE. + if self.version_str == "v15": + y2_adjusted = y2 + self.extra_margin + y2_adjusted = min(y2_adjusted, frame.shape[0]) + final_bbox = [x1, y1, x2, y2_adjusted] + else: + final_bbox = [x1, y1, x2, y2] + + # --- Step 1: Face Parsing with original float precision --- + if self.version_str == "v15": + mode = self.parsing_mode + else: + mode = "raw" + + # This call now uses the original float coordinates and should succeed. + parsing_result = get_image_prepare_material(frame, final_bbox, fp=fp, mode=mode) + + # The robustness check remains as a safeguard for genuinely bad frames. + if parsing_result is None: + logging.warning(f"DIAGNOSTIC: Frame {i} is a genuinely problematic frame and was skipped.") + continue + + mask, mask_crop_box = parsing_result + + # --- Step 2: VAE Encoding --- + # Cropping with floats is fine; array slicing will implicitly convert to integers. + vae_x1, vae_y1, vae_x2, vae_y2 = final_bbox + crop_frame = frame[int(vae_y1):int(vae_y2), int(vae_x1):int(vae_x2)] + + if crop_frame.size == 0: + logging.warning(f"Skipping frame {i} as VAE cropped frame is empty. Bbox: {final_bbox}") + continue + + resized_crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4) + latents = vae.get_latents_for_unet(resized_crop_frame) + + # --- Step 3: Append all data from the successful frame --- + valid_latents.append(latents) + valid_coords.append(final_bbox) + valid_frames.append(frame) + valid_masks.append(mask) + valid_mask_coords.append(mask_crop_box) + + logging.debug("DEBUG: VAE Encoding & Masking completed.") + + if not valid_frames: + logging.error("All frames failed the parsing step. This could indicate a problem with the source video or a deeper issue with the face parsing model's environment/dependencies.") + raise RuntimeError("No valid frames survived the preparation process.") + + + + # 4. Create looping cycle (forward and reverse) and save all data + self.frame_list_cycle = valid_frames + valid_frames[::-1] + self.coord_list_cycle = valid_coords + valid_coords[::-1] + self.input_latent_list_cycle = valid_latents + valid_latents[::-1] + self.mask_list_cycle = valid_masks + valid_masks[::-1] + self.mask_coords_list_cycle = valid_mask_coords + valid_mask_coords[::-1] + + logging.debug("DEBUG: Saving final cycle data to disk.") + shutil.rmtree(self.full_imgs_path, ignore_errors=True) + os.makedirs(self.full_imgs_path) + shutil.rmtree(self.mask_out_path, ignore_errors=True) + os.makedirs(self.mask_out_path) + + for i, (frame, mask) in enumerate(tqdm(zip(self.frame_list_cycle, self.mask_list_cycle), total=len(self.frame_list_cycle), desc="Saving final cycle data")): + cv2.imwrite(os.path.join(self.full_imgs_path, f"{i:08d}.png"), frame) + cv2.imwrite(os.path.join(self.mask_out_path, f"{i:08d}.png"), mask) + + with open(self.coords_path, 'wb') as f: + pickle.dump(self.coord_list_cycle, f) + with open(self.mask_coords_path, 'wb') as f: + pickle.dump(self.mask_coords_list_cycle, f) + torch.save(torch.stack(self.input_latent_list_cycle), self.latents_out_path) + + logging.info(f"--- Material prep complete. Final cycle length: {len(self.frame_list_cycle)} frames. ---") + logging.debug("DEBUG: Avatar material preparation function finished.") + + @torch.no_grad() + def inference(self, audio_source, target_fps): + """ + Main inference loop for real-time streaming. + Processes audio, generates frames, and sends them via GStreamer. + """ + run_id = f"stream_{int(time.time())}" + logging.info(f"🎬 Starting inference run ID: {run_id}") + + vae_to_blend_queue = queue.Queue(maxsize=self.batch_size * 2) + gst_video_pipeline, gst_audio_pipeline, frame_processor_thread = None, None, None + start_time = time.time() + + try: + # 1. Setup GStreamer pipelines + if self.frame_list_cycle: + h, w, _ = self.frame_list_cycle[0].shape + gst_video_pipeline = GStreamerPipeline(width=w, height=h, fps=target_fps) + else: + logging.warning("Avatar reference frames not loaded, using default GStreamer resolution (1280x720).") + gst_video_pipeline = GStreamerPipeline(width=1280, height=720, fps=target_fps) + + gst_audio_pipeline = GStreamerAudio(sample_rate=48000) + if not gst_video_pipeline.is_running or not gst_audio_pipeline.is_running: # CHECK is_running flag + raise RuntimeError("GStreamer pipeline(s) failed to initialize. Check GStreamer installation and plugins.") + + # --- NEW: PIPELINE WARM-UP / PRE-ROLL --- + # This is the critical fix for the initial scrambled/sped-up quarter-second. + # We send a few silent, black frames to stabilize the entire pipeline + # (Python -> OS buffers -> GStreamer) before the real content starts. + logging.info("Warming up streaming pipeline with pre-roll frames...") + frame_duration = 1.0 / target_fps + + # Create a single black frame to reuse. + black_frame = np.zeros((h, w, 3), dtype=np.uint8) + + # Pre-roll for a quarter-second to stabilize the connection. + num_preroll_frames = int(target_fps / 4) + + for _ in range(num_preroll_frames): + if not gst_video_pipeline.send_frame(black_frame): + raise RuntimeError("GStreamer video pipe broke during pre-roll.") + # We use a simple sleep here for the pre-roll, as precision is less critical + # than just giving the pipeline time to initialize. + time.sleep(frame_duration) + logging.info("Pre-roll complete. Starting main content stream.") + # --- END NEW SECTION --- + + # 2. Process audio input + audio_reader = FFmpegAudioReader(audio_source) + full_audio_pcm = audio_reader.read_full_audio() + if full_audio_pcm is None or full_audio_pcm.size == 0: + raise ValueError("PCM audio is empty after FFmpeg conversion. Cannot proceed with inference.") + + feature_extraction_path = audio_source if isinstance(audio_source, str) else None + temp_file_handle = None + if not feature_extraction_path: + temp_file_handle = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") + # import soundfile as sf # Already imported globally now + sf.write(temp_file_handle.name, full_audio_pcm, gst_audio_pipeline.sample_rate) + feature_extraction_path = temp_file_handle.name + temp_file_handle.close() + + whisper_input_features, librosa_length = audio_processor.get_audio_feature(feature_extraction_path, weight_dtype=weight_dtype) + whisper_chunks = audio_processor.get_whisper_chunk( + whisper_input_features, + device, + weight_dtype, + whisper, + librosa_length, + fps=target_fps, + audio_padding_length_left=AUDIO_PADDING_LENGTH_LEFT, + audio_padding_length_right=AUDIO_PADDING_LENGTH_RIGHT, + ) + + if temp_file_handle: + os.unlink(feature_extraction_path) + + num_frames_to_generate = len(whisper_chunks) + if num_frames_to_generate == 0: + logging.warning("No frames to generate based on audio features. Skipping inference.") + return + + num_vae_batches = (num_frames_to_generate + self.batch_size - 1) // self.batch_size + logging.info(f"Audio processed. Planning {num_frames_to_generate} frames in {num_vae_batches} VAE/UNet batches.") + + # 3. Start the consumer thread + self.idx = 0 + frame_processor_thread = threading.Thread( + target=self.process_and_send_frames, + args=(vae_to_blend_queue, gst_video_pipeline, gst_audio_pipeline, num_frames_to_generate, FRAME_SKIP_THRESHOLD), + daemon=True, name=f"FrameProcessor_{run_id}" + ) + frame_processor_thread.start() + + # 4. Main Generation Loop (Producer: VAE/UNet inference) + data_gen = datagen(whisper_chunks, self.input_latent_list_cycle, self.batch_size) + total_audio_samples = len(full_audio_pcm) + audio_samples_sent = 0 + + vae_pbar = tqdm(data_gen, total=num_vae_batches, desc=f"VAE/UNET [{run_id}]", unit="batch") + for i, batch_data in enumerate(vae_pbar): + current_batch_start_time = time.perf_counter() # Mark start of current batch processing + + if not frame_processor_thread.is_alive(): + logging.error(f"Frame processor thread died unexpectedly. Halting generation.") + break + if not batch_data or len(batch_data) != 2: + logging.warning(f"Skipping malformed batch {i}.") + continue + + whisper_batch, latent_batch = batch_data + num_frames_in_batch = len(latent_batch) + + # --- Core AI Inference (THIS WAS THE MISSING PART) --- + audio_feature = pe(whisper_batch.to(device, dtype=weight_dtype)) + latent_input = latent_batch.to(device, dtype=unet.model.dtype) + + pred_latents = unet.model(latent_input, timesteps, encoder_hidden_states=audio_feature).sample + pred_latents = pred_latents.to(device=device, dtype=vae.vae.dtype) + recon_frames = vae.decode_latents(pred_latents) + + if recon_frames is None or len(recon_frames) == 0: + logging.warning(f"Skipping empty VAE output batch {i}.") + continue + + # --- FPS Calculation and Display --- + current_batch_end_time = time.perf_counter() + batch_processing_time = current_batch_end_time - current_batch_start_time + + if batch_processing_time > 0: + current_fps = num_frames_in_batch / batch_processing_time + else: + current_fps = 0.0 + + fps_color = 'green' if current_fps >= target_fps * 0.9 else ('yellow' if current_fps >= target_fps * 0.5 else 'red') + fps_string = colored(f"{current_fps:.2f} FPS", fps_color, attrs=['bold']) + + vae_pbar.set_postfix_str(f"FPS: {fps_string} | Batch: {batch_processing_time:.2f}s", refresh=True) + + # --- DYNAMIC A/V SYNC --- + audio_duration_of_batch = num_frames_in_batch / target_fps + num_audio_samples_for_batch = int(audio_duration_of_batch * gst_audio_pipeline.sample_rate) + + start_audio_idx = audio_samples_sent + end_audio_idx = min(start_audio_idx + num_audio_samples_for_batch, total_audio_samples) + audio_chunk_pcm = full_audio_pcm[start_audio_idx:end_audio_idx] + audio_samples_sent = end_audio_idx + + try: + # Put both the recon_frames (list) and audio_chunk_pcm (numpy array) into the queue + vae_to_blend_queue.put((list(recon_frames), audio_chunk_pcm), timeout=5.0) + except queue.Full: + logging.error(f"VAE-to-Blend queue is full. Consumer can't keep up. Halting generation.") + break + + except Exception as e: + logging.critical(f"CRITICAL ERROR in inference run '{run_id}'.", exc_info=True) + finally: + logging.info(f"\n--- [{run_id}] Final Cleanup ---") + if 'vae_to_blend_queue' in locals(): + try: vae_to_blend_queue.put(None) # Sentinel value for graceful shutdown + except Exception: pass + + # Wait for consumer thread to finish + if frame_processor_thread and frame_processor_thread.is_alive(): + logging.info("Waiting for frame processor thread to finish...") + frame_processor_thread.join(timeout=30.0) + if frame_processor_thread.is_alive(): + logging.warning("Frame processor thread did not terminate gracefully.") + + # Stop GStreamer pipelines + if gst_video_pipeline: gst_video_pipeline.stop() + if gst_audio_pipeline: gst_audio_pipeline.stop() + + elapsed = time.time() - start_time + logging.info(f">>> Inference run '{run_id}' finished in {elapsed:.2f}s. <<<") + + + + def process_and_send_frames(self, vae_to_blend_q, gst_video, gst_audio, total_frames_planned, frame_skip_threshold_val): + ''' + [ENHANCED VERSION] Consumer thread: blends, paces, interleaves audio, and sends to GStreamer. + + This function is the heart of the real-time playback logic. It performs three critical tasks: + 1. **Real-Time Pacing:** Ensures the stream does not play faster than the target FPS on fast hardware, + while having no negative performance impact on slower hardware. + 2. **Dynamic Frame Skipping:** If the AI model (producer) generates frames much faster than they can be + streamed, this function will intelligently drop the oldest batches to "catch up" to real-time, + preventing runaway latency. + 3. **Frame-by-Frame A/V Interleaving:** To ensure smooth audio playback without choppiness, this function + sends a single video frame and then IMMEDIATELY sends its corresponding small slice of audio. + This tight interleaving prevents audio buffer underruns on the receiving end. + + Args: + vae_to_blend_q (queue.Queue): The queue from which to get generated face frames and audio chunks. + gst_video (GStreamerPipeline): The GStreamer video pipeline manager. + gst_audio (GStreamerAudio): The GStreamer audio pipeline manager. + total_frames_planned (int): The total number of frames expected for the entire audio clip. + frame_skip_threshold_val (int): The queue size at which we start dropping frames to reduce latency. + ''' + # --- PACER INITIALIZATION --- + target_fps = gst_video.fps + frame_duration = 1.0 / target_fps + start_time = None # <<< FIX: Initialize start_time to None. + frames_sent = 0 + # -------------------------- + total_frames_processed = 0 + total_frames_skipped = 0 + + # --- DYNAMIC FRAME SKIPPING --- + # (This section remains unchanged) + while vae_to_blend_q.qsize() > frame_skip_threshold_val: + try: + skipped_item = vae_to_blend_q.get_nowait() + if skipped_item and isinstance(skipped_item, tuple) and len(skipped_item[0]) > 0: + num_skipped_in_batch = len(skipped_item[0]) + total_frames_skipped += num_skipped_in_batch + total_frames_processed += num_skipped_in_batch + logging.warning( + colored(f"Queue overloaded ({vae_to_blend_q.qsize()}). Skipping a batch of {num_skipped_in_batch} frames to catch up.", 'yellow') + ) + vae_to_blend_q.task_done() + except queue.Empty: + break + + # --- MAIN PROCESSING LOOP --- + while total_frames_processed < total_frames_planned: + try: + batch_data = vae_to_blend_q.get(block=True, timeout=10.0) + except queue.Empty: + logging.error("Timeout waiting for frames from VAE producer. Ending processor thread.") + break + + if batch_data is None: + logging.info("Received sentinel. Frame processor shutting down.") + break + + # <<< FIX: Start the pacer's clock at the exact moment the first batch arrives. + if start_time is None: + start_time = time.perf_counter() + logging.info("First data batch received. Starting real-time pacer clock.") + # <<< END FIX + + vae_frames, audio_chunk = batch_data + num_frames_in_batch = len(vae_frames) + if num_frames_in_batch == 0: + vae_to_blend_q.task_done() + continue + + samples_per_frame = len(audio_chunk) // num_frames_in_batch + audio_cursor = 0 + + for i, frame in enumerate(vae_frames): + # --- REAL-TIME PACER --- + next_frame_target_time = start_time + (frames_sent + 1) * frame_duration + sleep_duration = next_frame_target_time - time.perf_counter() + if sleep_duration > 0: + time.sleep(sleep_duration) + # --- END PACER --- + + # 1. Blend the generated face onto the background video frame + blended_frame = self._blend_single_frame(frame, gst_video.width, gst_video.height) + + if blended_frame is not None: + # 2. Send the final video frame to the video pipeline + if gst_video.send_frame(blended_frame): + frames_sent += 1 + else: + logging.error("Video pipe broken. Halting all frame processing.") + vae_to_blend_q.task_done() + return + + # 3. Determine the precise audio slice for THIS video frame + start_idx = audio_cursor + end_idx = (audio_cursor + samples_per_frame) if (i < num_frames_in_batch - 1) else len(audio_chunk) + audio_slice = audio_chunk[start_idx:end_idx] + audio_cursor = end_idx + + # 4. Send the corresponding audio slice + if audio_slice.size > 0: + if not gst_audio.send_audio(audio_slice): + logging.error("Audio pipe broken. Stopping audio sends for this run.") + + total_frames_processed += num_frames_in_batch + vae_to_blend_q.task_done() + + logging.info(colored(f"--- Frame processor finished. Sent: {frames_sent}, Skipped: {total_frames_skipped} ---", "green")) + + def _blend_single_frame(self, res_frame, target_width, target_height): + """ + [OPTIMIZED] Blends a single generated face frame onto the original background. + This version uses pre-converted PIL objects to reduce real-time overhead. + """ + try: + # 1. --- Retrieve pre-processed data for the current frame --- + cycle_len = len(self.coord_list_cycle) + if cycle_len == 0: + return None # No need to log error every frame + + current_idx = self.idx % cycle_len + + # Retrieve the pre-calculated and pre-converted PIL images + body_pil = self.body_pil_cycle[current_idx] + mask_pil = self.mask_pil_cycle[current_idx] + + # These are still needed from the original lists + face_bbox = self.coord_list_cycle[current_idx] + crop_box = self.mask_coords_list_cycle[current_idx] + + # 2. --- Perform necessary real-time conversions and operations --- + + # Resize the AI-generated face (this must be done in real-time) + face_w = int(face_bbox[2] - face_bbox[0]) + face_h = int(face_bbox[3] - face_bbox[1]) + + # OPTIMIZATION: Switched to a faster interpolation method. + # INTER_LINEAR is a great balance of speed and quality. + res_frame_resized = cv2.resize(res_frame.astype(np.uint8), (face_w, face_h), interpolation=cv2.INTER_LINEAR) + + # This is the only BGR->RGB conversion left in the hot-loop + face_pil = Image.fromarray(res_frame_resized[:, :, ::-1]) + + # 3. --- Replicate the Original Library's Paste Logic --- + x_s, y_s, _, _ = [int(p) for p in crop_box] + x_f, y_f, _, _ = [int(p) for p in face_bbox] + + # This crop operation is faster now as it's on an already-loaded PIL image + face_large_pil = body_pil.crop(tuple(int(p) for p in crop_box)) + + paste_position = (x_f - x_s, y_f - y_s) + face_large_pil.paste(face_pil, paste_position) + body_pil.paste(face_large_pil, (x_s, y_s), mask_pil) + + # 4. --- Convert back to numpy for GStreamer --- + # This final conversion is unavoidable + final_frame = np.array(body_pil, dtype=np.uint8)[:, :, ::-1] + + # 5. --- Increment index and return the final frame --- + self.idx += 1 + + # This final resize for GStreamer is also unavoidable + if final_frame.shape[0] != target_height or final_frame.shape[1] != target_width: + return cv2.resize(final_frame, (target_width, target_height), interpolation=cv2.INTER_LINEAR) + else: + return final_frame + + except Exception as e: + # Reduced logging level for performance + if self.idx % 100 == 0: # Log only every 100 frames to avoid spam + logging.warning(f"Error blending frame at index {self.idx}: {e}") + self.idx += 1 + return None + +# --- Main Application Entry Point --- +inference_lock = threading.Lock() # Ensures only one inference run processes audio at a time +audio_queue = queue.Queue(maxsize=10) # Queue for raw audio data received by file watcher + +# NEW: Load environment from powershell_env.txt (GLOBAL SCOPE) +POWERSHELL_ENV_FILE = "C:\\temp\\powershell_env.txt" + +env_from_file = os.environ.copy() + +if os.path.exists(POWERSHELL_ENV_FILE): + try: + with open(POWERSHELL_ENV_FILE, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if '=' in line: + key, value = line.split('=', 1) + env_from_file[key] = value + logging.info(f"Successfully loaded environment from {POWERSHELL_ENV_FILE}.") + except Exception as e: + logging.warning(f"Failed to load environment from {POWERSHELL_ENV_FILE}: {e}", exc_info=True) +else: + logging.warning(f"PowerShell environment file not found at {POWERSHELL_ENV_FILE}. Proceeding with inherited environment.") +# --- END NEW GLOBAL BLOCK --- + + +def inference_worker(avatar_instance): + """ + Worker thread that waits for audio data from a queue and initiates an inference run. + Uses a non-blocking lock to skip audio if inference is already in progress. + """ + logging.info("🚀 Inference worker started. Waiting for audio data...") + while True: + audio_data = audio_queue.get() # Blocks until audio is available or None is received + if audio_data is None: # Shutdown signal + logging.info("Inference worker received shutdown signal. Exiting.") + break + + # Try to acquire lock non-blocking. If busy, skip this audio chunk. + if not inference_lock.acquire(blocking=False): + logging.warning("Inference already in progress. Skipping newly received audio to prevent backlog.") + audio_queue.task_done() # Mark this item as done even if skipped + continue # Continue to next item in queue + + try: + logging.info(f"Inference lock acquired. Processing {len(audio_data)} bytes of audio.") + # Delegate to Avatar's inference method, which now handles GStreamer internally + avatar_instance.inference(audio_data, TARGET_FPS) + except Exception as e: + logging.critical(f"Unhandled exception during avatar inference: {e}", exc_info=True) + finally: + inference_lock.release() # Release lock whether successful or not + audio_queue.task_done() # Mark this item as done + logging.info("Inference lock released.") + +def file_watcher(): + """ + Monitors a specified file for changes (e.g., new audio data written to it). + When a change is detected, reads the file's content and puts it into the audio queue. + Designed for use with a "hot file" or named pipe for continuous audio input. + """ + logging.info(f"👀 Starting file watcher for: {STREAM_PIPE_PATH}") # Use global STREAM_PIPE_PATH + last_processed_mtime = 0 # Tracks last modification time to detect new data + file_watcher_started_log = False # Flag to log only once + + while True: + if not file_watcher_started_log: + logging.info("✅ File watcher thread is actively monitoring.") + file_watcher_started_log = True + + try: + if os.path.isfile(STREAM_PIPE_PATH): + current_mtime = os.path.getmtime(STREAM_PIPE_PATH) + if current_mtime > last_processed_mtime: + logging.info("New audio file detected. Reading content...") + # Small delay to ensure file is fully written by external process + time.sleep(0.1) + with open(STREAM_PIPE_PATH, "rb") as f: + incoming_audio_data = f.read() + + if incoming_audio_data: + audio_queue.put(incoming_audio_data) # Add to queue for inference worker + last_processed_mtime = current_mtime # Update mtime only if data was processed + else: + logging.warning("File modification detected, but file was empty. Skipping.") + last_processed_mtime = current_mtime # Still update mtime to avoid re-processing empty file + time.sleep(0.2) # Polling interval for file changes + except FileNotFoundError: + # This is normal if the file is deleted and recreated (e.g., for continuous streaming) + last_processed_mtime = 0 # Reset mtime to ensure next file is processed + time.sleep(0.5) # Wait a bit longer before re-checking for file + except Exception as e: + logging.error(f"Error in file watcher loop: {e}", exc_info=True) + time.sleep(2) + + +if __name__ == "__main__": + logging.info("🎬 Starting MuseTalk Realtime Stream Sync Application (v1.5 compatible)...") + + # 1. Load environment variables from .env file (must be at the top) + # load_dotenv() # Already done globally at the very top of script + + # --- environment loading (now outside of __main__ but used by GStreamer classes) --- + # The `env_from_file` variable is already defined globally now. + + # 2. Argument Parsing (incorporating 1.5 defaults and user's .env/cmd-line options) + parser = argparse.ArgumentParser(description="MuseTalk Real-Time Streaming Script") + parser.add_argument("--version", type=str, default=MUSE_VERSION, choices=["v1", "v15"], help="MuseTalk version (from .env MUSE_VERSION)") + parser.add_argument("--ffmpeg_path", type=str, default=FFMPEG_PATH, help="Path to ffmpeg executable (from .env FFMPEG_PATH)") + parser.add_argument("--gpu_id", type=int, default=GPU_ID, help="GPU ID to use (from .env GPU_ID)") + parser.add_argument("--vae_type", type=str, default=VAE_TYPE, help="Type of VAE model (from .env VAE_TYPE)") + parser.add_argument("--unet_config", type=str, default=UNET_CONFIG_PATH, help="Path to UNet configuration file (from .env UNET_CONFIG)") + parser.add_argument("--unet_model_path", type=str, default=UNET_MODEL_PATH, help="Path to UNet model weights (from .env UNET_MODEL_PATH)") + parser.add_argument("--whisper_dir", type=str, default=WHISPER_DIR, help="Directory containing Whisper model (from .env WHISPER_DIR)") + parser.add_argument("--result_dir", default=RESULT_DIR, help="Directory for output results (from .env RESULT_DIR)") + parser.add_argument("--extra_margin", type=int, default=EXTRA_MARGIN, help="Extra margin for face cropping (from .env EXTRA_MARGIN)") + parser.add_argument("--fps", type=int, default=TARGET_FPS, help="Video frames per second (from .env TARGET_FPS)") + parser.add_argument("--audio_padding_length_left", type=int, default=AUDIO_PADDING_LENGTH_LEFT, help="Left padding length for audio (from .env AUDIO_PADDING_LEFT)") + parser.add_argument("--audio_padding_length_right", type=int, default=AUDIO_PADDING_LENGTH_RIGHT, help="Right padding length for audio (from .env AUDIO_PADDING_RIGHT)") + parser.add_argument("--batch_size", type=int, default=BATCH_SIZE, help="Batch size for inference (from .env BATCH_SIZE)") + parser.add_argument("--parsing_mode", default=PARSING_MODE, help="Face blending parsing mode (from .env PARSING_MODE)") + parser.add_argument("--left_cheek_width", type=int, default=LEFT_CHEEK_WIDTH, help="Width of left cheek region (from .env LEFT_CHEEK_WIDTH)") + parser.add_argument("--right_cheek_width", type=int, default=RIGHT_CHEEK_WIDTH, help="Width of right cheek region (from .env RIGHT_CHEEK_WIDTH)") + parser.add_argument("--avatar_config_path", type=str, default=AVATAR_CONFIG_PATH, help="Path to avatar configuration YAML (from .env AVATAR_CONFIG_PATH)") + parser.add_argument("--avatar_id_to_use", type=str, default=AVATAR_ID_TO_USE, help="ID of the avatar to use from config (from .env AVATAR_ID_TO_USE)") + parser.add_argument("--stream_pipe_path", type=str, default=STREAM_PIPE_PATH, help="Path to watch for audio input (from .env STREAM_PIPE_PATH)") + parser.add_argument("--gstreamer_launch_path", type=str, default=GSTREAMER_LAUNCH_PATH, help="Full path to gst-launch-1.0.exe (from .env GSTREAMER_LAUNCH_PATH)") + + args = parser.parse_args() + logging.info(f"INFO: GStreamer launch path resolved to: {GSTREAMER_LAUNCH_PATH}") + + # 3. Configure ffmpeg path and verify + if not fast_check_ffmpeg(): + logging.info("Attempting to add ffmpeg to PATH...") + path_separator = ';' if sys.platform == 'win32' else ':' + os.environ["PATH"] = f"{FFMPEG_PATH}{path_separator}{os.environ['PATH']}" + if not fast_check_ffmpeg(): + logging.critical("❌ Critical: Unable to find ffmpeg even after attempting to add to PATH. Please ensure ffmpeg is properly installed and accessible. Exiting.") + sys.exit(1) # Exit if ffmpeg is not found + + # 4. Set computing device and print GPU info + # 'device' is already set globally based on GPU_ID from .env + logging.info(f"✅ Selected device: {device}") + if torch.cuda.is_available(): + try: + gpu_name = torch.cuda.get_device_name(device) + gpu_capability = torch.cuda.get_device_capability(device) + logging.info(f"GPU Name: {gpu_name}") + logging.info(f"CUDA Capability: {gpu_capability}") + except Exception as e: + logging.warning(f"Warning: Could not retrieve detailed GPU information: {e}") + + # 5. Load all MuseTalk models (following MuseTalk 1.5's new loading pattern) + logging.info("Loading MuseTalk models (VAE, UNet, Positional Encoding)...") + try: + vae, unet, pe = load_all_model( + unet_model_path=UNET_MODEL_PATH, + vae_type=VAE_TYPE, + unet_config=UNET_CONFIG_PATH, + device=device + ) + # --- NEW OPTIMIZATION --- + ''' + logging.info("Applying torch.compile() for maximum performance...") + if torch.__version__ >= "2.0": + unet.model = torch.compile(unet.model, mode="reduce-overhead") + vae.vae = torch.compile(vae.vae, mode="reduce-overhead") + pe = torch.compile(pe) + logging.info("✅ Models successfully compiled.") + else: + logging.warning("torch.compile() requires PyTorch 2.0+. Skipping.") + ''' + # ------------------------ + timesteps = torch.tensor([0], device=device) # Initialize global timesteps tensor + + pe = pe.half().to(device) + vae.vae = vae.vae.half().to(device) + unet.model = unet.model.half().to(device) + + logging.info("Initializing AudioProcessor and loading Whisper model...") + audio_processor = AudioProcessor(feature_extractor_path=WHISPER_DIR) + weight_dtype = unet.model.dtype # Use model's dtype for Whisper (usually FP16 after .half()) + whisper = WhisperModel.from_pretrained(WHISPER_DIR) + whisper = whisper.to(device=device, dtype=weight_dtype).eval() + whisper.requires_grad_(False) # Freeze Whisper model parameters + + logging.info("Initializing FaceParsing model...") + if MUSE_VERSION == "v15": + fp = FaceParsing( + left_cheek_width=LEFT_CHEEK_WIDTH, + right_cheek_width=RIGHT_CHEEK_WIDTH + ) + else: # v1 fallback + fp = FaceParsing() + + logging.info("✅ All MuseTalk 1.5 core models loaded and configured successfully.") + except Exception as e: + logging.critical("❌ Fatal error loading MuseTalk models or setting up device/precision. Check model paths, CUDA, and dependencies.", exc_info=True) + sys.exit(1) + + main_avatar = None + try: + # 6. Validate configuration paths and avatar ID + if not all([AVATAR_CONFIG_PATH, AVATAR_ID_TO_USE, STREAM_PIPE_PATH]): + raise ValueError("A required configuration path (AVATAR_CONFIG_PATH, AVATAR_ID_TO_USE, or STREAM_PIPE_PATH) is missing. Check your .env file or command-line arguments.") + + logging.info(f"Attempting to load Avatar ID '{AVATAR_ID_TO_USE}' from config file '{AVATAR_CONFIG_PATH}'...") + config = OmegaConf.load(AVATAR_CONFIG_PATH) + if AVATAR_ID_TO_USE not in config: + raise ValueError(f"Avatar ID '{AVATAR_ID_TO_USE}' was not found as a key in the YAML file '{AVATAR_CONFIG_PATH}'") + + avatar_config = config[AVATAR_ID_TO_USE] + + # 7. Instantiate the Avatar with the loaded configuration + main_avatar = Avatar( + avatar_id=AVATAR_ID_TO_USE, + video_path=avatar_config.video_path, + bbox_shift=avatar_config.bbox_shift, + batch_size=BATCH_SIZE, # Use global BATCH_SIZE + preparation=avatar_config.preparation, + version_str=MUSE_VERSION, # Use global MUSE_VERSION + extra_margin=EXTRA_MARGIN, # Use global EXTRA_MARGIN + parsing_mode=PARSING_MODE, # Use global PARSING_MODE + ) + logging.debug("DEBUG: Avatar instance successfully created. Proceeding to thread initialization.") + except Exception as e: + logging.critical("❌ CRITICAL ERROR during Avatar setup. Check your avatar config YAML and video paths.", exc_info=True) + sys.exit(1) + + # 8. Start the file watcher and inference worker threads + logging.info("Starting file watcher and inference worker threads...") + inference_thread = threading.Thread(target=inference_worker, args=(main_avatar,), daemon=True, name="InferenceWorker") + watcher_thread = threading.Thread(target=file_watcher, daemon=True, name="FileWatcher") + + inference_thread.start() + watcher_thread.start() + + # 9. Keep the main thread alive and handle graceful shutdown + logging.info("✅ Application is running. Press Ctrl+C to shut down gracefully.") + try: + while True: + # Check if critical worker threads are still alive + if not watcher_thread.is_alive(): + logging.error("File watcher thread died unexpectedly. Initiating shutdown.") + break + if not inference_thread.is_alive(): + logging.error("Inference worker thread died unexpectedly. Initiating shutdown.") + break + time.sleep(1.0) # Main thread sleeps, allowing worker threads to run + except KeyboardInterrupt: + logging.info("\n🛑 KeyboardInterrupt received. Initiating graceful shutdown...") + finally: + # Signal worker threads to stop by putting a sentinel value into the queue + audio_queue.put(None) + + logging.info("Attempting to join worker threads for clean shutdown...") + + # Give watcher thread a moment to finish + if watcher_thread.is_alive(): + watcher_thread.join(timeout=2) + if watcher_thread.is_alive(): + logging.warning("File watcher thread did not terminate gracefully. It might be stuck.") + + # Give inference thread more time as it might be processing a batch + if inference_thread.is_alive(): + inference_thread.join(timeout=30) + if inference_thread.is_alive(): + logging.warning("Inference worker thread did not terminate gracefully. It might be stuck.") + + logging.info("✅ Application shutdown complete.") \ No newline at end of file diff --git a/scripts/realtime_stream_sync.py b/scripts/realtime_stream_sync.py new file mode 100644 index 00000000..2b8e15a2 --- /dev/null +++ b/scripts/realtime_stream_sync.py @@ -0,0 +1,915 @@ +# -*- coding: utf-8 -*- +import ffmpeg +import argparse +import os +import concurrent.futures # For ThreadPoolExecutor in Avatar preparation +import threading +import queue +import io +from omegaconf import OmegaConf +import subprocess +import numpy as np +import cv2 +import torch +import glob +import pickle +import sys +from tqdm import tqdm +import copy +import json +import traceback # For detailed error printing +import time +from PIL import Image +import tempfile +import logging +from dotenv import load_dotenv +# --- Platform-specific imports for future use if needed --- +if sys.platform == "win32": + try: + import psutil + # Optional: Set high process priority on Windows + p = psutil.Process(os.getpid()) + p.nice(psutil.HIGH_PRIORITY_CLASS) + print("INFO: Process priority set to HIGH on Windows.") + except ImportError: + print("Warning: psutil not found. Cannot set process priority.") + except Exception as e: + print(f"Warning: Could not set process priority: {e}") + + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +# --- MuseTalk Specific Imports (Ensure these are in your PYTHONPATH) --- +try: + from musetalk.utils.utils import get_file_type, get_video_fps, datagen + from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder + from musetalk.utils.blending import get_image_prepare_material + from musetalk.utils.utils import load_all_model +except ImportError as e: + logging.critical(f"Error importing MuseTalk utilities: {e}. Ensure the library is installed and in your PYTHONPATH.", exc_info=True) + sys.exit(1) + +import shutil + +# --- Configuration & Global Variables --- +# Recommended to be set via environment variables or a config file +TARGET_FPS = int(os.getenv("TARGET_FPS", "25")) +FRAME_SKIP_THRESHOLD = int(os.getenv("FRAME_SKIP_THRESHOLD", "3")) # Drop frames if queue size exceeds this +AVATAR_CONFIG_PATH = os.getenv("AVATAR_CONFIG_PATH", "configs/inference/realtime.yaml") +AVATAR_ID_TO_USE = os.getenv("AVATAR_ID_TO_USE", "default_avatar_id") # CHANGE THIS +STREAM_PIPE_PATH = os.getenv("STREAM_PIPE_PATH", "./hot_file.opus") # File to watch for audio + +# --- PyTorch Device Setup --- +cuda_available = torch.cuda.is_available() +device = torch.device("cuda" if cuda_available else "cpu") +logging.info("--- PyTorch Device Information ---") +if cuda_available: + try: + gpu_name = torch.cuda.get_device_name(0) + logging.info(f"✅ CUDA (GPU) detected: {gpu_name}") + except Exception as e: + logging.warning(f"⚠️ Could not retrieve GPU name: {e}") +else: + logging.info("❌ CUDA (GPU) not available. Using CPU.") +logging.info(f"✅ Selected device: {device}") +logging.info("-------------------------------") + +# --- Load Models --- +logging.info("Loading models...") +try: + audio_processor, vae, unet, pe = load_all_model() + logging.info("✅ Models loaded successfully.") +except Exception as e: + logging.critical("Fatal error loading models.", exc_info=True) + sys.exit(1) + +# --- Set Model Precision (FP16) --- +logging.info("Setting model precision to half (FP16) where applicable...") +try: + if hasattr(pe, 'half'): pe = pe.to(device).half() + if hasattr(vae, 'vae') and hasattr(vae.vae, 'half'): vae.vae = vae.vae.to(device).half() + if hasattr(unet, 'model') and hasattr(unet.model, 'half'): unet.model = unet.model.to(device).half() + timesteps = torch.tensor([0], device=device) + logging.info("✅ Model precision set.") +except Exception as e: + logging.warning(f"⚠️ Error setting model precision: {e}. Performance may be affected.", exc_info=True) + + +# --- FFmpeg Audio Reader --- +class FFmpegAudioReader: + """Uses FFmpeg to read an audio file (from path or bytes) and convert it to raw PCM.""" + def __init__(self, audio_source): + self.audio_source = audio_source + self.is_file_path = isinstance(audio_source, str) + + def read_full_audio(self): + """Reads the entire audio source and converts it to PCM s16le, 48kHz, Stereo.""" + logging.info(f"Reading and converting audio from {'file' if self.is_file_path else 'memory'}...") + target_sr, target_ac, target_format = 48000, 2, "s16le" + input_data = None + + ffmpeg_input_args = {} + if not self.is_file_path: + # If source is bytes, we'll pipe it to ffmpeg's stdin + input_filename = 'pipe:0' + input_data = self.audio_source + else: + input_filename = self.audio_source + + try: + # Use ffmpeg-python for a cleaner interface + out, err = ( + ffmpeg + .input(input_filename, **ffmpeg_input_args) + .output('pipe:', format=target_format, ac=target_ac, ar=target_sr) + .run(capture_stdout=True, capture_stderr=True, input=input_data) + ) + if err: + logging.debug(f"FFmpeg stderr: {err.decode(errors='ignore')}") + except ffmpeg.Error as e: + logging.error(f"❌ FFmpeg error during audio conversion: {e.stderr.decode(errors='ignore') if e.stderr else 'Unknown FFmpeg error'}") + return None + except Exception as e: + logging.error(f"❌ Unexpected error during FFmpeg execution: {e}", exc_info=True) + return None + + if not out: + logging.error("❌ Failed to read audio: FFmpeg produced no PCM data!") + return None + + audio_data = np.frombuffer(out, dtype=np.int16).reshape(-1, target_ac) + logging.info(f"✅ Read and converted audio: {len(audio_data)} samples at {target_sr}Hz, {target_ac}ch.") + return audio_data + +# --- GStreamer Classes --- +class GStreamerPipeline: + """Manages the GStreamer video pipeline subprocess.""" + def __init__(self, width=1280, height=720, fps=TARGET_FPS, host="127.0.0.1", port=5000): + self.width, self.height, self.fps, self.host, self.port = width, height, fps, host, port + self.process = None + self.stdout_thread = None + self.stderr_thread = None + + # A robust pipeline with leaky queues (for backpressure) and nvenc for performance + pipeline_str = ( + f"fdsrc fd=0 do-timestamp=true is-live=true ! videoparse format=bgr width={self.width} height={self.height} framerate={self.fps}/1 ! " + "queue ! videoconvert ! videorate ! " + f"video/x-raw,format=NV12 ! " # Assuming NV12 is desired for nvh265enc + "queue ! " + # Reverting bitrate to 4000 as per your old working code, and keep preset for testing + f"nvh265enc preset=low-latency-hq rc-mode=cbr bitrate=4000 gop-size=30 ! " + "h265parse ! rtph265pay pt=96 config-interval=1 ! " + f"udpsink host={self.host} port={self.port} sync=true async=false" + ) + logging.info(f"Starting GStreamer video pipeline with resolution {self.width}x{self.height}@{self.fps}fps...") + + # Prepare environment variables for the subprocess + env_vars = os.environ.copy() + env_vars['GST_DEBUG'] = '3' # Set GStreamer debug level + + try: + # Construct the command string without explicit quoting for GSTREAMER_LAUNCH_PATH here. + # This relies on the shell's PATH to find gst-launch-1.0. + # If GSTREAMER_LAUNCH_PATH itself has a full path, this will override PATH. + command_to_run = f"{GSTREAMER_LAUNCH_PATH} -v {pipeline_str}" + + logging.debug(f"DEBUG: GStreamer VIDEO command (old way): {command_to_run}") + + self.process = subprocess.Popen( + command_to_run, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, # Capture stdout + stderr=subprocess.PIPE, # Capture stderr + shell=True, # Keep shell=True + bufsize=0, + env=env_vars + ) + logging.info(f"✅ GStreamer video process launched (PID: {self.process.pid}).") + + # --- Start threads to read stdout and stderr asynchronously --- + self.stdout_thread = threading.Thread( + target=_log_subprocess_output, + args=(self.process.stdout, logging.info, "GST_VIDEO_STDOUT"), + daemon=True + ) + self.stderr_thread = threading.Thread( + target=_log_subprocess_output, + args=(self.process.stderr, logging.error, "GST_VIDEO_STDERR"), + daemon=True + ) + self.stdout_thread.start() + self.stderr_thread.start() + + except Exception as e: + logging.error(f"❌ Failed to start GStreamer video pipeline: {e}", exc_info=True) + self.process = None + + def send_frame(self, frame): + """Sends a NumPy array frame to the GStreamer pipeline's stdin.""" + if not self.process or self.process.stdin.closed: + logging.error("❌ GStreamer video pipeline process is not running or stdin is closed.") + return False + try: + if not frame.flags['C_CONTIGUOUS']: + frame = np.ascontiguousarray(frame, dtype=np.uint8) + self.process.stdin.write(frame.tobytes()) + self.process.stdin.flush() + return True + except (BrokenPipeError, OSError): + logging.error("❌ GStreamer video pipeline: Broken pipe. The process may have crashed.") + self.stop() + return False + except Exception as e: + logging.error(f"❌ Error pushing video frame: {e}", exc_info=True) + return False + + def stop(self): + """Stops the GStreamer video subprocess gracefully.""" + if self.process: + logging.info(f"Stopping GStreamer video pipeline (PID: {self.process.pid})...") + proc_to_stop, self.process = self.process, None + if proc_to_stop.stdin and not proc_to_stop.stdin.closed: + try: proc_to_stop.stdin.close() + except Exception: pass + proc_to_stop.terminate() + try: + proc_to_stop.wait(timeout=3) + logging.info(f"✅ GStreamer video process terminated.") + except subprocess.TimeoutExpired: + logging.warning(f"⚠️ GStreamer video process did not terminate gracefully, killing...") + proc_to_stop.kill() + + if self.stdout_thread and self.stdout_thread.is_alive(): + logging.info("Joining GST_VIDEO_STDOUT thread...") + self.stdout_thread.join(timeout=1) + if self.stderr_thread and self.stderr_thread.is_alive(): + logging.info("Joining GST_VIDEO_STDERR thread...") + self.stderr_thread.join(timeout=1) + logging.info("GStreamer video pipeline stop complete.") + + +class GStreamerAudio: + """Manages the GStreamer audio pipeline subprocess.""" + def __init__(self, host="127.0.0.1", port=5001, sample_rate=48000, channels=2): + self.host, self.port, self.sample_rate, self.channels = host, port, sample_rate, channels + self.process = None + self.stdout_thread = None + self.stderr_thread = None + + pipeline_str = ( + f"fdsrc fd=0 do-timestamp=true is-live=true ! " + "queue ! " + f"audio/x-raw,format=S16LE,channels={self.channels},rate={self.sample_rate},layout=interleaved ! " + "audioconvert ! audioresample ! " + "opusenc bitrate=96000 ! rtpopuspay pt=97 ! " + f"udpsink host={self.host} port={self.port} sync=true" + ) + logging.info("Starting GStreamer audio pipeline...") + + env_vars = os.environ.copy() + env_vars['GST_DEBUG'] = '3' + + try: + # Construct the command string without explicit quoting for GSTREAMER_LAUNCH_PATH here. + command_to_run = f"{GSTREAMER_LAUNCH_PATH} -v {pipeline_str}" + + logging.debug(f"DEBUG: GStreamer AUDIO command (old way): {command_to_run}") + + self.process = subprocess.Popen( + command_to_run, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True, + bufsize=0, + env=env_vars + ) + logging.info(f"✅ GStreamer audio process launched (PID: {self.process.pid}).") + + # --- Start threads to read stdout and stderr asynchronously --- + self.stdout_thread = threading.Thread( + target=_log_subprocess_output, + args=(self.process.stdout, logging.info, "GST_AUDIO_STDOUT"), + daemon=True + ) + self.stderr_thread = threading.Thread( + target=_log_subprocess_output, + args=(self.process.stderr, logging.error, "GST_AUDIO_STDERR"), + daemon=True + ) + self.stdout_thread.start() + self.stderr_thread.start() + + except Exception as e: + logging.error(f"❌ Failed to start GStreamer audio pipeline: {e}", exc_info=True) + self.process = None + + def send_audio(self, audio_data_pcm): + """Sends raw PCM audio data to the GStreamer pipeline's stdin.""" + if not self.process or self.process.stdin.closed: + logging.error("❌ GStreamer audio pipeline process is not running or stdin is closed.") + return False + try: + self.process.stdin.write(audio_data_pcm.tobytes()) + self.process.stdin.flush() + return True + except (BrokenPipeError, OSError): + logging.error("❌ GStreamer audio pipeline: Broken pipe.") + self.stop() + return False + except Exception as e: + logging.error(f"❌ Error pushing audio chunk: {e}", exc_info=True) + return False + + def stop(self): + """Stops the GStreamer audio subprocess gracefully.""" + if self.process: + logging.info(f"Stopping GStreamer audio pipeline (PID: {self.process.pid})...") + proc_to_stop, self.process = self.process, None + if proc_to_stop.stdin and not proc_to_stop.stdin.closed: + try: proc_to_stop.stdin.close() + except Exception: pass + proc_to_stop.terminate() + try: + proc_to_stop.wait(timeout=3) + logging.info("✅ GStreamer audio process terminated.") + except subprocess.TimeoutExpired: + logging.warning("⚠️ GStreamer audio process did not terminate gracefully, killing...") + proc_to_stop.kill() + + if self.stdout_thread and self.stdout_thread.is_alive(): + logging.info("Joining GST_AUDIO_STDOUT thread...") + self.stdout_thread.join(timeout=1) + if self.stderr_thread and self.stderr_thread.is_alive(): + logging.info("Joining GST_AUDIO_STDERR thread...") + self.stderr_thread.join(timeout=1) + logging.info("GStreamer audio pipeline stop complete.") + + + def _log_stream(self, stream, prefix): + try: + for line_bytes in iter(stream.readline, b''): + logging.info(f"[{prefix}]: {line_bytes.decode(errors='ignore').strip()}") + finally: + stream.close() + + def send_audio(self, audio_data_pcm): + if not self.process or self.process.stdin.closed: + return False + try: + self.process.stdin.write(audio_data_pcm.tobytes()) + self.process.stdin.flush() + return True + except (BrokenPipeError, OSError): + logging.error("❌ GStreamer audio pipeline: Broken pipe.") + self.stop() + return False + except Exception as e: + logging.error(f"❌ Error pushing audio chunk: {e}", exc_info=True) + return False + + def stop(self): + if self.process: + logging.info(f"Stopping GStreamer audio pipeline (PID: {self.process.pid})...") + proc_to_stop, self.process = self.process, None + if proc_to_stop.stdin and not proc_to_stop.stdin.closed: + try: proc_to_stop.stdin.close() + except Exception: pass + proc_to_stop.terminate() + try: + proc_to_stop.wait(timeout=2) + logging.info("✅ GStreamer audio process terminated.") + except subprocess.TimeoutExpired: + logging.warning("⚠️ GStreamer audio process did not terminate gracefully, killing...") + proc_to_stop.kill() + +# --- Helper Functions --- +def video2imgs(vid_path, save_path): + logging.info(f"Extracting frames from {vid_path} to {save_path}...") + cap = cv2.VideoCapture(vid_path) + if not cap.isOpened(): + logging.error(f"Error: Could not open video file: {vid_path}") + return + count = 0 + while True: + ret, frame = cap.read() + if not ret: break + cv2.imwrite(os.path.join(save_path, f"{str(count).zfill(8)}.png"), frame) + count += 1 + cap.release() + logging.info(f"Finished extracting {count} frames.") + +def osmakedirs(path_list): + for path in path_list: + os.makedirs(path, exist_ok=True) + +# --- Avatar Class --- +class Avatar: + @torch.no_grad() + def __init__(self, avatar_id, video_path, bbox_shift, batch_size, preparation): + logging.info(f"Initializing Avatar: {avatar_id}") + self.avatar_id = str(avatar_id) + self.video_path = video_path + self.bbox_shift = bbox_shift + self.batch_size = batch_size + self.preparation = preparation + + # Define paths + self.avatar_base_path = os.path.join("./results/avatars", self.avatar_id) + self.full_imgs_path = os.path.join(self.avatar_base_path, "full_imgs") + self.mask_out_path = os.path.join(self.avatar_base_path, "masks") + self.coords_path = os.path.join(self.avatar_base_path, "coords.pkl") + self.latents_out_path = os.path.join(self.avatar_base_path, "latents.pt") + self.mask_coords_path = os.path.join(self.avatar_base_path, "mask_coords.pkl") + self.avatar_info_path = os.path.join(self.avatar_base_path, "avatar_info.json") + + # Initialize data stores + self.input_latent_list_cycle = [] + self.coord_list_cycle = [] + self.frame_list_cycle = [] + self.mask_coords_list_cycle = [] + self.mask_list_cycle = [] + + self.idx = 0 # Tracks current position in the reference video cycle + + self.init_avatar_data() + logging.info(f"✅ Avatar '{self.avatar_id}' initialized with {len(self.frame_list_cycle)} reference frames.") + + def init_avatar_data(self): + # If preparation is needed, do it. Otherwise, load existing data. + if self.preparation: + if os.path.exists(self.avatar_base_path): + # Prompt user before overwriting existing data + response = input(f"Avatar '{self.avatar_id}' data exists. Re-create all material? (y/n): ").strip().lower() + if response == "y": + logging.info(f"User chose to re-create. Removing: {self.avatar_base_path}") + shutil.rmtree(self.avatar_base_path) + self._prepare_material_core() + else: + logging.info("Loading existing data as per user request.") + self._reload_prepared_data() + else: + self._prepare_material_core() + else: + logging.info("Preparation=False. Loading existing prepared data...") + self._reload_prepared_data() + + def _reload_prepared_data(self): + logging.info(f"Reloading prepared data from: {self.avatar_base_path}") + try: + # Load all data from disk + loaded_latents = torch.load(self.latents_out_path, map_location='cpu') + self.input_latent_list_cycle = list(loaded_latents) if isinstance(loaded_latents, torch.Tensor) else loaded_latents + with open(self.coords_path, 'rb') as f: self.coord_list_cycle = pickle.load(f) + with open(self.mask_coords_path, 'rb') as f: self.mask_coords_list_cycle = pickle.load(f) + + # Read corresponding images and masks + num_items = len(self.coord_list_cycle) + if num_items == 0: raise ValueError("Loaded coordinate data is empty.") + + frame_files = [os.path.join(self.full_imgs_path, f"{str(i).zfill(8)}.png") for i in range(num_items)] + self.frame_list_cycle = read_imgs(frame_files) + mask_files = [os.path.join(self.mask_out_path, f"{str(i).zfill(8)}.png") for i in range(num_items)] + self.mask_list_cycle = read_imgs(mask_files) + + # Validate that all lists have the same, non-zero length + data_map = { + "Latents": self.input_latent_list_cycle, "Coords": self.coord_list_cycle, + "Frames": self.frame_list_cycle, "Masks": self.mask_list_cycle, "MaskCoords": self.mask_coords_list_cycle + } + if not all(len(lst) == num_items for lst in data_map.values()): + lengths = {name: len(lst) for name, lst in data_map.items()} + raise ValueError(f"Data lists have mismatched lengths after loading: {lengths}") + + except Exception as e: + logging.critical(f"Error reloading prepared data. You may need to run with preparation=True.", exc_info=True) + raise SystemExit(f"Exiting: Failed to reload data for {self.avatar_id}.") + + @torch.no_grad() + def _prepare_material_core(self): + logging.info(f"--- Preparing new material for avatar: {self.avatar_id} ---") + osmakedirs([self.avatar_base_path, self.full_imgs_path, self.mask_out_path]) + with open(self.avatar_info_path, "w") as f: json.dump({"avatar_id": self.avatar_id}, f) + + # 1. Extract frames from video or copy from image folder + if os.path.isfile(self.video_path): + video2imgs(self.video_path, self.full_imgs_path) + elif os.path.isdir(self.video_path): + # Handles copying and renaming from a folder of images + source_files = sorted([f for f in os.listdir(self.video_path) if f.lower().endswith(('.png','.jpg','.jpeg'))]) + for i, filename in enumerate(tqdm(source_files, desc="Copying frames")): + shutil.copy(os.path.join(self.video_path, filename), os.path.join(self.full_imgs_path, f"{i:08d}.png")) + else: + raise FileNotFoundError(f"video_path '{self.video_path}' is not a valid file or directory.") + + # 2. Get landmarks and filter out invalid frames + source_images = sorted(glob.glob(os.path.join(self.full_imgs_path, '*.png'))) + initial_coords, initial_frames = get_landmark_and_bbox(source_images, self.bbox_shift) + + # 3. Process valid frames: VAE encoding and mask generation + valid_latents, valid_coords, valid_frames, valid_masks, valid_mask_coords = [], [], [], [], [] + coord_ph_val = coord_placeholder + + for i, (bbox, frame) in enumerate(tqdm(zip(initial_coords, initial_frames), total=len(initial_coords), desc="VAE Encoding & Masking")): + if bbox is None or np.array_equal(bbox, coord_ph_val) or frame is None: + continue + + x1c, y1c, x2c, y2c = bbox + crop = frame[int(y1c):int(y2c), int(x1c):int(x2c)] + if crop.size == 0: continue + + try: + # VAE Encoding + resized_crop = cv2.resize(crop, (256, 256), interpolation=cv2.INTER_LANCZOS4) + latents = vae.get_latents_for_unet(resized_crop).cpu() + + # Mask Generation + mask, crop_box = get_image_prepare_material(frame, bbox) + if mask is None or crop_box is None: continue + + # Add all valid data together + valid_latents.append(latents) + valid_coords.append(bbox) + valid_frames.append(frame) + valid_masks.append(mask) + valid_mask_coords.append(crop_box) + except Exception as e: + logging.warning(f"Skipping frame {i} due to processing error: {e}") + + if not valid_frames: + raise RuntimeError("No valid frames survived the preparation process.") + + # 4. Create looping cycle (forward and reverse) and save all data + self.frame_list_cycle = valid_frames + valid_frames[::-1] + self.coord_list_cycle = valid_coords + valid_coords[::-1] + self.input_latent_list_cycle = valid_latents + valid_latents[::-1] + self.mask_list_cycle = valid_masks + valid_masks[::-1] + self.mask_coords_list_cycle = valid_mask_coords + valid_mask_coords[::-1] + + # Overwrite content with the final, filtered, and cycled data + shutil.rmtree(self.full_imgs_path); os.makedirs(self.full_imgs_path) + shutil.rmtree(self.mask_out_path); os.makedirs(self.mask_out_path) + for i, (frame, mask) in enumerate(tqdm(zip(self.frame_list_cycle, self.mask_list_cycle), total=len(self.frame_list_cycle), desc="Saving final cycle data")): + cv2.imwrite(os.path.join(self.full_imgs_path, f"{i:08d}.png"), frame) + cv2.imwrite(os.path.join(self.mask_out_path, f"{i:08d}.png"), mask) + + with open(self.coords_path, 'wb') as f: pickle.dump(self.coord_list_cycle, f) + with open(self.mask_coords_path, 'wb') as f: pickle.dump(self.mask_coords_list_cycle, f) + torch.save(torch.stack(self.input_latent_list_cycle), self.latents_out_path) + + logging.info(f"--- Material prep complete. Final cycle length: {len(self.frame_list_cycle)} frames. ---") + + @torch.no_grad() + def inference(self, audio_source, target_fps): + # This is the main inference producer loop + run_id = f"stream_{int(time.time())}" + logging.info(f"🎬 Starting inference run ID: {run_id}") + + # This queue passes generated frames and audio to the processing/sending thread + vae_to_blend_queue = queue.Queue(maxsize=self.batch_size * 2) + gst_video_pipeline, gst_audio_pipeline, frame_processor_thread = None, None, None + start_time = time.time() + + try: + # 1. Setup GStreamer pipelines + gst_video_pipeline = GStreamerPipeline(fps=target_fps) + gst_audio_pipeline = GStreamerAudio() + if not gst_video_pipeline.process or not gst_audio_pipeline.process: + raise RuntimeError("GStreamer pipeline(s) failed to initialize.") + + # 2. Process audio input + audio_reader = FFmpegAudioReader(audio_source) + full_audio_pcm = audio_reader.read_full_audio() + if full_audio_pcm is None or full_audio_pcm.size == 0: + raise ValueError("PCM audio is empty after FFmpeg conversion.") + + # Use a temporary file for feature extraction if audio source is in memory + feature_extraction_path = audio_source if isinstance(audio_source, str) else None + temp_file_handle = None + if not feature_extraction_path: + temp_file_handle = tempfile.NamedTemporaryFile(delete=False, suffix=".opus") + temp_file_handle.write(audio_source) + feature_extraction_path = temp_file_handle.name + temp_file_handle.close() # Close handle so audio_processor can open it + + whisper_feature = audio_processor.audio2feat(feature_extraction_path) + whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature, fps=target_fps) + + if temp_file_handle: + os.unlink(feature_extraction_path) + + num_frames_to_generate = len(whisper_chunks) + if num_frames_to_generate == 0: + logging.warning("No frames to generate based on audio features. Skipping.") + return + + num_vae_batches = (num_frames_to_generate + self.batch_size - 1) // self.batch_size + logging.info(f"Audio processed. Planning {num_frames_to_generate} frames in {num_vae_batches} VAE batches.") + + # 3. Start the consumer thread (frame processing and sending) + self.idx = 0 # Reset reference frame index for each run + frame_processor_thread = threading.Thread( + target=self.process_and_send_frames, + args=(vae_to_blend_queue, gst_video_pipeline, gst_audio_pipeline, num_frames_to_generate), + daemon=True, name=f"FrameProcessor_{run_id}" + ) + frame_processor_thread.start() + + # 4. Main Generation Loop (Producer) + data_gen = datagen(whisper_chunks, self.input_latent_list_cycle, self.batch_size) + total_audio_samples = len(full_audio_pcm) + audio_samples_sent = 0 + + for i, batch_data in enumerate(tqdm(data_gen, total=num_vae_batches, desc=f"VAE/UNET [{run_id}]")): + if not frame_processor_thread.is_alive(): + logging.error(f"Frame processor thread died unexpectedly. Halting generation.") + break + if not batch_data or len(batch_data) != 2: continue + + whisper_batch, latent_batch = batch_data + + # --- Core AI Inference --- + audio_feature = pe(torch.from_numpy(whisper_batch).to(device, dtype=unet.model.dtype)) + # The 'datagen' utility already stacks the latents into a batch tensor. + # We just need to move this tensor to the correct device and set its data type. + if not isinstance(latent_batch, torch.Tensor): + raise TypeError(f"The 'datagen' utility should yield a Tensor, but got {type(latent_batch)}") + latent_input = latent_batch.to(device, dtype=unet.model.dtype) + pred_latents = unet.model(latent_input, timesteps, encoder_hidden_states=audio_feature).sample + vae_output = vae.decode_latents(pred_latents) # Returns list/np.array of frames + # --- End Core AI --- + + if vae_output is None or len(vae_output) == 0: continue + + # --- DYNAMIC A/V SYNC --- + # Calculate how much audio this batch of video frames represents + num_frames_in_batch = len(vae_output) + audio_duration_of_batch = num_frames_in_batch / target_fps + num_audio_samples_for_batch = int(audio_duration_of_batch * gst_audio_pipeline.sample_rate) + + # Slice the exact audio chunk from the full PCM data + start_audio_idx = audio_samples_sent + end_audio_idx = min(start_audio_idx + num_audio_samples_for_batch, total_audio_samples) + audio_chunk_pcm = full_audio_pcm[start_audio_idx:end_audio_idx] + audio_samples_sent = end_audio_idx + # --- End A/V SYNC --- + + try: + # Put the generated frames and their corresponding audio chunk into the queue + vae_to_blend_queue.put((list(vae_output), audio_chunk_pcm), timeout=5.0) + except queue.Full: + logging.error(f"VAE-to-Blend queue is full. Consumer can't keep up. Halting generation.") + break + + except Exception as e: + logging.critical(f"CRITICAL ERROR in inference run '{run_id}'.", exc_info=True) + finally: + logging.info(f"\n--- [{run_id}] Final Cleanup ---") + # Signal consumer thread to finish + if 'vae_to_blend_queue' in locals(): + try: vae_to_blend_queue.put(None) # Sentinel value + except Exception: pass + + # Wait for consumer thread to finish its work + if frame_processor_thread and frame_processor_thread.is_alive(): + logging.info("Waiting for frame processor thread to finish...") + frame_processor_thread.join(timeout=15.0) + + # Stop GStreamer pipelines + if gst_video_pipeline: gst_video_pipeline.stop() + if gst_audio_pipeline: gst_audio_pipeline.stop() + + elapsed = time.time() - start_time + logging.info(f">>> Inference run '{run_id}' finished in {elapsed:.2f}s. <<<") + + def process_and_send_frames(self, vae_to_blend_q, gst_video, gst_audio, total_frames_planned): + # This is the consumer loop + total_frames_processed = 0 + while total_frames_processed < total_frames_planned: + + # --- FRAME SKIPPING LOGIC --- + # If the queue is getting too full, processing is falling behind. + # Discard older frames to catch up to the most recent ones. + while vae_to_blend_q.qsize() > FRAME_SKIP_THRESHOLD: + try: + logging.warning(f"Queue high ({vae_to_blend_q.qsize()}). Skipping a frame batch to catch up.") + skipped_item = vae_to_blend_q.get_nowait() + if skipped_item is not None: + total_frames_processed += len(skipped_item[0]) # Add skipped frames to total + vae_to_blend_q.task_done() + except queue.Empty: + break # Queue emptied, proceed normally + # --- END FRAME SKIPPING --- + + try: + # Block and wait for the next item from the generator + batch_data = vae_to_blend_q.get(block=True, timeout=10.0) + except queue.Empty: + logging.error("Timeout waiting for frames from VAE. Ending processor thread.") + break + + if batch_data is None: # Sentinel value means we're done + logging.info("Received sentinel. Frame processor shutting down.") + break + + vae_frames, audio_chunk = batch_data + + # Process each frame in the batch (simplified, no inner parallelization) + for frame in vae_frames: + blended_frame = self._blend_single_frame(frame, gst_video.width, gst_video.height) + if blended_frame is not None: + if not gst_video.send_frame(blended_frame): + logging.error("Failed to send video frame to GStreamer. Stopping video sends.") + vae_to_blend_q.task_done() + return # Exit thread if pipe is broken + + # After sending all video frames for the batch, send the corresponding audio + if audio_chunk is not None and audio_chunk.size > 0: + if not gst_audio.send_audio(audio_chunk): + logging.error("Failed to send audio chunk to GStreamer.") + + total_frames_processed += len(vae_frames) + vae_to_blend_q.task_done() + + logging.info("--- Frame processor finished its work. ---") + + def _blend_single_frame(self, res_frame, target_width, target_height): + try: + cycle_len = len(self.coord_list_cycle) + if cycle_len == 0: + logging.error("Cannot blend frame, avatar cycle data is empty.") + return None + + current_idx = self.idx % cycle_len + + # Retrieve all necessary data for the current reference frame + bbox = self.coord_list_cycle[current_idx] + ori_frame = self.frame_list_cycle[current_idx] + mask = self.mask_list_cycle[current_idx] + + # Check for corrupt/missing data before processing + if mask is None or ori_frame is None or bbox is None: + logging.warning(f"Skipping frame blend at index {self.idx} due to missing or corrupt reference data for cycle index {current_idx}.") + self.idx += 1 + return None + + x, y, x1, y1 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]) + face_w, face_h = x1 - x, y1 - y + + if face_w <= 0 or face_h <= 0: + self.idx += 1 + return None + + resized_face = cv2.resize(res_frame.astype(np.uint8), (face_w, face_h), interpolation=cv2.INTER_LINEAR) + + if mask.shape[:2] != (face_h, face_w): + mask = cv2.resize(mask, (face_w, face_h), interpolation=cv2.INTER_LINEAR) + + # --- THIS IS THE CORRECTED LOGIC --- + # First, check if the mask is a 3-channel BGR image. + if len(mask.shape) == 3 and mask.shape[2] == 3: + # If it is, convert it to grayscale. + alpha_mask = (cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) / 255.0)[..., np.newaxis] + else: + # Otherwise, assume it's already the 1-channel grayscale image we need. + alpha_mask = (mask / 255.0)[..., np.newaxis] + # --- END OF CORRECTION --- + + blended_frame = ori_frame.copy() + face_region = blended_frame[y:y1, x:x1] + + # Ensure the alpha mask can be broadcast to the face region shape for blending + if face_region.shape != alpha_mask.shape: + alpha_mask = np.repeat(alpha_mask, 3, axis=2) + + blended_face = face_region.astype(np.float32) * (1.0 - alpha_mask) + resized_face.astype(np.float32) * alpha_mask + blended_frame[y:y1, x:x1] = blended_face.astype(np.uint8) + + self.idx += 1 + return cv2.resize(blended_frame, (target_width, target_height), interpolation=cv2.INTER_LINEAR) + + except Exception as e: + logging.error(f"Error blending frame at index {self.idx}: {e}", exc_info=True) + self.idx += 1 + return None + + +# --- Main Application Logic --- +inference_lock = threading.Lock() +audio_queue = queue.Queue(maxsize=10) + +def inference_worker(avatar, fps): + """Worker thread that waits for audio data from a queue and runs inference.""" + logging.info("🚀 Inference worker started. Waiting for audio data...") + while True: + audio_data = audio_queue.get() + if audio_data is None: # Shutdown signal + break + + if not inference_lock.acquire(blocking=False): + logging.warning("Inference already in progress. Skipping newly received audio.") + audio_queue.task_done() + continue + + try: + logging.info(f"Inference lock acquired. Processing {len(audio_data)} bytes of audio.") + avatar.inference(audio_data, fps) + finally: + inference_lock.release() + audio_queue.task_done() + logging.info("Inference lock released.") + +def file_watcher(pipe_path): + """Monitors a file for changes and puts its content into the audio queue.""" + logging.info(f"👀 Starting file watcher for: {pipe_path}") + last_processed_mtime = 0 + while True: + try: + if os.path.isfile(pipe_path): + current_mtime = os.path.getmtime(pipe_path) + if current_mtime > last_processed_mtime: + logging.info("New audio file detected. Reading content...") + # Add a small delay and retry to handle partially written files + time.sleep(0.1) + with open(pipe_path, "rb") as f: + opus_data = f.read() + + if opus_data: + audio_queue.put(opus_data) + last_processed_mtime = current_mtime + else: + logging.warning("File modification detected, but file was empty.") + last_processed_mtime = current_mtime + time.sleep(0.2) # Poll interval + except FileNotFoundError: + # This is okay, the file might be deleted and recreated + last_processed_mtime = 0 + time.sleep(0.5) + except Exception as e: + logging.error(f"Error in file watcher loop.", exc_info=True) + time.sleep(2) + + +if __name__ == "__main__": + logging.info("🎬 Starting Realtime Stream Sync Application...") + + # 1. Load environment variables from .env file + # This must be at the top of the execution block. + from dotenv import load_dotenv + load_dotenv() + + # 2. Read configuration exclusively from the loaded environment variables + AVATAR_CONFIG_PATH = os.getenv("AVATAR_CONFIG_PATH") + AVATAR_ID_TO_USE = os.getenv("AVATAR_ID_TO_USE") + TARGET_FPS = int(os.getenv("TARGET_FPS", "25")) + STREAM_PIPE_PATH = os.getenv("STREAM_PIPE_PATH") + + main_avatar = None + try: + # 3. Validate that the required variables were found in the .env file + if not all([AVATAR_CONFIG_PATH, AVATAR_ID_TO_USE, STREAM_PIPE_PATH]): + raise ValueError("A required variable (AVATAR_CONFIG_PATH, AVATAR_ID_TO_USE, or STREAM_PIPE_PATH) is missing from your .env file.") + + logging.info(f"Attempting to load Avatar ID '{AVATAR_ID_TO_USE}' from config file '{AVATAR_CONFIG_PATH}'...") + + # 4. Load the YAML config and select the specified avatar + config = OmegaConf.load(AVATAR_CONFIG_PATH) + if AVATAR_ID_TO_USE not in config: + raise ValueError(f"Avatar ID '{AVATAR_ID_TO_USE}' was not found as a key in the YAML file '{AVATAR_CONFIG_PATH}'") + + avatar_config = config[AVATAR_ID_TO_USE] + + # 5. Instantiate the Avatar with the loaded configuration + main_avatar = Avatar( + avatar_id=AVATAR_ID_TO_USE, + video_path=avatar_config.video_path, + bbox_shift=avatar_config.bbox_shift, + batch_size=avatar_config.get("batch_size", 4), + preparation=avatar_config.preparation + ) + except Exception as e: + logging.critical("❌ CRITICAL ERROR during setup.", exc_info=True) + sys.exit(1) + + # 6. Start the watcher and inference worker threads + inference_thread = threading.Thread(target=inference_worker, args=(main_avatar, TARGET_FPS), daemon=True) + watcher_thread = threading.Thread(target=file_watcher, args=(STREAM_PIPE_PATH,), daemon=True) + + inference_thread.start() + watcher_thread.start() + + # 7. Keep the main thread alive and handle graceful shutdown + logging.info("✅ Application is running. Press Ctrl+C to shut down.") + try: + while True: + if not watcher_thread.is_alive() or not inference_thread.is_alive(): + logging.error("A critical worker thread has died. Shutting down.") + break + time.sleep(1.0) + except KeyboardInterrupt: + logging.info("\n🛑 KeyboardInterrupt received. Initiating graceful shutdown...") + finally: + audio_queue.put(None) + if watcher_thread.is_alive(): watcher_thread.join(timeout=2) + if inference_thread.is_alive(): inference_thread.join(timeout=30) + logging.info("✅ Application shutdown complete.") \ No newline at end of file diff --git a/scripts/test.py b/scripts/test.py new file mode 100644 index 00000000..92cbf9ae --- /dev/null +++ b/scripts/test.py @@ -0,0 +1,54 @@ +import ffmpeg +import numpy as np +import pyaudio +import time + +def load_audio_ffmpeg(audio_path, sample_rate=48000): + """ 用 FFmpeg 读取音频并转换为 int16 PCM """ + out, _ = ( + ffmpeg.input(audio_path) + .output('pipe:', format='s16le', acodec='pcm_s16le', ac=1, ar=sample_rate) + .run(capture_stdout=True) + ) + audio_data = np.frombuffer(out, dtype=np.int16) + return audio_data, sample_rate + +def split_audio(audio_data, num_chunks): + """ 将音频数据均分为 num_chunks 份 """ + chunk_size = len(audio_data) // num_chunks + audio_chunks = [audio_data[i * chunk_size: (i + 1) * chunk_size] for i in range(num_chunks)] + return audio_chunks + +def play_audio_chunk(audio_chunk, sample_rate=48000): + """ 播放单个音频片段 """ + p = pyaudio.PyAudio() + stream = p.open(format=pyaudio.paInt16, channels=1, rate=sample_rate, output=True) + + print(f"🎧 播放音频片段: {len(audio_chunk)} samples") + stream.write(audio_chunk.tobytes()) # 确保是 int16 PCM + stream.stop_stream() + stream.close() + p.terminate() + print("✅ 音频播放完成") + +def main(): + audio_path = "input.wav" # 你的音频文件路径 + num_chunks = 10 # 切分成 10 段 + + # 读取音频 + print("🎵 读取音频中...") + audio_data, sample_rate = load_audio_ffmpeg("/home/fan370/Documents/MuseTalk/data/audio/sun.wav") + print(f"📊 音频数据大小: {audio_data.shape}, 采样率: {sample_rate} Hz") + + # 切割音频 + print("✂️ 正在切割音频...") + audio_chunks = split_audio(audio_data, num_chunks) + + # 播放音频 + for i, chunk in enumerate(audio_chunks): + print(f"🎧 播放第 {i+1}/{len(audio_chunks)} 个音频片段") + play_audio_chunk(chunk, sample_rate) + time.sleep(0.1) # 避免过快播放 + +if __name__ == "__main__": + main() diff --git a/setup-gstreamer-env.ps1 b/setup-gstreamer-env.ps1 new file mode 100644 index 00000000..4bbf0e64 --- /dev/null +++ b/setup-gstreamer-env.ps1 @@ -0,0 +1,17 @@ +# Path to GStreamer installation +$GST_ROOT = "E:\\gstreamer\\1.0\\msvc_x86_64" + +# Add GStreamer's Python module directory to PYTHONPATH +$env:PYTHONPATH = "$GST_ROOT\\lib\\python3\\site-packages;$env:PYTHONPATH" + +# Add GStreamer binary directory to PATH +$env:PATH = "$GST_ROOT\\bin;$env:PATH" + +# Set GI_TYPELIB_PATH to find GObject Introspection typelibs +$env:GI_TYPELIB_PATH = "$GST_ROOT\\lib\\girepository-1.0" + +# Set GST_PLUGIN_PATH to find GStreamer plugins +$env:GST_PLUGIN_PATH = "$GST_ROOT\\lib\\gstreamer-1.0" + +# Run a simple test to verify it works +python -c "import gi; gi.require_version('Gst', '1.0'); from gi.repository import Gst; Gst.init(None); print(f'GStreamer version: {Gst.version_string()}'); print('GStreamer successfully initialized!')" diff --git a/setup-pygstreamer-env.ps1 b/setup-pygstreamer-env.ps1 new file mode 100644 index 00000000..56f6d0a9 --- /dev/null +++ b/setup-pygstreamer-env.ps1 @@ -0,0 +1,23 @@ +# 1. Create a bridge script named 'setup-gstreamer-env.ps1' +@" +# Path to GStreamer installation +`$GST_ROOT = "E:\\gstreamer\\1.0\\msvc_x86_64" + +# Add GStreamer's Python module directory to PYTHONPATH +`$env:PYTHONPATH = "`$GST_ROOT\\lib\\python3\\site-packages;`$env:PYTHONPATH" + +# Add GStreamer binary directory to PATH +`$env:PATH = "`$GST_ROOT\\bin;`$env:PATH" + +# Set GI_TYPELIB_PATH to find GObject Introspection typelibs +`$env:GI_TYPELIB_PATH = "`$GST_ROOT\\lib\\girepository-1.0" + +# Set GST_PLUGIN_PATH to find GStreamer plugins +`$env:GST_PLUGIN_PATH = "`$GST_ROOT\\lib\\gstreamer-1.0" + +# Run a simple test to verify it works +python -c "import gi; gi.require_version('Gst', '1.0'); from gi.repository import Gst; Gst.init(None); print(f'GStreamer version: {Gst.version_string()}'); print('GStreamer successfully initialized!')" +"@ | Out-File -FilePath setup-gstreamer-env.ps1 -Encoding utf8 + +# 2. Run the script to configure environment +./setup-gstreamer-env.ps1 diff --git a/setup-pygstreamer.ps1 b/setup-pygstreamer.ps1 new file mode 100644 index 00000000..9e3011c0 --- /dev/null +++ b/setup-pygstreamer.ps1 @@ -0,0 +1,34 @@ +# Create this script as setup-pygstreamer.ps1 + +# Clear any existing paths that might conflict +$env:PATH = ($env:PATH -split ';' | Where-Object { -not $_.Contains('msys64') }) -join ';' + +# Add GStreamer paths +$GST_ROOT = "E:\gstreamer\1.0\msvc_x86_64" +$env:PATH = "$GST_ROOT\bin;$env:PATH" +$env:GI_TYPELIB_PATH = "$GST_ROOT\lib\girepository-1.0" +$env:PKG_CONFIG_PATH = "$GST_ROOT\lib\pkgconfig" +$env:GST_PLUGIN_PATH = "$GST_ROOT\lib\gstreamer-1.0" + +# Install PyGObjects' Windows binaries +pip install --no-cache-dir PyGObject-stubs +pip install --no-cache-dir pycairo + +# Don't forget to install torch/CUDA if needed for MuseTalk +# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 + +# Run a test script to verify +$TEST_SCRIPT = @' +import gi +gi.require_version("Gst", "1.0") +from gi.repository import Gst + +Gst.init(None) +print(f"GStreamer version: {Gst.version_string()}") +print("GStreamer successfully initialized!") +'@ + +Write-Host "Creating test script..." +$TEST_SCRIPT | Out-File -FilePath test_gst.py -Encoding utf8 +Write-Host "Running test script..." +python test_gst.py diff --git a/test-custom-opencv.py b/test-custom-opencv.py new file mode 100644 index 00000000..d84165db --- /dev/null +++ b/test-custom-opencv.py @@ -0,0 +1,34 @@ +# test-custom-opencv.py +import use_custom_opencv + +# Now import OpenCV +import cv2 +import os + +# Print OpenCV details +print(f"OpenCV version: {cv2.__version__}") +print(f"OpenCV path: {os.path.dirname(cv2.__file__)}") + +# Check for GStreamer support +build_info = cv2.getBuildInformation() +gstreamer_line = [line for line in build_info.splitlines() + if "GStreamer:" in line] +print(f"GStreamer support: {gstreamer_line[0] if gstreamer_line else 'Not found'}") + +# Test a simple GStreamer pipeline +try: + pipeline = "videotestsrc ! videoconvert ! appsink" + cap = cv2.VideoCapture(pipeline, cv2.CAP_GSTREAMER) + + if cap.isOpened(): + print("GStreamer pipeline working!") + ret, frame = cap.read() + if ret: + print(f"Successfully read frame with shape: {frame.shape}") + else: + print("Failed to read frame") + cap.release() + else: + print("Failed to open GStreamer pipeline") +except Exception as e: + print(f"Error testing GStreamer: {e}") \ No newline at end of file diff --git a/test-gst.py b/test-gst.py new file mode 100644 index 00000000..eb237ac2 --- /dev/null +++ b/test-gst.py @@ -0,0 +1,21 @@ +import os +import sys + +# Add DLL directory +opencv_bin = r"D:\tencent\devel\cv\opencv-4.5.5\build\install\x64\vc16\bin" +os.add_dll_directory(opencv_bin) + +# Find the exact .pyd file +pyd_dir = r"D:\tencent\devel\cv\opencv-4.5.5\build\lib\python3\Release" +print("Looking for cv2.pyd in:", pyd_dir) +for file in os.listdir(pyd_dir): + if file.endswith(".pyd"): + print(f"Found: {file}") + +# Add to path and try import +sys.path.insert(0, pyd_dir) +try: + import cv2 + print("Success!") +except ImportError as e: + print(f"Failed: {e}") diff --git a/test_opencv_import.py b/test_opencv_import.py new file mode 100644 index 00000000..a40c6a14 --- /dev/null +++ b/test_opencv_import.py @@ -0,0 +1,31 @@ + +import os +import sys + +# Add paths +opencv_python_path = r"D:\tencent\devel\cv\opencv-4.5.5\build\lib\python3\Release" +opencv_bin = r"D:\tencent\devel\cv\opencv-4.5.5\build\install\x64\vc16\bin" +gst_bin = r"E:\gstreamer\1.0\msvc_x86_64\bin" + +# Add to Python path +if opencv_python_path not in sys.path: + sys.path.insert(0, opencv_python_path) + +# Add to PATH +os.environ["PATH"] = opencv_bin + os.pathsep + os.environ["PATH"] +os.environ["PATH"] = gst_bin + os.pathsep + os.environ["PATH"] + +# Add DLL directories +if sys.version_info >= (3, 8): + os.add_dll_directory(opencv_bin) + os.add_dll_directory(gst_bin) + +# Try importing OpenCV +try: + print("Attempting to import cv2...") + import cv2 + print(f"Success! OpenCV version: {cv2.__version__}") +except Exception as e: + print(f"Error importing cv2: {e}") + import traceback + traceback.print_exc() diff --git a/use-custom-opencv.py b/use-custom-opencv.py new file mode 100644 index 00000000..13ba8a10 --- /dev/null +++ b/use-custom-opencv.py @@ -0,0 +1,67 @@ +import os +import sys +import site +import platform + +def setup_opencv_environment(): + """Configure environment for custom OpenCV with all required DLLs""" + print(f"Setting up OpenCV environment for Python {sys.version}") + + # Paths to OpenCV and its dependencies + opencv_python_path = r"D:\tencent\devel\cv\opencv-4.5.5\build\lib\python3\Release" + opencv_bin = r"D:\tencent\devel\cv\opencv-4.5.5\build\install\x64\vc16\bin" + opencv_lib = r"D:\tencent\devel\cv\opencv-4.5.5\build\install\x64\vc16\lib" + gst_bin = r"E:\gstreamer\1.0\msvc_x86_64\bin" + + # Add OpenCV module path to Python path + if opencv_python_path not in sys.path: + sys.path.insert(0, opencv_python_path) + print(f"Added to Python path: {opencv_python_path}") + + # Add DLL directories to PATH + paths_to_add = [opencv_bin, opencv_lib, gst_bin] + for path in paths_to_add: + if os.path.exists(path): + if path not in os.environ["PATH"]: + os.environ["PATH"] = path + os.pathsep + os.environ["PATH"] + print(f"Added to PATH: {path}") + else: + print(f"WARNING: Path does not exist: {path}") + + # Use add_dll_directory for Python 3.8+ + if sys.version_info >= (3, 8): + for path in paths_to_add: + if os.path.exists(path): + try: + os.add_dll_directory(path) + print(f"Added DLL directory: {path}") + except Exception as e: + print(f"ERROR adding DLL directory {path}: {e}") + + # Print architecture information + print(f"Python architecture: {platform.architecture()[0]}") + + return True + +# Execute if run directly +if __name__ == "__main__": + success = setup_opencv_environment() + + if success: + try: + import cv2 + print(f"\nOpenCV successfully loaded!") + print(f"OpenCV version: {cv2.__version__}") + print(f"OpenCV path: {os.path.dirname(cv2.__file__)}") + + # Check for GStreamer support + build_info = cv2.getBuildInformation() + gstreamer_line = [line for line in build_info.splitlines() + if "GStreamer:" in line] + print(f"GStreamer support: {gstreamer_line[0] if gstreamer_line else 'Not found'}") + except Exception as e: + print(f"\nERROR loading OpenCV: {e}") + + # Provide detailed error information + import traceback + traceback.print_exc() \ No newline at end of file