Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,33 +148,38 @@ We also hope you note that we have not verified, maintained, or updated third-pa
To prepare the Python environment and install additional packages such as opencv, diffusers, mmcv, etc., please follow the steps below:

### Build environment
We recommend Python 3.10 and CUDA 11.7. Set up your environment as follows:
We recommend Python 3.10 and CUDA 11.7 or 12.x. Set up your environment as follows:

```shell
conda create -n MuseTalk python==3.10
conda activate MuseTalk
```

### Install PyTorch 2.0.1
Choose one of the following installation methods:
### Install PyTorch
Install PyTorch **before** other dependencies. Choose the build that matches your GPU and CUDA:

```shell
# Option 1: Using pip
# Option 1: CUDA 11.8 (most GPUs, e.g. RTX 20/30 series)
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118

# Option 2: Using conda
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia
# Option 2: CUDA 12.8 (RTX 50 series / Blackwell, sm_120 support)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128

# Option 3: CPU only (slower inference)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
```

**Note:** For NVIDIA RTX 5060 Ti / RTX 50 series (Blackwell), use the CUDA 12.8 index above. Older PyTorch builds do not include kernels for sm_120.

### Install Dependencies
Install the remaining required packages:
Install the remaining required packages (use `numpy<2` for opencv compatibility):

```shell
pip install -r requirements.txt
```

### Install MMLab Packages
Install the MMLab ecosystem packages:
Install the MMLab ecosystem packages via OpenMIM:

```bash
pip install --no-cache-dir -U openmim
Expand All @@ -184,6 +189,8 @@ mim install "mmdet==3.1.0"
mim install "mmpose==1.1.0"
```

**Alternative (mmcv-lite):** If no pre-built mmcv wheel is available for your PyTorch/CUDA (e.g. some Windows + CUDA 12.8 setups), you can use `mmcv-lite` instead of full mmcv. The realtime inference code includes fallbacks for NMS and ROI align when using mmcv-lite.

### Setup FFmpeg
1. [Download](https://github.com/BtbN/FFmpeg-Builds/releases) the ffmpeg-static package

Expand Down Expand Up @@ -324,7 +331,6 @@ Important notes for real-time inference:
2. After preparation, the avatar will generate videos using audio clips from `audio_clips`
3. The generation process can achieve 30fps+ on an NVIDIA Tesla V100
4. Set `preparation` to `False` for generating more videos with the same avatar

For faster generation without saving images, you can use:
```bash
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path
if not os.path.isfile(path_to_detector):
model_weights = load_url(models_urls['s3fd'])
else:
model_weights = torch.load(path_to_detector)
model_weights = torch.load(path_to_detector, weights_only=False)

self.face_detector = s3fd()
self.face_detector.load_state_dict(model_weights)
Expand Down
4 changes: 2 additions & 2 deletions musetalk/utils/face_parsing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def model_init(self,
net = BiSeNet(resnet_path)
if torch.cuda.is_available():
net.cuda()
net.load_state_dict(torch.load(model_pth))
net.load_state_dict(torch.load(model_pth, weights_only=False))
else:
net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu')))
net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu'), weights_only=False))
net.eval()
return net

Expand Down
2 changes: 1 addition & 1 deletion musetalk/utils/face_parsing/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def forward(self, x):
return feat8, feat16, feat32

def init_weight(self, model_path):
state_dict = torch.load(model_path) #modelzoo.load_url(resnet18_url)
state_dict = torch.load(model_path, weights_only=False) # legacy .tar format
self_state_dict = self.state_dict()
for k, v in state_dict.items():
if 'fc' in k: continue
Expand Down
8 changes: 7 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Install PyTorch first (see README for GPU/CPU and RTX 50 series options):
# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 # CUDA 12.8 (e.g. RTX 50)
# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 # CUDA 11.8

diffusers==0.30.2
accelerate==0.28.0
numpy==1.23.5
# numpy<2 required for opencv-python compatibility
numpy>=1.22,<2
tensorflow==2.12.0
tensorboard==2.12.0
opencv-python==4.9.0.80
Expand All @@ -18,3 +23,4 @@ imageio[ffmpeg]
omegaconf
ffmpeg-python
moviepy
openmim
4 changes: 2 additions & 2 deletions scripts/realtime_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def init(self):
osmakedirs([self.avatar_path, self.full_imgs_path, self.video_out_path, self.mask_out_path])
self.prepare_material()
else:
self.input_latent_list_cycle = torch.load(self.latents_out_path)
self.input_latent_list_cycle = torch.load(self.latents_out_path, weights_only=False)
with open(self.coords_path, 'rb') as f:
self.coord_list_cycle = pickle.load(f)
input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
Expand Down Expand Up @@ -133,7 +133,7 @@ def init(self):
else:
sys.exit()
else:
self.input_latent_list_cycle = torch.load(self.latents_out_path)
self.input_latent_list_cycle = torch.load(self.latents_out_path, weights_only=False)
with open(self.coords_path, 'rb') as f:
self.coord_list_cycle = pickle.load(f)
input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
Expand Down