Skip to content

Commit a3d469d

Browse files
committed
✨feat: 添加对Gemini模型的音频处理支持,更新 ProviderRequest 以包含音频 URL 列表
1 parent dff15cf commit a3d469d

File tree

3 files changed

+122
-29
lines changed

3 files changed

+122
-29
lines changed

astrbot/core/pipeline/process_stage/method/llm_request.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
ResultContentType,
1515
MessageChain,
1616
)
17-
from astrbot.core.message.components import Image
17+
from astrbot.core.message.components import Image, Record
1818
from astrbot.core import logger
1919
from astrbot.core.utils.metrics import Metric
2020
from astrbot.core.provider.entities import (
@@ -77,16 +77,33 @@ async def process(
7777
)
7878

7979
else:
80-
req = ProviderRequest(prompt="", image_urls=[])
80+
req = ProviderRequest(prompt="", image_urls=[], audio_urls=[])
8181
if self.provider_wake_prefix:
8282
if not event.message_str.startswith(self.provider_wake_prefix):
8383
return
8484
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
8585
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
86+
87+
# 处理消息中的图片和音频
88+
has_audio = False
8689
for comp in event.message_obj.message:
8790
if isinstance(comp, Image):
8891
image_path = await comp.convert_to_file_path()
8992
req.image_urls.append(image_path)
93+
elif isinstance(comp, Record):
94+
# 处理音频消息
95+
audio_path = await comp.convert_to_file_path()
96+
logger.info(f"检测到音频消息,路径: {audio_path}")
97+
has_audio = True
98+
if hasattr(req, "audio_urls"):
99+
req.audio_urls.append(audio_path)
100+
else:
101+
# 为了兼容性,如果ProviderRequest没有audio_urls属性
102+
req.audio_urls = [audio_path]
103+
104+
# 如果只有音频没有文本,添加默认文本
105+
if not req.prompt and has_audio:
106+
req.prompt = "[用户发送的音频将其视为文本输入与其进行聊天]"
90107

91108
# 获取对话上下文
92109
conversation_id = await self.conv_manager.get_curr_conversation_id(

astrbot/core/provider/entities.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ class ProviderRequest:
9696
"""会话 ID"""
9797
image_urls: List[str] = None
9898
"""图片 URL 列表"""
99+
audio_urls: List[str] = None
100+
"""音频 URL 列表"""
99101
func_tool: FuncCall = None
100102
"""可用的函数工具"""
101103
contexts: List = None

astrbot/core/provider/sources/gemini_source.py

Lines changed: 101 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import json
44
import logging
55
import random
6+
import os
7+
import mimetypes
68
from typing import Dict, List, Optional
79
from collections.abc import AsyncGenerator
810

@@ -193,6 +195,12 @@ def process_image_url(image_url_dict: dict) -> types.Part:
193195
mime_type = url.split(":")[1].split(";")[0]
194196
image_bytes = base64.b64decode(url.split(",", 1)[1])
195197
return types.Part.from_bytes(data=image_bytes, mime_type=mime_type)
198+
199+
def process_inline_data(inline_data_dict: dict) -> types.Part:
200+
"""处理内联数据,如音频""" # TODO: 处理视频?
201+
mime_type = inline_data_dict["mime_type"]
202+
data = inline_data_dict.get("data", "")
203+
return types.Part.from_bytes(data=data, mime_type=mime_type)
196204

197205
def append_or_extend(contents: list[types.Content], part: list[types.Part], content_cls: type[types.Content]) -> None:
198206
if contents and isinstance(contents[-1], content_cls):
@@ -212,12 +220,15 @@ def append_or_extend(contents: list[types.Content], part: list[types.Part], cont
212220

213221
if role == "user":
214222
if isinstance(content, list):
215-
parts = [
216-
types.Part.from_text(text=item["text"] or " ")
217-
if item["type"] == "text"
218-
else process_image_url(item["image_url"])
219-
for item in content
220-
]
223+
parts = []
224+
for item in content:
225+
if item["type"] == "text":
226+
parts.append(types.Part.from_text(text=item["text"] or " "))
227+
elif item["type"] == "image_url":
228+
parts.append(process_image_url(item["image_url"]))
229+
elif item["type"] == "inline_data":
230+
# 处理内联数据,如音频
231+
parts.append(process_inline_data(item["inline_data"]))
221232
else:
222233
parts = [create_text_part(content)]
223234
append_or_extend(gemini_contents, parts, types.UserContent)
@@ -447,13 +458,14 @@ async def text_chat(
447458
prompt: str,
448459
session_id: str = None,
449460
image_urls: List[str] = None,
461+
audio_urls: List[str] = None,
450462
func_tool: FuncCall = None,
451463
contexts=[],
452464
system_prompt=None,
453465
tool_calls_result=None,
454466
**kwargs,
455467
) -> LLMResponse:
456-
new_record = await self.assemble_context(prompt, image_urls)
468+
new_record = await self.assemble_context(prompt, image_urls, audio_urls)
457469
context_query = [*contexts, new_record]
458470
if system_prompt:
459471
context_query.insert(0, {"role": "system", "content": system_prompt})
@@ -486,14 +498,15 @@ async def text_chat_stream(
486498
self,
487499
prompt: str,
488500
session_id: str = None,
489-
image_urls: List[str] = [],
501+
image_urls: List[str] = None,
502+
audio_urls: List[str] = None,
490503
func_tool: FuncCall = None,
491504
contexts=[],
492505
system_prompt=None,
493506
tool_calls_result=None,
494507
**kwargs,
495508
) -> AsyncGenerator[LLMResponse, None]:
496-
new_record = await self.assemble_context(prompt, image_urls)
509+
new_record = await self.assemble_context(prompt, image_urls, audio_urls)
497510
context_query = [*contexts, new_record]
498511
if system_prompt:
499512
context_query.insert(0, {"role": "system", "content": system_prompt})
@@ -545,30 +558,55 @@ def set_key(self, key):
545558
self.chosen_api_key = key
546559
self._init_client()
547560

548-
async def assemble_context(self, text: str, image_urls: List[str] = None):
561+
async def assemble_context(self, text: str, image_urls: List[str] = None, audio_urls: List[str] = None):
549562
"""
550563
组装上下文。
551564
"""
552-
if image_urls:
565+
has_media = (image_urls and len(image_urls) > 0) or (audio_urls and len(audio_urls) > 0)
566+
567+
if has_media:
553568
user_content = {
554569
"role": "user",
555-
"content": [{"type": "text", "text": text if text else "[图片]"}],
570+
"content": [{"type": "text", "text": text if text else "[媒体内容]"}],
556571
}
557-
for image_url in image_urls:
558-
if image_url.startswith("http"):
559-
image_path = await download_image_by_url(image_url)
560-
image_data = await self.encode_image_bs64(image_path)
561-
elif image_url.startswith("file:///"):
562-
image_path = image_url.replace("file:///", "")
563-
image_data = await self.encode_image_bs64(image_path)
564-
else:
565-
image_data = await self.encode_image_bs64(image_url)
566-
if not image_data:
567-
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
568-
continue
569-
user_content["content"].append(
570-
{"type": "image_url", "image_url": {"url": image_data}}
571-
)
572+
573+
# 处理图片
574+
if image_urls:
575+
for image_url in image_urls:
576+
if image_url.startswith("http"):
577+
image_path = await download_image_by_url(image_url)
578+
image_data = await self.encode_image_bs64(image_path)
579+
elif image_url.startswith("file:///"):
580+
image_path = image_url.replace("file:///", "")
581+
image_data = await self.encode_image_bs64(image_path)
582+
else:
583+
image_data = await self.encode_image_bs64(image_url)
584+
if not image_data:
585+
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
586+
continue
587+
user_content["content"].append(
588+
{"type": "image_url", "image_url": {"url": image_data}}
589+
)
590+
591+
# 处理音频
592+
if audio_urls:
593+
for audio_url in audio_urls:
594+
audio_bytes, mime_type = await self.encode_audio_data(audio_url)
595+
if not audio_bytes or not mime_type:
596+
logger.warning(f"音频 {audio_url} 处理失败,将忽略。")
597+
continue
598+
599+
# 添加音频数据
600+
user_content["content"].append(
601+
{
602+
"type": "inline_data",
603+
"inline_data": {
604+
"mime_type": mime_type,
605+
"data": audio_bytes
606+
}
607+
}
608+
)
609+
572610
return user_content
573611
else:
574612
return {"role": "user", "content": text}
@@ -584,5 +622,41 @@ async def encode_image_bs64(self, image_url: str) -> str:
584622
return "data:image/jpeg;base64," + image_bs64
585623
return ""
586624

625+
async def encode_audio_data(self, audio_url: str) -> tuple:
626+
"""
627+
读取音频文件并返回二进制数据
628+
629+
Returns:
630+
tuple: (音频二进制数据, MIME类型)
631+
"""
632+
try:
633+
# 直接读取文件二进制数据
634+
with open(audio_url, "rb") as f:
635+
audio_bytes = f.read()
636+
637+
# 推断 MIME 类型
638+
mime_type = mimetypes.guess_type(audio_url)[0]
639+
if not mime_type:
640+
# 根据文件扩展名确定 MIME 类型
641+
extension = os.path.splitext(audio_url)[1].lower()
642+
if extension == '.wav':
643+
mime_type = 'audio/wav'
644+
elif extension == '.mp3':
645+
mime_type = 'audio/mpeg'
646+
elif extension == '.ogg':
647+
mime_type = 'audio/ogg'
648+
elif extension == '.flac':
649+
mime_type = 'audio/flac'
650+
elif extension == '.m4a':
651+
mime_type = 'audio/mp4'
652+
else:
653+
mime_type = 'audio/wav' # 默认
654+
655+
logger.info(f"音频文件处理成功: {audio_url},mime类型: {mime_type},大小: {len(audio_bytes)} 字节")
656+
return audio_bytes, mime_type
657+
except Exception as e:
658+
logger.error(f"音频文件处理失败: {e}")
659+
return None, None
660+
587661
async def terminate(self):
588662
logger.info("Google GenAI 适配器已终止。")

0 commit comments

Comments
 (0)