3
3
import json
4
4
import logging
5
5
import random
6
+ import os
7
+ import mimetypes
6
8
from typing import Dict , List , Optional
7
9
from collections .abc import AsyncGenerator
8
10
@@ -193,6 +195,12 @@ def process_image_url(image_url_dict: dict) -> types.Part:
193
195
mime_type = url .split (":" )[1 ].split (";" )[0 ]
194
196
image_bytes = base64 .b64decode (url .split ("," , 1 )[1 ])
195
197
return types .Part .from_bytes (data = image_bytes , mime_type = mime_type )
198
+
199
+ def process_inline_data (inline_data_dict : dict ) -> types .Part :
200
+ """处理内联数据,如音频""" # TODO: 处理视频?
201
+ mime_type = inline_data_dict ["mime_type" ]
202
+ data = inline_data_dict .get ("data" , "" )
203
+ return types .Part .from_bytes (data = data , mime_type = mime_type )
196
204
197
205
def append_or_extend (contents : list [types .Content ], part : list [types .Part ], content_cls : type [types .Content ]) -> None :
198
206
if contents and isinstance (contents [- 1 ], content_cls ):
@@ -212,12 +220,15 @@ def append_or_extend(contents: list[types.Content], part: list[types.Part], cont
212
220
213
221
if role == "user" :
214
222
if isinstance (content , list ):
215
- parts = [
216
- types .Part .from_text (text = item ["text" ] or " " )
217
- if item ["type" ] == "text"
218
- else process_image_url (item ["image_url" ])
219
- for item in content
220
- ]
223
+ parts = []
224
+ for item in content :
225
+ if item ["type" ] == "text" :
226
+ parts .append (types .Part .from_text (text = item ["text" ] or " " ))
227
+ elif item ["type" ] == "image_url" :
228
+ parts .append (process_image_url (item ["image_url" ]))
229
+ elif item ["type" ] == "inline_data" :
230
+ # 处理内联数据,如音频
231
+ parts .append (process_inline_data (item ["inline_data" ]))
221
232
else :
222
233
parts = [create_text_part (content )]
223
234
append_or_extend (gemini_contents , parts , types .UserContent )
@@ -447,13 +458,14 @@ async def text_chat(
447
458
prompt : str ,
448
459
session_id : str = None ,
449
460
image_urls : List [str ] = None ,
461
+ audio_urls : List [str ] = None ,
450
462
func_tool : FuncCall = None ,
451
463
contexts = [],
452
464
system_prompt = None ,
453
465
tool_calls_result = None ,
454
466
** kwargs ,
455
467
) -> LLMResponse :
456
- new_record = await self .assemble_context (prompt , image_urls )
468
+ new_record = await self .assemble_context (prompt , image_urls , audio_urls )
457
469
context_query = [* contexts , new_record ]
458
470
if system_prompt :
459
471
context_query .insert (0 , {"role" : "system" , "content" : system_prompt })
@@ -486,14 +498,15 @@ async def text_chat_stream(
486
498
self ,
487
499
prompt : str ,
488
500
session_id : str = None ,
489
- image_urls : List [str ] = [],
501
+ image_urls : List [str ] = None ,
502
+ audio_urls : List [str ] = None ,
490
503
func_tool : FuncCall = None ,
491
504
contexts = [],
492
505
system_prompt = None ,
493
506
tool_calls_result = None ,
494
507
** kwargs ,
495
508
) -> AsyncGenerator [LLMResponse , None ]:
496
- new_record = await self .assemble_context (prompt , image_urls )
509
+ new_record = await self .assemble_context (prompt , image_urls , audio_urls )
497
510
context_query = [* contexts , new_record ]
498
511
if system_prompt :
499
512
context_query .insert (0 , {"role" : "system" , "content" : system_prompt })
@@ -545,30 +558,55 @@ def set_key(self, key):
545
558
self .chosen_api_key = key
546
559
self ._init_client ()
547
560
548
- async def assemble_context (self , text : str , image_urls : List [str ] = None ):
561
+ async def assemble_context (self , text : str , image_urls : List [str ] = None , audio_urls : List [ str ] = None ):
549
562
"""
550
563
组装上下文。
551
564
"""
552
- if image_urls :
565
+ has_media = (image_urls and len (image_urls ) > 0 ) or (audio_urls and len (audio_urls ) > 0 )
566
+
567
+ if has_media :
553
568
user_content = {
554
569
"role" : "user" ,
555
- "content" : [{"type" : "text" , "text" : text if text else "[图片 ]" }],
570
+ "content" : [{"type" : "text" , "text" : text if text else "[媒体内容 ]" }],
556
571
}
557
- for image_url in image_urls :
558
- if image_url .startswith ("http" ):
559
- image_path = await download_image_by_url (image_url )
560
- image_data = await self .encode_image_bs64 (image_path )
561
- elif image_url .startswith ("file:///" ):
562
- image_path = image_url .replace ("file:///" , "" )
563
- image_data = await self .encode_image_bs64 (image_path )
564
- else :
565
- image_data = await self .encode_image_bs64 (image_url )
566
- if not image_data :
567
- logger .warning (f"图片 { image_url } 得到的结果为空,将忽略。" )
568
- continue
569
- user_content ["content" ].append (
570
- {"type" : "image_url" , "image_url" : {"url" : image_data }}
571
- )
572
+
573
+ # 处理图片
574
+ if image_urls :
575
+ for image_url in image_urls :
576
+ if image_url .startswith ("http" ):
577
+ image_path = await download_image_by_url (image_url )
578
+ image_data = await self .encode_image_bs64 (image_path )
579
+ elif image_url .startswith ("file:///" ):
580
+ image_path = image_url .replace ("file:///" , "" )
581
+ image_data = await self .encode_image_bs64 (image_path )
582
+ else :
583
+ image_data = await self .encode_image_bs64 (image_url )
584
+ if not image_data :
585
+ logger .warning (f"图片 { image_url } 得到的结果为空,将忽略。" )
586
+ continue
587
+ user_content ["content" ].append (
588
+ {"type" : "image_url" , "image_url" : {"url" : image_data }}
589
+ )
590
+
591
+ # 处理音频
592
+ if audio_urls :
593
+ for audio_url in audio_urls :
594
+ audio_bytes , mime_type = await self .encode_audio_data (audio_url )
595
+ if not audio_bytes or not mime_type :
596
+ logger .warning (f"音频 { audio_url } 处理失败,将忽略。" )
597
+ continue
598
+
599
+ # 添加音频数据
600
+ user_content ["content" ].append (
601
+ {
602
+ "type" : "inline_data" ,
603
+ "inline_data" : {
604
+ "mime_type" : mime_type ,
605
+ "data" : audio_bytes
606
+ }
607
+ }
608
+ )
609
+
572
610
return user_content
573
611
else :
574
612
return {"role" : "user" , "content" : text }
@@ -584,5 +622,41 @@ async def encode_image_bs64(self, image_url: str) -> str:
584
622
return "data:image/jpeg;base64," + image_bs64
585
623
return ""
586
624
625
+ async def encode_audio_data (self , audio_url : str ) -> tuple :
626
+ """
627
+ 读取音频文件并返回二进制数据
628
+
629
+ Returns:
630
+ tuple: (音频二进制数据, MIME类型)
631
+ """
632
+ try :
633
+ # 直接读取文件二进制数据
634
+ with open (audio_url , "rb" ) as f :
635
+ audio_bytes = f .read ()
636
+
637
+ # 推断 MIME 类型
638
+ mime_type = mimetypes .guess_type (audio_url )[0 ]
639
+ if not mime_type :
640
+ # 根据文件扩展名确定 MIME 类型
641
+ extension = os .path .splitext (audio_url )[1 ].lower ()
642
+ if extension == '.wav' :
643
+ mime_type = 'audio/wav'
644
+ elif extension == '.mp3' :
645
+ mime_type = 'audio/mpeg'
646
+ elif extension == '.ogg' :
647
+ mime_type = 'audio/ogg'
648
+ elif extension == '.flac' :
649
+ mime_type = 'audio/flac'
650
+ elif extension == '.m4a' :
651
+ mime_type = 'audio/mp4'
652
+ else :
653
+ mime_type = 'audio/wav' # 默认
654
+
655
+ logger .info (f"音频文件处理成功: { audio_url } ,mime类型: { mime_type } ,大小: { len (audio_bytes )} 字节" )
656
+ return audio_bytes , mime_type
657
+ except Exception as e :
658
+ logger .error (f"音频文件处理失败: { e } " )
659
+ return None , None
660
+
587
661
async def terminate (self ):
588
662
logger .info ("Google GenAI 适配器已终止。" )
0 commit comments