diff --git a/app/services/llm.py b/app/services/llm.py index 2e3b942..d054eb1 100644 --- a/app/services/llm.py +++ b/app/services/llm.py @@ -354,7 +354,7 @@ def _generate_response(prompt: str, llm_provider: str = None) -> str: return content.replace("\n", "") -def _generate_response_video(prompt: str, llm_provider_video: str, video_file: Union[str, File]) -> str: +def _generate_response_video(prompt: str, llm_provider_video: str, video_file: Union[str, TextIO]) -> str: """ 多模态能力大模型 """ diff --git a/app/test/test_gemini.py b/app/test/test_gemini.py index 5bba20e..aa96a39 100644 --- a/app/test/test_gemini.py +++ b/app/test/test_gemini.py @@ -5,8 +5,7 @@ os.environ["HTTP_PROXY"] = config.proxy.get("http") os.environ["HTTPS_PROXY"] = config.proxy.get("https") -genai.configure(api_key="AIzaSyBnKPxuPuBpZKGKuR_Sb9CwCIJYJF-N8DM") -# genai.configure(api_key="AIzaSyCm33aPRAZ_P29gTALv0tRerMJwY3zJrq0") +genai.configure(api_key="") model = genai.GenerativeModel("gemini-1.5-pro") diff --git a/webui/components/basic_settings.py b/webui/components/basic_settings.py index 960587a..adeca9e 100644 --- a/webui/components/basic_settings.py +++ b/webui/components/basic_settings.py @@ -116,7 +116,7 @@ def render_vision_llm_settings(tr): st.subheader(tr("Vision Model Settings")) # 视频分析模型提供商选择 - vision_providers = ['Gemini', 'NarratoAPI'] + vision_providers = ['Gemini', 'NarratoAPI(待发布)', 'QwenVL(待发布)'] saved_vision_provider = config.app.get("vision_llm_provider", "Gemini").lower() saved_provider_index = 0 @@ -141,7 +141,18 @@ def render_vision_llm_settings(tr): # 渲染视觉模型配置输入框 st_vision_api_key = st.text_input(tr("Vision API Key"), value=vision_api_key, type="password") - st_vision_base_url = st.text_input(tr("Vision Base URL"), value=vision_base_url) + + # 当选择 Gemini 时禁用 base_url 输入 + if vision_provider.lower() == 'gemini': + st_vision_base_url = st.text_input( + tr("Vision Base URL"), + value=vision_base_url, + disabled=True, + help=tr("Gemini API does not require a base URL") + ) + else: + st_vision_base_url = st.text_input(tr("Vision Base URL"), value=vision_base_url) + st_vision_model_name = st.text_input(tr("Vision Model Name"), value=vision_model_name) # 在配置输入框后添加测试按钮 @@ -204,7 +215,7 @@ def render_vision_llm_settings(tr): # tr("Vision Model API Key"), # value=config.app.get("narrato_vision_key", ""), # type="password", - # help="用于视频分析的模型 API Key" + # help="用于视频分析的模 API Key" # ) # # if narrato_vision_model: @@ -247,6 +258,76 @@ def render_vision_llm_settings(tr): # st.session_state['narrato_batch_size'] = narrato_batch_size +def test_text_model_connection(api_key, base_url, model_name, provider, tr): + """测试文本模型连接 + + Args: + api_key: API密钥 + base_url: 基础URL + model_name: 模型名称 + provider: 提供商名称 + + Returns: + bool: 连接是否成功 + str: 测试结果消息 + """ + import requests + + try: + # 构建统一的测试请求(遵循OpenAI格式) + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + + # 如果没有指定base_url,使用默认值 + if not base_url: + if provider.lower() == 'openai': + base_url = "https://api.openai.com/v1" + elif provider.lower() == 'moonshot': + base_url = "https://api.moonshot.cn/v1" + elif provider.lower() == 'deepseek': + base_url = "https://api.deepseek.com/v1" + + # 构建测试URL + test_url = f"{base_url.rstrip('/')}/chat/completions" + + # 特殊处理Gemini + if provider.lower() == 'gemini': + import google.generativeai as genai + try: + genai.configure(api_key=api_key) + model = genai.GenerativeModel(model_name or 'gemini-pro') + model.generate_content("直接回复我文本'当前网络可用'") + return True, tr("Gemini model is available") + except Exception as e: + return False, f"{tr('Gemini model is not available')}: {str(e)}" + + # 构建测试消息 + test_data = { + "model": model_name, + "messages": [ + {"role": "user", "content": "直接回复我文本'当前网络可用'"} + ], + "max_tokens": 10 + } + + # 发送测试请求 + response = requests.post( + test_url, + headers=headers, + json=test_data, + timeout=10 + ) + + if response.status_code == 200: + return True, tr("Text model is available") + else: + return False, f"{tr('Text model is not available')}: HTTP {response.status_code}" + + except Exception as e: + return False, f"{tr('Connection failed')}: {str(e)}" + def render_text_llm_settings(tr): """渲染文案生成模型设置""" st.subheader(tr("Text Generation Model Settings")) @@ -279,6 +360,22 @@ def render_text_llm_settings(tr): st_text_base_url = st.text_input(tr("Text Base URL"), value=text_base_url) st_text_model_name = st.text_input(tr("Text Model Name"), value=text_model_name) + # 添加测试按钮 + if st.button(tr("Test Connection"), key="test_text_connection"): + with st.spinner(tr("Testing connection...")): + success, message = test_text_model_connection( + api_key=st_text_api_key, + base_url=st_text_base_url, + model_name=st_text_model_name, + provider=text_provider, + tr=tr + ) + + if success: + st.success(message) + else: + st.error(message) + # 保存文本模型配置 if st_text_api_key: config.app[f"text_{text_provider}_api_key"] = st_text_api_key diff --git a/webui/components/script_settings.py b/webui/components/script_settings.py index 8dfa10b..30c23d3 100644 --- a/webui/components/script_settings.py +++ b/webui/components/script_settings.py @@ -535,6 +535,7 @@ def update_progress(progress: float, message: str = ""): model_name=text_model, api_key=text_api_key, prompt=custom_prompt, + base_url=text_base_url or "", video_theme=st.session_state.get('video_theme', '') ) diff --git a/webui/i18n/zh.json b/webui/i18n/zh.json index 2e9f65c..68b968a 100644 --- a/webui/i18n/zh.json +++ b/webui/i18n/zh.json @@ -132,6 +132,8 @@ "NarratoAPI is available": "NarratoAPI 可用", "NarratoAPI is not available": "NarratoAPI 不可用", "Unsupported provider": "不支持的提供商", - "0: Keep the audio only, 1: Keep the original sound only, 2: Keep the original sound and audio": "0: 仅保留音频,1: 仅保留原声,2: 保留原声和音频" + "0: Keep the audio only, 1: Keep the original sound only, 2: Keep the original sound and audio": "0: 仅保留音频,1: 仅保留原声,2: 保留原声和音频", + "Text model is not available": "文案生成模型不可用", + "Text model is available": "文案生成模型可用" } }