Skip to content

Commit 17ce5f8

Browse files
committed
add more audio format support
1 parent bf7622e commit 17ce5f8

3 files changed

Lines changed: 30 additions & 14 deletions

File tree

interface/interface_audio.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
face_pts_mean = adjust_verts(face_pts_mean)
3333
teeth_verts_ = render_verts_[478:, :3]
3434
head_joint = np.array([out_size * 0.5, out_size * 3 / 4, -0.])
35-
def run_audio(img_path, wavpath, output_path, template_path = None):
35+
def run_audio(img_path, audio_path, output_path, template_path = None):
3636
img_primer_rgba, source_img, source_crop_pts, source_crop_pts_vt, source_crop_coords = face_process(img_path, out_size)
3737

3838
# print(source_img.shape)
@@ -57,7 +57,7 @@ def run_audio(img_path, wavpath, output_path, template_path = None):
5757
tensor_source_prompt = torch.from_numpy(source_prompt / 255.).float().permute(2, 0, 1).unsqueeze(0).to(
5858
device)
5959

60-
pts_audio_driving = audio_interface(wavpath)
60+
pts_audio_driving = audio_interface(audio_path)
6161
frame_num = len(pts_audio_driving)
6262
import uuid
6363
task_id = str(uuid.uuid1())
@@ -183,27 +183,26 @@ def run_audio(img_path, wavpath, output_path, template_path = None):
183183
videoWriter.write(frame[..., ::-1])
184184
videoWriter.release()
185185
val_video = output_path
186-
wav_path = wavpath
187186
os.system(
188-
"ffmpeg -i {} -i {} -c:v libx264 -pix_fmt yuv420p {}".format(save_path, wav_path, val_video))
187+
"ffmpeg -i {} -i {} -c:v libx264 -pix_fmt yuv420p {}".format(save_path, audio_path, val_video))
189188
os.remove(save_path)
190189
cv2.destroyAllWindows()
191190

192191
def main():
193192
# 检查命令行参数的数量
194193
if len(sys.argv) < 4 or len(sys.argv) > 5:
195-
print("Usage: python interface_audio.py <img_path> <wav_path> <output_path> <template_path>")
194+
print("Usage: python interface_audio.py <img_path> <audio_path> <output_path> <template_path>")
196195
sys.exit(1) # 参数数量不正确时退出程序
197196

198197
img_path = sys.argv[1]
199-
wav_path = sys.argv[2]
198+
audio_path = sys.argv[2]
200199
output_path = sys.argv[3]
201200
if len(sys.argv) == 4:
202201
template_path = None
203202
else:
204203
template_path = sys.argv[4]
205-
print(f"img path is set to: {img_path}, wav path is set to: {wav_path}, output path is set to: {output_path}")
206-
run_audio(img_path, wav_path, output_path, template_path)
204+
print(f"img path is set to: {img_path}, wav path is set to: {audio_path}, output path is set to: {output_path}")
205+
run_audio(img_path, audio_path, output_path, template_path)
207206

208207
if __name__ == "__main__":
209208
main()

interface/utils.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,23 @@ def rgb_face_process(img_primer_bgr, out_size):
112112
# mat_list, _, face_pts_mean_personal_primer = calc_face_mat(pts_driven, face_pts_mean)
113113
# return source_img, source_crop_pts
114114

115+
# 读取音频文件
116+
def load_audio(file_path):
117+
import librosa
118+
import numpy as np
115119

116-
def audio_interface(wavpath):
120+
# 使用 librosa 读取音频文件
121+
# sr=None 表示不改变原始采样率,mono=True 表示转换为单声道
122+
y, sr = librosa.load(file_path, sr=None, mono=True)
123+
124+
# 将采样率转换为 16kHz
125+
y_16k = librosa.resample(y, orig_sr=sr, target_sr=16000)
126+
127+
# 确保数据类型为 float32
128+
y_16k = y_16k.astype(np.float32)
129+
return y_16k
130+
131+
def audio_interface(audio_path):
117132
global Audio2FeatureModel,PcaModel
118133
if Audio2FeatureModel is None:
119134
current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -124,10 +139,11 @@ def audio_interface(wavpath):
124139
Audio2FeatureModel.load_state_dict(torch.load(ckpt_path))
125140
Audio2FeatureModel = Audio2FeatureModel.to(device)
126141
Audio2FeatureModel.eval()
127-
rate, wav = wavfile.read(wavpath, mmap=False)
128-
129-
augmented_samples = wav
130-
augmented_samples2 = augmented_samples.astype(np.float32, order='C') / 32768.0
142+
# rate, wav = wavfile.read(wavpath, mmap=False)
143+
#
144+
# augmented_samples = wav
145+
# augmented_samples2 = augmented_samples.astype(np.float32, order='C') / 32768.0
146+
augmented_samples2 = load_audio(audio_path)
131147

132148
opts = knf.FbankOptions()
133149
opts.frame_opts.dither = 0

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ tqdm
44
scikit-learn
55
glfw
66
PyOpenGL
7-
onnxruntime
7+
onnxruntime
8+
librosa

0 commit comments

Comments
 (0)