diff --git a/astrbot/builtin_stars/astrbot/long_term_memory.py b/astrbot/builtin_stars/astrbot/long_term_memory.py index 610995db2..74fa24ef0 100644 --- a/astrbot/builtin_stars/astrbot/long_term_memory.py +++ b/astrbot/builtin_stars/astrbot/long_term_memory.py @@ -17,13 +17,13 @@ class LongTermMemory: - def __init__(self, acm: AstrBotConfigManager, context: star.Context): + def __init__(self, acm: AstrBotConfigManager, context: star.Context) -> None: self.acm = acm self.context = context self.session_chats = defaultdict(list) """记录群成员的群聊记录""" - def cfg(self, event: AstrMessageEvent): + def cfg(self, event: AstrMessageEvent) -> dict: cfg = self.context.get_config(umo=event.unified_msg_origin) try: max_cnt = int(cfg["provider_ltm_settings"]["group_message_max_cnt"]) @@ -111,7 +111,7 @@ async def need_active_reply(self, event: AstrMessageEvent) -> bool: return False - async def handle_message(self, event: AstrMessageEvent): + async def handle_message(self, event: AstrMessageEvent) -> None: """仅支持群聊""" if event.get_message_type() == MessageType.GROUP_MESSAGE: datetime_str = datetime.datetime.now().strftime("%H:%M:%S") @@ -148,7 +148,7 @@ async def handle_message(self, event: AstrMessageEvent): if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]: self.session_chats[event.unified_msg_origin].pop(0) - async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest): + async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest) -> None: """当触发 LLM 请求前,调用此方法修改 req""" if event.unified_msg_origin not in self.session_chats: return @@ -171,7 +171,9 @@ async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest): ) req.system_prompt += chats_str - async def after_req_llm(self, event: AstrMessageEvent, llm_resp: LLMResponse): + async def after_req_llm( + self, event: AstrMessageEvent, llm_resp: LLMResponse + ) -> None: if event.unified_msg_origin not in self.session_chats: return diff --git a/astrbot/builtin_stars/astrbot/main.py b/astrbot/builtin_stars/astrbot/main.py index b3ea355b1..75952235d 100644 --- a/astrbot/builtin_stars/astrbot/main.py +++ b/astrbot/builtin_stars/astrbot/main.py @@ -1,4 +1,5 @@ import traceback +from collections.abc import AsyncGenerator from astrbot.api import star from astrbot.api.event import AstrMessageEvent, filter @@ -21,14 +22,16 @@ def __init__(self, context: star.Context) -> None: self.proc_llm_req = ProcessLLMRequest(self.context) - def ltm_enabled(self, event: AstrMessageEvent): + def ltm_enabled(self, event: AstrMessageEvent) -> bool: ltmse = self.context.get_config(umo=event.unified_msg_origin)[ "provider_ltm_settings" ] return ltmse["group_icl_enable"] or ltmse["active_reply"]["enable"] @filter.platform_adapter_type(filter.PlatformAdapterType.ALL) - async def on_message(self, event: AstrMessageEvent): + async def on_message( + self, event: AstrMessageEvent + ) -> AsyncGenerator[ProviderRequest, None]: """群聊记忆增强""" has_image_or_plain = False for comp in event.message_obj.message: @@ -89,7 +92,9 @@ async def on_message(self, event: AstrMessageEvent): logger.error(f"主动回复失败: {e}") @filter.on_llm_request() - async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest): + async def decorate_llm_req( + self, event: AstrMessageEvent, req: ProviderRequest + ) -> None: """在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt""" await self.proc_llm_req.process_llm_request(event, req) @@ -100,7 +105,9 @@ async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest): logger.error(f"ltm: {e}") @filter.on_llm_response() - async def record_llm_resp_to_ltm(self, event: AstrMessageEvent, resp: LLMResponse): + async def record_llm_resp_to_ltm( + self, event: AstrMessageEvent, resp: LLMResponse + ) -> None: """在 LLM 响应后记录对话""" if self.ltm and self.ltm_enabled(event): try: @@ -109,7 +116,7 @@ async def record_llm_resp_to_ltm(self, event: AstrMessageEvent, resp: LLMRespons logger.error(f"ltm: {e}") @filter.after_message_sent() - async def after_message_sent(self, event: AstrMessageEvent): + async def after_message_sent(self, event: AstrMessageEvent) -> None: """消息发送后处理""" if self.ltm and self.ltm_enabled(event): try: diff --git a/astrbot/builtin_stars/astrbot/process_llm_request.py b/astrbot/builtin_stars/astrbot/process_llm_request.py index 28d0a34f4..ec27341dc 100644 --- a/astrbot/builtin_stars/astrbot/process_llm_request.py +++ b/astrbot/builtin_stars/astrbot/process_llm_request.py @@ -12,7 +12,7 @@ class ProcessLLMRequest: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.ctx = context cfg = context.get_config() self.timezone = cfg.get("timezone") @@ -22,7 +22,7 @@ def __init__(self, context: star.Context): else: logger.info(f"Timezone set to: {self.timezone}") - async def _ensure_persona(self, req: ProviderRequest, cfg: dict, umo: str): + async def _ensure_persona(self, req: ProviderRequest, cfg: dict, umo: str) -> None: """确保用户人格已加载""" if not req.conversation: return @@ -78,7 +78,7 @@ async def _ensure_img_caption( req: ProviderRequest, cfg: dict, img_cap_prov_id: str, - ): + ) -> None: try: caption = await self._request_img_caption( img_cap_prov_id, @@ -118,7 +118,9 @@ async def _request_img_caption( f"Cannot get image caption because provider `{provider_id}` is not exist.", ) - async def process_llm_request(self, event: AstrMessageEvent, req: ProviderRequest): + async def process_llm_request( + self, event: AstrMessageEvent, req: ProviderRequest + ) -> None: """在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt""" cfg: dict = self.ctx.get_config(umo=event.unified_msg_origin)[ "provider_settings" diff --git a/astrbot/builtin_stars/builtin_commands/commands/admin.py b/astrbot/builtin_stars/builtin_commands/commands/admin.py index 83d4b5974..a4f46b603 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/admin.py +++ b/astrbot/builtin_stars/builtin_commands/commands/admin.py @@ -5,10 +5,10 @@ class AdminCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def op(self, event: AstrMessageEvent, admin_id: str = ""): + async def op(self, event: AstrMessageEvent, admin_id: str = "") -> None: """授权管理员。op """ if not admin_id: event.set_result( @@ -21,7 +21,7 @@ async def op(self, event: AstrMessageEvent, admin_id: str = ""): self.context.get_config().save_config() event.set_result(MessageEventResult().message("授权成功。")) - async def deop(self, event: AstrMessageEvent, admin_id: str = ""): + async def deop(self, event: AstrMessageEvent, admin_id: str = "") -> None: """取消授权管理员。deop """ if not admin_id: event.set_result( @@ -39,7 +39,7 @@ async def deop(self, event: AstrMessageEvent, admin_id: str = ""): MessageEventResult().message("此用户 ID 不在管理员名单内。"), ) - async def wl(self, event: AstrMessageEvent, sid: str = ""): + async def wl(self, event: AstrMessageEvent, sid: str = "") -> None: """添加白名单。wl """ if not sid: event.set_result( @@ -53,7 +53,7 @@ async def wl(self, event: AstrMessageEvent, sid: str = ""): cfg.save_config() event.set_result(MessageEventResult().message("添加白名单成功。")) - async def dwl(self, event: AstrMessageEvent, sid: str = ""): + async def dwl(self, event: AstrMessageEvent, sid: str = "") -> None: """删除白名单。dwl """ if not sid: event.set_result( @@ -70,7 +70,7 @@ async def dwl(self, event: AstrMessageEvent, sid: str = ""): except ValueError: event.set_result(MessageEventResult().message("此 SID 不在白名单内。")) - async def update_dashboard(self, event: AstrMessageEvent): + async def update_dashboard(self, event: AstrMessageEvent) -> None: """更新管理面板""" await event.send(MessageChain().message("正在尝试更新管理面板...")) await download_dashboard(version=f"v{VERSION}", latest=False) diff --git a/astrbot/builtin_stars/builtin_commands/commands/alter_cmd.py b/astrbot/builtin_stars/builtin_commands/commands/alter_cmd.py index 50007f6c0..ba31c3326 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/alter_cmd.py +++ b/astrbot/builtin_stars/builtin_commands/commands/alter_cmd.py @@ -11,10 +11,10 @@ class AlterCmdCommands(CommandParserMixin): - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def update_reset_permission(self, scene_key: str, perm_type: str): + async def update_reset_permission(self, scene_key: str, perm_type: str) -> None: """更新reset命令在特定场景下的权限设置""" from astrbot.api import sp @@ -26,7 +26,7 @@ async def update_reset_permission(self, scene_key: str, perm_type: str): alter_cmd_cfg["astrbot"] = plugin_cfg await sp.global_put("alter_cmd", alter_cmd_cfg) - async def alter_cmd(self, event: AstrMessageEvent): + async def alter_cmd(self, event: AstrMessageEvent) -> None: token = self.parse_commands(event.message_str) if token.len < 3: await event.send( diff --git a/astrbot/builtin_stars/builtin_commands/commands/conversation.py b/astrbot/builtin_stars/builtin_commands/commands/conversation.py index de3d11ac8..15eab5fdb 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/conversation.py +++ b/astrbot/builtin_stars/builtin_commands/commands/conversation.py @@ -16,10 +16,10 @@ class ConversationCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def _get_current_persona_id(self, session_id): + async def _get_current_persona_id(self, session_id: str) -> str | None: curr = await self.context.conversation_manager.get_curr_conversation_id( session_id, ) @@ -33,7 +33,7 @@ async def _get_current_persona_id(self, session_id): return None return conv.persona_id - async def reset(self, message: AstrMessageEvent): + async def reset(self, message: AstrMessageEvent) -> None: """重置 LLM 会话""" umo = message.unified_msg_origin cfg = self.context.get_config(umo=message.unified_msg_origin) @@ -98,7 +98,7 @@ async def reset(self, message: AstrMessageEvent): message.set_result(MessageEventResult().message(ret)) - async def his(self, message: AstrMessageEvent, page: int = 1): + async def his(self, message: AstrMessageEvent, page: int = 1) -> None: """查看对话记录""" if not self.context.get_using_provider(message.unified_msg_origin): message.set_result( @@ -141,7 +141,7 @@ async def his(self, message: AstrMessageEvent, page: int = 1): message.set_result(MessageEventResult().message(ret).use_t2i(False)) - async def convs(self, message: AstrMessageEvent, page: int = 1): + async def convs(self, message: AstrMessageEvent, page: int = 1) -> None: """查看对话列表""" cfg = self.context.get_config(umo=message.unified_msg_origin) agent_runner_type = cfg["provider_settings"]["agent_runner_type"] @@ -216,7 +216,7 @@ async def convs(self, message: AstrMessageEvent, page: int = 1): message.set_result(MessageEventResult().message(ret).use_t2i(False)) return - async def new_conv(self, message: AstrMessageEvent): + async def new_conv(self, message: AstrMessageEvent) -> None: """创建新对话""" cfg = self.context.get_config(umo=message.unified_msg_origin) agent_runner_type = cfg["provider_settings"]["agent_runner_type"] @@ -242,7 +242,7 @@ async def new_conv(self, message: AstrMessageEvent): MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})。"), ) - async def groupnew_conv(self, message: AstrMessageEvent, sid: str = ""): + async def groupnew_conv(self, message: AstrMessageEvent, sid: str = "") -> None: """创建新群聊对话""" if sid: session = str( @@ -273,7 +273,7 @@ async def switch_conv( self, message: AstrMessageEvent, index: int | None = None, - ): + ) -> None: """通过 /ls 前面的序号切换对话""" if not isinstance(index, int): message.set_result( @@ -308,7 +308,7 @@ async def switch_conv( ), ) - async def rename_conv(self, message: AstrMessageEvent, new_name: str = ""): + async def rename_conv(self, message: AstrMessageEvent, new_name: str = "") -> None: """重命名对话""" if not new_name: message.set_result(MessageEventResult().message("请输入新的对话名称。")) @@ -319,7 +319,7 @@ async def rename_conv(self, message: AstrMessageEvent, new_name: str = ""): ) message.set_result(MessageEventResult().message("重命名对话成功。")) - async def del_conv(self, message: AstrMessageEvent): + async def del_conv(self, message: AstrMessageEvent) -> None: """删除当前对话""" cfg = self.context.get_config(umo=message.unified_msg_origin) is_unique_session = cfg["platform_settings"]["unique_session"] diff --git a/astrbot/builtin_stars/builtin_commands/commands/help.py b/astrbot/builtin_stars/builtin_commands/commands/help.py index 092fc59ec..d3de1eaf8 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/help.py +++ b/astrbot/builtin_stars/builtin_commands/commands/help.py @@ -8,10 +8,10 @@ class HelpCommand: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def _query_astrbot_notice(self): + async def _query_astrbot_notice(self) -> str: try: async with aiohttp.ClientSession(trust_env=True) as session: async with session.get( @@ -34,7 +34,7 @@ async def _build_reserved_command_lines(self) -> list[str]: lines: list[str] = [] hidden_commands = {"set", "unset", "websearch"} - def walk(items: list[dict], indent: int = 0): + def walk(items: list[dict], indent: int = 0) -> None: for item in items: if not item.get("reserved") or not item.get("enabled"): continue @@ -62,7 +62,7 @@ def walk(items: list[dict], indent: int = 0): walk(commands) return lines - async def help(self, event: AstrMessageEvent): + async def help(self, event: AstrMessageEvent) -> None: """查看帮助""" notice = "" try: diff --git a/astrbot/builtin_stars/builtin_commands/commands/llm.py b/astrbot/builtin_stars/builtin_commands/commands/llm.py index 85977df40..ba9ba5c9b 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/llm.py +++ b/astrbot/builtin_stars/builtin_commands/commands/llm.py @@ -3,10 +3,10 @@ class LLMCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def llm(self, event: AstrMessageEvent): + async def llm(self, event: AstrMessageEvent) -> None: """开启/关闭 LLM""" cfg = self.context.get_config(umo=event.unified_msg_origin) enable = cfg["provider_settings"].get("enable", True) diff --git a/astrbot/builtin_stars/builtin_commands/commands/persona.py b/astrbot/builtin_stars/builtin_commands/commands/persona.py index 13a57f07f..1a5ddb848 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/persona.py +++ b/astrbot/builtin_stars/builtin_commands/commands/persona.py @@ -5,10 +5,10 @@ class PersonaCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def persona(self, message: AstrMessageEvent): + async def persona(self, message: AstrMessageEvent) -> None: l = message.message_str.split(" ") # noqa: E741 umo = message.unified_msg_origin diff --git a/astrbot/builtin_stars/builtin_commands/commands/plugin.py b/astrbot/builtin_stars/builtin_commands/commands/plugin.py index ab45efc11..49bee9462 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/plugin.py +++ b/astrbot/builtin_stars/builtin_commands/commands/plugin.py @@ -8,10 +8,10 @@ class PluginCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def plugin_ls(self, event: AstrMessageEvent): + async def plugin_ls(self, event: AstrMessageEvent) -> None: """获取已经安装的插件列表。""" parts = ["已加载的插件:\n"] for plugin in self.context.get_all_stars(): @@ -30,7 +30,7 @@ async def plugin_ls(self, event: AstrMessageEvent): MessageEventResult().message(f"{plugin_list_info}").use_t2i(False), ) - async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = ""): + async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = "") -> None: """禁用插件""" if DEMO_MODE: event.set_result(MessageEventResult().message("演示模式下无法禁用插件。")) @@ -43,7 +43,7 @@ async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = ""): await self.context._star_manager.turn_off_plugin(plugin_name) # type: ignore event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已禁用。")) - async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = ""): + async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = "") -> None: """启用插件""" if DEMO_MODE: event.set_result(MessageEventResult().message("演示模式下无法启用插件。")) @@ -56,7 +56,7 @@ async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = ""): await self.context._star_manager.turn_on_plugin(plugin_name) # type: ignore event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已启用。")) - async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = ""): + async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = "") -> None: """安装插件""" if DEMO_MODE: event.set_result(MessageEventResult().message("演示模式下无法安装插件。")) @@ -77,7 +77,7 @@ async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = ""): event.set_result(MessageEventResult().message(f"安装插件失败: {e}")) return - async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = ""): + async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = "") -> None: """获取插件帮助""" if not plugin_name: event.set_result( diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index 60b81ebe5..11f99c8ac 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -5,19 +5,20 @@ from astrbot.api import star from astrbot.api.event import AstrMessageEvent, MessageEventResult from astrbot.core.provider.entities import ProviderType +from astrbot.core.provider.provider import Provider class ProviderCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context def _log_reachability_failure( self, - provider, + provider: Provider, provider_capability_type: ProviderType | None, err_code: str, err_reason: str, - ): + ) -> None: """记录不可达原因到日志。""" meta = provider.meta() logger.warning( @@ -28,7 +29,9 @@ def _log_reachability_failure( err_reason, ) - async def _test_provider_capability(self, provider): + async def _test_provider_capability( + self, provider: Provider + ) -> tuple[bool, str | None, str | None]: """测试单个 provider 的可用性""" meta = provider.meta() provider_capability_type = meta.provider_type @@ -49,7 +52,7 @@ async def provider( event: AstrMessageEvent, idx: str | int | None = None, idx2: int | None = None, - ): + ) -> None: """查看或者切换 LLM Provider""" umo = event.unified_msg_origin cfg = self.context.get_config(umo).get("provider_settings", {}) @@ -228,7 +231,7 @@ async def model_ls( self, message: AstrMessageEvent, idx_or_name: int | str | None = None, - ): + ) -> None: """查看或者切换模型""" prov = self.context.get_using_provider(message.unified_msg_origin) if not prov: @@ -293,7 +296,7 @@ async def model_ls( MessageEventResult().message(f"切换模型到 {prov.get_model()}。"), ) - async def key(self, message: AstrMessageEvent, index: int | None = None): + async def key(self, message: AstrMessageEvent, index: int | None = None) -> None: prov = self.context.get_using_provider(message.unified_msg_origin) if not prov: message.set_result( diff --git a/astrbot/builtin_stars/builtin_commands/commands/setunset.py b/astrbot/builtin_stars/builtin_commands/commands/setunset.py index 79e5d5d1c..096698844 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/setunset.py +++ b/astrbot/builtin_stars/builtin_commands/commands/setunset.py @@ -3,10 +3,10 @@ class SetUnsetCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def set_variable(self, event: AstrMessageEvent, key: str, value: str): + async def set_variable(self, event: AstrMessageEvent, key: str, value: str) -> None: """设置会话变量""" uid = event.unified_msg_origin session_var = await sp.session_get(uid, "session_variables", {}) @@ -19,7 +19,7 @@ async def set_variable(self, event: AstrMessageEvent, key: str, value: str): ), ) - async def unset_variable(self, event: AstrMessageEvent, key: str): + async def unset_variable(self, event: AstrMessageEvent, key: str) -> None: """移除会话变量""" uid = event.unified_msg_origin session_var = await sp.session_get(uid, "session_variables", {}) diff --git a/astrbot/builtin_stars/builtin_commands/commands/sid.py b/astrbot/builtin_stars/builtin_commands/commands/sid.py index 4d95c5a60..e8bdbffb1 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/sid.py +++ b/astrbot/builtin_stars/builtin_commands/commands/sid.py @@ -7,10 +7,10 @@ class SIDCommand: """会话ID命令类""" - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def sid(self, event: AstrMessageEvent): + async def sid(self, event: AstrMessageEvent) -> None: """获取消息来源信息""" sid = event.unified_msg_origin user_id = str(event.get_sender_id()) diff --git a/astrbot/builtin_stars/builtin_commands/commands/t2i.py b/astrbot/builtin_stars/builtin_commands/commands/t2i.py index 7766b342f..78d6b0df7 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/t2i.py +++ b/astrbot/builtin_stars/builtin_commands/commands/t2i.py @@ -7,10 +7,10 @@ class T2ICommand: """文本转图片命令类""" - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def t2i(self, event: AstrMessageEvent): + async def t2i(self, event: AstrMessageEvent) -> None: """开关文本转图片""" config = self.context.get_config(umo=event.unified_msg_origin) if config["t2i"]: diff --git a/astrbot/builtin_stars/builtin_commands/commands/tool.py b/astrbot/builtin_stars/builtin_commands/commands/tool.py index 9a6c507e6..09b239b8c 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/tool.py +++ b/astrbot/builtin_stars/builtin_commands/commands/tool.py @@ -3,28 +3,28 @@ class ToolCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def tool_ls(self, event: AstrMessageEvent): + async def tool_ls(self, event: AstrMessageEvent) -> None: """查看函数工具列表""" event.set_result( MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。"), ) - async def tool_on(self, event: AstrMessageEvent, tool_name: str = ""): + async def tool_on(self, event: AstrMessageEvent, tool_name: str = "") -> None: """启用一个函数工具""" event.set_result( MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。"), ) - async def tool_off(self, event: AstrMessageEvent, tool_name: str = ""): + async def tool_off(self, event: AstrMessageEvent, tool_name: str = "") -> None: """停用一个函数工具""" event.set_result( MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。"), ) - async def tool_all_off(self, event: AstrMessageEvent): + async def tool_all_off(self, event: AstrMessageEvent) -> None: """停用所有函数工具""" event.set_result( MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。"), diff --git a/astrbot/builtin_stars/builtin_commands/commands/tts.py b/astrbot/builtin_stars/builtin_commands/commands/tts.py index dee8e31de..13049ac22 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/tts.py +++ b/astrbot/builtin_stars/builtin_commands/commands/tts.py @@ -8,10 +8,10 @@ class TTSCommand: """文本转语音命令类""" - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def tts(self, event: AstrMessageEvent): + async def tts(self, event: AstrMessageEvent) -> None: """开关文本转语音(会话级别)""" umo = event.unified_msg_origin ses_tts = await SessionServiceManager.is_tts_enabled_for_session(umo) diff --git a/astrbot/builtin_stars/builtin_commands/main.py b/astrbot/builtin_stars/builtin_commands/main.py index 7809c4359..d19f27f15 100644 --- a/astrbot/builtin_stars/builtin_commands/main.py +++ b/astrbot/builtin_stars/builtin_commands/main.py @@ -37,108 +37,108 @@ def __init__(self, context: star.Context) -> None: self.sid_c = SIDCommand(self.context) @filter.command("help") - async def help(self, event: AstrMessageEvent): + async def help(self, event: AstrMessageEvent) -> None: """查看帮助""" await self.help_c.help(event) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("llm") - async def llm(self, event: AstrMessageEvent): + async def llm(self, event: AstrMessageEvent) -> None: """开启/关闭 LLM""" await self.llm_c.llm(event) @filter.command_group("tool") - def tool(self): + def tool(self) -> None: """函数工具管理""" @tool.command("ls") - async def tool_ls(self, event: AstrMessageEvent): + async def tool_ls(self, event: AstrMessageEvent) -> None: """查看函数工具列表""" await self.tool_c.tool_ls(event) @tool.command("on") - async def tool_on(self, event: AstrMessageEvent, tool_name: str): + async def tool_on(self, event: AstrMessageEvent, tool_name: str) -> None: """启用一个函数工具""" await self.tool_c.tool_on(event, tool_name) @tool.command("off") - async def tool_off(self, event: AstrMessageEvent, tool_name: str): + async def tool_off(self, event: AstrMessageEvent, tool_name: str) -> None: """停用一个函数工具""" await self.tool_c.tool_off(event, tool_name) @tool.command("off_all") - async def tool_all_off(self, event: AstrMessageEvent): + async def tool_all_off(self, event: AstrMessageEvent) -> None: """停用所有函数工具""" await self.tool_c.tool_all_off(event) @filter.command_group("plugin") - def plugin(self): + def plugin(self) -> None: """插件管理""" @plugin.command("ls") - async def plugin_ls(self, event: AstrMessageEvent): + async def plugin_ls(self, event: AstrMessageEvent) -> None: """获取已经安装的插件列表。""" await self.plugin_c.plugin_ls(event) @filter.permission_type(filter.PermissionType.ADMIN) @plugin.command("off") - async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = ""): + async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = "") -> None: """禁用插件""" await self.plugin_c.plugin_off(event, plugin_name) @filter.permission_type(filter.PermissionType.ADMIN) @plugin.command("on") - async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = ""): + async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = "") -> None: """启用插件""" await self.plugin_c.plugin_on(event, plugin_name) @filter.permission_type(filter.PermissionType.ADMIN) @plugin.command("get") - async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = ""): + async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = "") -> None: """安装插件""" await self.plugin_c.plugin_get(event, plugin_repo) @plugin.command("help") - async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = ""): + async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = "") -> None: """获取插件帮助""" await self.plugin_c.plugin_help(event, plugin_name) @filter.command("t2i") - async def t2i(self, event: AstrMessageEvent): + async def t2i(self, event: AstrMessageEvent) -> None: """开关文本转图片""" await self.t2i_c.t2i(event) @filter.command("tts") - async def tts(self, event: AstrMessageEvent): + async def tts(self, event: AstrMessageEvent) -> None: """开关文本转语音(会话级别)""" await self.tts_c.tts(event) @filter.command("sid") - async def sid(self, event: AstrMessageEvent): + async def sid(self, event: AstrMessageEvent) -> None: """获取会话 ID 和 管理员 ID""" await self.sid_c.sid(event) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("op") - async def op(self, event: AstrMessageEvent, admin_id: str = ""): + async def op(self, event: AstrMessageEvent, admin_id: str = "") -> None: """授权管理员。op """ await self.admin_c.op(event, admin_id) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("deop") - async def deop(self, event: AstrMessageEvent, admin_id: str): + async def deop(self, event: AstrMessageEvent, admin_id: str) -> None: """取消授权管理员。deop """ await self.admin_c.deop(event, admin_id) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("wl") - async def wl(self, event: AstrMessageEvent, sid: str = ""): + async def wl(self, event: AstrMessageEvent, sid: str = "") -> None: """添加白名单。wl """ await self.admin_c.wl(event, sid) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("dwl") - async def dwl(self, event: AstrMessageEvent, sid: str): + async def dwl(self, event: AstrMessageEvent, sid: str) -> None: """删除白名单。dwl """ await self.admin_c.dwl(event, sid) @@ -149,12 +149,12 @@ async def provider( event: AstrMessageEvent, idx: str | int | None = None, idx2: int | None = None, - ): + ) -> None: """查看或者切换 LLM Provider""" await self.provider_c.provider(event, idx, idx2) @filter.command("reset") - async def reset(self, message: AstrMessageEvent): + async def reset(self, message: AstrMessageEvent) -> None: """重置 LLM 会话""" await self.conversation_c.reset(message) @@ -164,74 +164,76 @@ async def model_ls( self, message: AstrMessageEvent, idx_or_name: int | str | None = None, - ): + ) -> None: """查看或者切换模型""" await self.provider_c.model_ls(message, idx_or_name) @filter.command("history") - async def his(self, message: AstrMessageEvent, page: int = 1): + async def his(self, message: AstrMessageEvent, page: int = 1) -> None: """查看对话记录""" await self.conversation_c.his(message, page) @filter.command("ls") - async def convs(self, message: AstrMessageEvent, page: int = 1): + async def convs(self, message: AstrMessageEvent, page: int = 1) -> None: """查看对话列表""" await self.conversation_c.convs(message, page) @filter.command("new") - async def new_conv(self, message: AstrMessageEvent): + async def new_conv(self, message: AstrMessageEvent) -> None: """创建新对话""" await self.conversation_c.new_conv(message) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("groupnew") - async def groupnew_conv(self, message: AstrMessageEvent, sid: str): + async def groupnew_conv(self, message: AstrMessageEvent, sid: str) -> None: """创建新群聊对话""" await self.conversation_c.groupnew_conv(message, sid) @filter.command("switch") - async def switch_conv(self, message: AstrMessageEvent, index: int | None = None): + async def switch_conv( + self, message: AstrMessageEvent, index: int | None = None + ) -> None: """通过 /ls 前面的序号切换对话""" await self.conversation_c.switch_conv(message, index) @filter.command("rename") - async def rename_conv(self, message: AstrMessageEvent, new_name: str): + async def rename_conv(self, message: AstrMessageEvent, new_name: str) -> None: """重命名对话""" await self.conversation_c.rename_conv(message, new_name) @filter.command("del") - async def del_conv(self, message: AstrMessageEvent): + async def del_conv(self, message: AstrMessageEvent) -> None: """删除当前对话""" await self.conversation_c.del_conv(message) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("key") - async def key(self, message: AstrMessageEvent, index: int | None = None): + async def key(self, message: AstrMessageEvent, index: int | None = None) -> None: """查看或者切换 Key""" await self.provider_c.key(message, index) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("persona") - async def persona(self, message: AstrMessageEvent): + async def persona(self, message: AstrMessageEvent) -> None: """查看或者切换 Persona""" await self.persona_c.persona(message) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("dashboard_update") - async def update_dashboard(self, event: AstrMessageEvent): + async def update_dashboard(self, event: AstrMessageEvent) -> None: """更新管理面板""" await self.admin_c.update_dashboard(event) @filter.command("set") - async def set_variable(self, event: AstrMessageEvent, key: str, value: str): + async def set_variable(self, event: AstrMessageEvent, key: str, value: str) -> None: await self.setunset_c.set_variable(event, key, value) @filter.command("unset") - async def unset_variable(self, event: AstrMessageEvent, key: str): + async def unset_variable(self, event: AstrMessageEvent, key: str) -> None: await self.setunset_c.unset_variable(event, key) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("alter_cmd", alias={"alter"}) - async def alter_cmd(self, event: AstrMessageEvent): + async def alter_cmd(self, event: AstrMessageEvent) -> None: """修改命令权限""" await self.alter_cmd_c.alter_cmd(event) diff --git a/astrbot/builtin_stars/python_interpreter/main.py b/astrbot/builtin_stars/python_interpreter/main.py index ec9d261b7..469bc3ead 100644 --- a/astrbot/builtin_stars/python_interpreter/main.py +++ b/astrbot/builtin_stars/python_interpreter/main.py @@ -6,6 +6,7 @@ import time import uuid from collections import defaultdict +from collections.abc import AsyncGenerator import aiodocker import aiohttp @@ -124,7 +125,7 @@ def __init__(self, context: star.Context) -> None: with open(PATH) as f: self.config = json.load(f) - async def initialize(self): + async def initialize(self) -> None: ok = await self.is_docker_available() if not ok: logger.info( @@ -134,7 +135,7 @@ async def initialize(self): # "astrbot-python-interpreter" # ) - async def file_upload(self, file_path: str): + async def file_upload(self, file_path: str) -> str: """上传图像文件到 S3""" ext = os.path.splitext(file_path)[1] S3_URL = "https://s3.neko.soulter.top/astrbot-s3" @@ -170,7 +171,7 @@ async def get_image_name(self) -> str: return f"{self.config['sandbox']['docker_mirror']}/{self.config['sandbox']['image']}" return self.config["sandbox"]["image"] - def _save_config(self): + def _save_config(self) -> None: with open(PATH, "w") as f: json.dump(self.config, f) @@ -202,7 +203,9 @@ async def tidy_code(self, code: str) -> str: return match.group(1) @filter.event_message_type(filter.EventMessageType.ALL) - async def on_message(self, event: AstrMessageEvent): + async def on_message( + self, event: AstrMessageEvent + ) -> AsyncGenerator[MessageEventResult, None]: """处理消息""" uid = event.get_sender_id() if uid not in self.user_waiting: @@ -239,7 +242,9 @@ async def on_message(self, event: AstrMessageEvent): del self.user_waiting[uid] @filter.on_llm_request() - async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest): + async def on_llm_req( + self, event: AstrMessageEvent, request: ProviderRequest + ) -> None: if event.get_session_id() in self.user_file_msg_buffer: files = self.user_file_msg_buffer[event.get_session_id()] if not request.prompt: @@ -247,11 +252,13 @@ async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest): request.prompt += f"\nUser provided files: {files}" @filter.command_group("pi") - def pi(self): + def pi(self) -> None: """代码执行器配置""" @pi.command("absdir") - async def pi_absdir(self, event: AstrMessageEvent, path: str = ""): + async def pi_absdir( + self, event: AstrMessageEvent, path: str = "" + ) -> AsyncGenerator[MessageEventResult, None]: """设置 Docker 宿主机绝对路径""" if not path: yield event.plain_result( @@ -263,7 +270,9 @@ async def pi_absdir(self, event: AstrMessageEvent, path: str = ""): yield event.plain_result(f"设置 Docker 宿主机绝对路径成功: {path}") @pi.command("mirror") - async def pi_mirror(self, event: AstrMessageEvent, url: str = ""): + async def pi_mirror( + self, event: AstrMessageEvent, url: str = "" + ) -> AsyncGenerator[MessageEventResult, None]: """Docker 镜像地址""" if not url: yield event.plain_result(f"""当前 Docker 镜像地址: {self.config["sandbox"]["docker_mirror"]}。 @@ -276,7 +285,9 @@ async def pi_mirror(self, event: AstrMessageEvent, url: str = ""): yield event.plain_result("设置 Docker 镜像地址成功。") @pi.command("repull") - async def pi_repull(self, event: AstrMessageEvent): + async def pi_repull( + self, event: AstrMessageEvent + ) -> AsyncGenerator[MessageEventResult, None]: """重新拉取沙箱镜像""" async with aiodocker.Docker() as docker: image_name = await self.get_image_name() @@ -289,7 +300,9 @@ async def pi_repull(self, event: AstrMessageEvent): yield event.plain_result("重新拉取沙箱镜像成功。") @pi.command("file") - async def pi_file(self, event: AstrMessageEvent): + async def pi_file( + self, event: AstrMessageEvent + ) -> AsyncGenerator[MessageEventResult, None]: """在规定秒数(60s)内上传一个文件""" uid = event.get_sender_id() self.user_waiting[uid] = time.time() @@ -303,7 +316,9 @@ async def pi_file(self, event: AstrMessageEvent): self.user_waiting.pop(uid) @pi.command("clear", alias=["clean"]) - async def pi_file_clean(self, event: AstrMessageEvent): + async def pi_file_clean( + self, event: AstrMessageEvent + ) -> AsyncGenerator[MessageEventResult, None]: """清理用户上传的文件""" uid = event.get_sender_id() if uid in self.user_waiting: @@ -317,7 +332,9 @@ async def pi_file_clean(self, event: AstrMessageEvent): ) @pi.command("list") - async def pi_file_list(self, event: AstrMessageEvent): + async def pi_file_list( + self, event: AstrMessageEvent + ) -> AsyncGenerator[MessageEventResult, None]: """列出用户上传的文件""" uid = event.get_sender_id() if uid in self.user_file_msg_buffer: @@ -331,7 +348,9 @@ async def pi_file_list(self, event: AstrMessageEvent): ) @llm_tool("python_interpreter") - async def python_interpreter(self, event: AstrMessageEvent): + async def python_interpreter( + self, event: AstrMessageEvent + ) -> AsyncGenerator[MessageEventResult | None, None]: """Use this tool only if user really want to solve a complex problem and the problem can be solved very well by Python code. For example, user can use this tool to solve math problems, edit image, docx, pptx, pdf, etc. """ @@ -507,7 +526,9 @@ async def python_interpreter(self, event: AstrMessageEvent): ) @pi.command("cleanfile") - async def pi_cleanfile(self, event: AstrMessageEvent): + async def pi_cleanfile( + self, event: AstrMessageEvent + ) -> AsyncGenerator[MessageEventResult, None]: """清理用户上传的文件""" for file in self.user_file_msg_buffer[event.get_session_id()]: try: diff --git a/astrbot/builtin_stars/python_interpreter/shared/api.py b/astrbot/builtin_stars/python_interpreter/shared/api.py index 287773fb0..cc12bd6d6 100644 --- a/astrbot/builtin_stars/python_interpreter/shared/api.py +++ b/astrbot/builtin_stars/python_interpreter/shared/api.py @@ -1,22 +1,22 @@ import os -def _get_magic_code(): +def _get_magic_code() -> str | None: """防止注入攻击""" return os.getenv("MAGIC_CODE") -def send_text(text: str): +def send_text(text: str) -> None: print(f"[ASTRBOT_TEXT_OUTPUT#{_get_magic_code()}]: {text}") -def send_image(image_path: str): +def send_image(image_path: str) -> None: if not os.path.exists(image_path): raise Exception(f"Image file not found: {image_path}") print(f"[ASTRBOT_IMAGE_OUTPUT#{_get_magic_code()}]: {image_path}") -def send_file(file_path: str): +def send_file(file_path: str) -> None: if not os.path.exists(file_path): raise Exception(f"File not found: {file_path}") print(f"[ASTRBOT_FILE_OUTPUT#{_get_magic_code()}]: {file_path}") diff --git a/astrbot/builtin_stars/reminder/main.py b/astrbot/builtin_stars/reminder/main.py index 62af7ae56..e8dea2242 100644 --- a/astrbot/builtin_stars/reminder/main.py +++ b/astrbot/builtin_stars/reminder/main.py @@ -3,6 +3,7 @@ import os import uuid import zoneinfo +from collections.abc import AsyncGenerator from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.triggers.cron import CronTrigger @@ -38,7 +39,7 @@ def __init__(self, context: star.Context) -> None: self._init_scheduler() self.scheduler.start() - def _init_scheduler(self): + def _init_scheduler(self) -> None: """Initialize the scheduler.""" for group in self.reminder_data: for reminder in self.reminder_data[group]: @@ -72,7 +73,7 @@ def _init_scheduler(self): misfire_grace_time=60, ) - def check_is_outdated(self, reminder: dict): + def check_is_outdated(self, reminder: dict) -> bool: """Check if the reminder is outdated.""" if "datetime" in reminder: reminder_time = datetime.datetime.strptime( @@ -82,13 +83,13 @@ def check_is_outdated(self, reminder: dict): return reminder_time < datetime.datetime.now(self.timezone) return False - async def _save_data(self): + async def _save_data(self) -> None: """Save the reminder data.""" reminder_file = os.path.join(get_astrbot_data_path(), "astrbot-reminder.json") with open(reminder_file, "w", encoding="utf-8") as f: json.dump(self.reminder_data, f, ensure_ascii=False) - def _parse_cron_expr(self, cron_expr: str): + def _parse_cron_expr(self, cron_expr: str) -> dict: fields = cron_expr.split(" ") return { "minute": fields[0], @@ -106,7 +107,7 @@ async def reminder_tool( datetime_str: str | None = None, cron_expression: str | None = None, human_readable_cron: str | None = None, - ): + ) -> AsyncGenerator[MessageEventResult, None]: """Call this function when user is asking for setting a reminder. Args: @@ -178,10 +179,10 @@ async def reminder_tool( ) @filter.command_group("reminder") - def reminder(self): + def reminder(self) -> None: """待办提醒""" - async def get_upcoming_reminders(self, unified_msg_origin: str): + async def get_upcoming_reminders(self, unified_msg_origin: str) -> list: """Get upcoming reminders.""" reminders = self.reminder_data.get(unified_msg_origin, []) if not reminders: @@ -200,7 +201,9 @@ async def get_upcoming_reminders(self, unified_msg_origin: str): return upcoming_reminders @reminder.command("ls") - async def reminder_ls(self, event: AstrMessageEvent): + async def reminder_ls( + self, event: AstrMessageEvent + ) -> AsyncGenerator[MessageEventResult, None]: """List upcoming reminders.""" reminders = await self.get_upcoming_reminders(event.unified_msg_origin) if not reminders: @@ -218,7 +221,9 @@ async def reminder_ls(self, event: AstrMessageEvent): yield event.plain_result(reminder_str) @reminder.command("rm") - async def reminder_rm(self, event: AstrMessageEvent, index: int): + async def reminder_rm( + self, event: AstrMessageEvent, index: int + ) -> AsyncGenerator[MessageEventResult, None]: """Remove a reminder by index.""" reminders = await self.get_upcoming_reminders(event.unified_msg_origin) @@ -246,7 +251,7 @@ async def reminder_rm(self, event: AstrMessageEvent, index: int): await self._save_data() yield event.plain_result("成功删除待办事项:\n" + reminder["text"]) - async def _reminder_callback(self, unified_msg_origin: str, d: dict): + async def _reminder_callback(self, unified_msg_origin: str, d: dict) -> None: """The callback function of the reminder.""" logger.info(f"Reminder Activated: {d['text']}, created by {unified_msg_origin}") await self.context.send_message( @@ -260,7 +265,7 @@ async def _reminder_callback(self, unified_msg_origin: str, d: dict): ), ) - async def terminate(self): + async def terminate(self) -> None: self.scheduler.shutdown() await self._save_data() logger.info("Reminder plugin terminated.") diff --git a/astrbot/builtin_stars/session_controller/main.py b/astrbot/builtin_stars/session_controller/main.py index 9ea62ea30..82753d915 100644 --- a/astrbot/builtin_stars/session_controller/main.py +++ b/astrbot/builtin_stars/session_controller/main.py @@ -1,10 +1,13 @@ import copy +from collections.abc import AsyncGenerator from sys import maxsize import astrbot.api.message_components as Comp from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, filter from astrbot.api.star import Context, Star +from astrbot.core.message.message_event_result import MessageEventResult +from astrbot.core.provider.entities import ProviderRequest from astrbot.core.utils.session_waiter import ( FILTERS, USER_SESSIONS, @@ -17,11 +20,11 @@ class Main(Star): """会话控制""" - def __init__(self, context: Context): + def __init__(self, context: Context) -> None: super().__init__(context) @filter.event_message_type(filter.EventMessageType.ALL, priority=maxsize) - async def handle_session_control_agent(self, event: AstrMessageEvent): + async def handle_session_control_agent(self, event: AstrMessageEvent) -> None: """会话控制代理""" for session_filter in FILTERS: session_id = session_filter.filter(event) @@ -30,7 +33,9 @@ async def handle_session_control_agent(self, event: AstrMessageEvent): event.stop_event() @filter.event_message_type(filter.EventMessageType.ALL, priority=maxsize - 1) - async def handle_empty_mention(self, event: AstrMessageEvent): + async def handle_empty_mention( + self, event: AstrMessageEvent + ) -> AsyncGenerator[MessageEventResult | ProviderRequest, None]: """实现了对只有一个 @ 的消息内容的处理""" try: messages = event.get_messages() @@ -91,7 +96,7 @@ async def handle_empty_mention(self, event: AstrMessageEvent): async def empty_mention_waiter( controller: SessionController, event: AstrMessageEvent, - ): + ) -> None: event.message_obj.message.insert( 0, Comp.At(qq=event.get_self_id(), name=event.get_self_id()), diff --git a/astrbot/builtin_stars/web_searcher/engines/__init__.py b/astrbot/builtin_stars/web_searcher/engines/__init__.py index 699438602..2c18d9884 100644 --- a/astrbot/builtin_stars/web_searcher/engines/__init__.py +++ b/astrbot/builtin_stars/web_searcher/engines/__init__.py @@ -48,7 +48,7 @@ def __init__(self) -> None: def _set_selector(self, selector: str) -> str: raise NotImplementedError - def _get_next_page(self, query: str): + async def _get_next_page(self, query: str) -> str: raise NotImplementedError async def _get_html(self, url: str, data: dict | None = None) -> str: diff --git a/astrbot/builtin_stars/web_searcher/engines/bing.py b/astrbot/builtin_stars/web_searcher/engines/bing.py index 7565e5df3..95b9bd419 100644 --- a/astrbot/builtin_stars/web_searcher/engines/bing.py +++ b/astrbot/builtin_stars/web_searcher/engines/bing.py @@ -7,7 +7,7 @@ def __init__(self) -> None: self.base_urls = ["https://cn.bing.com", "https://www.bing.com"] self.headers.update({"User-Agent": USER_AGENT_BING}) - def _set_selector(self, selector: str): + def _set_selector(self, selector: str) -> str: selectors = { "url": "div.b_attribution cite", "title": "h2", @@ -17,7 +17,7 @@ def _set_selector(self, selector: str): } return selectors[selector] - async def _get_next_page(self, query) -> str: + async def _get_next_page(self, query: str) -> str: # if self.page == 1: # await self._get_html(self.base_url) for base_url in self.base_urls: diff --git a/astrbot/builtin_stars/web_searcher/engines/sogo.py b/astrbot/builtin_stars/web_searcher/engines/sogo.py index f490f1106..dfd14bce1 100644 --- a/astrbot/builtin_stars/web_searcher/engines/sogo.py +++ b/astrbot/builtin_stars/web_searcher/engines/sogo.py @@ -13,7 +13,7 @@ def __init__(self) -> None: self.base_url = "https://www.sogou.com" self.headers["User-Agent"] = random.choice(USER_AGENTS) - def _set_selector(self, selector: str): + def _set_selector(self, selector: str) -> str: selectors = { "url": "h3 > a", "title": "h3", @@ -23,7 +23,7 @@ def _set_selector(self, selector: str): } return selectors[selector] - async def _get_next_page(self, query) -> str: + async def _get_next_page(self, query: str) -> str: url = f"{self.base_url}/web?query={query}" return await self._get_html(url, None) @@ -38,7 +38,7 @@ async def search(self, query: str, num_results: int) -> list[SearchResult]: result.url = await self._parse_url(result.url) return results - async def _parse_url(self, url) -> str: + async def _parse_url(self, url: str) -> str: html = await self._get_html(url) soup = BeautifulSoup(html, "html.parser") script = soup.find("script") diff --git a/astrbot/builtin_stars/web_searcher/main.py b/astrbot/builtin_stars/web_searcher/main.py index 4745cd0c0..5b4fd9f86 100644 --- a/astrbot/builtin_stars/web_searcher/main.py +++ b/astrbot/builtin_stars/web_searcher/main.py @@ -89,7 +89,7 @@ async def _process_search_result( async def _web_search_default( self, - query, + query: str, num_results: int = 5, ) -> list[SearchResult]: results = [] @@ -184,7 +184,7 @@ async def _extract_tavily(self, cfg: AstrBotConfig, payload: dict) -> list[dict] return results @filter.command("websearch") - async def websearch(self, event: AstrMessageEvent, oper: str | None = None): + async def websearch(self, event: AstrMessageEvent, oper: str | None = None) -> None: """网页搜索指令(已废弃)""" event.set_result( MessageEventResult().message( @@ -231,7 +231,7 @@ async def search_from_search_engine( return ret - async def ensure_baidu_ai_search_mcp(self, umo: str | None = None): + async def ensure_baidu_ai_search_mcp(self, umo: str | None = None) -> None: if self.baidu_initialized: return cfg = self.context.get_config(umo=umo) @@ -379,7 +379,7 @@ async def edit_web_search_tools( self, event: AstrMessageEvent, req: ProviderRequest, - ): + ) -> None: """Get the session conversation for the given event.""" cfg = self.context.get_config(umo=event.unified_msg_origin) prov_settings = cfg.get("provider_settings", {}) diff --git a/astrbot/cli/commands/cmd_conf.py b/astrbot/cli/commands/cmd_conf.py index a9bd40f00..64adf6fc4 100644 --- a/astrbot/cli/commands/cmd_conf.py +++ b/astrbot/cli/commands/cmd_conf.py @@ -104,7 +104,7 @@ def _save_config(config: dict[str, Any]) -> None: ) -def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None: +def _set_nested_item(obj: dict[str, Any], path: str, value: object) -> None: """设置嵌套字典中的值""" parts = path.split(".") for part in parts[:-1]: @@ -118,7 +118,7 @@ def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None: obj[parts[-1]] = value -def _get_nested_item(obj: dict[str, Any], path: str) -> Any: +def _get_nested_item(obj: dict[str, Any], path: str) -> object: """获取嵌套字典中的值""" parts = path.split(".") for part in parts: @@ -127,7 +127,7 @@ def _get_nested_item(obj: dict[str, Any], path: str) -> Any: @click.group(name="conf") -def conf(): +def conf() -> None: """配置管理命令 支持的配置项: @@ -149,7 +149,7 @@ def conf(): @conf.command(name="set") @click.argument("key") @click.argument("value") -def set_config(key: str, value: str): +def set_config(key: str, value: str) -> None: """设置配置项的值""" if key not in CONFIG_VALIDATORS: raise click.ClickException(f"不支持的配置项: {key}") @@ -178,7 +178,7 @@ def set_config(key: str, value: str): @conf.command(name="get") @click.argument("key", required=False) -def get_config(key: str | None = None): +def get_config(key: str | None = None) -> None: """获取配置项的值,不提供key则显示所有可配置项""" config = _load_config() diff --git a/astrbot/cli/commands/cmd_plug.py b/astrbot/cli/commands/cmd_plug.py index a1099de1d..aeaedc70b 100644 --- a/astrbot/cli/commands/cmd_plug.py +++ b/astrbot/cli/commands/cmd_plug.py @@ -5,6 +5,7 @@ import click from ..utils import ( + PluginInfo, PluginStatus, build_plug_list, check_astrbot_root, @@ -15,7 +16,7 @@ @click.group() -def plug(): +def plug() -> None: """插件管理""" @@ -28,7 +29,11 @@ def _get_data_path() -> Path: return (base / "data").resolve() -def display_plugins(plugins, title=None, color=None): +def display_plugins( + plugins: list[PluginInfo], + title: str | None = None, + color: int | tuple[int, int, int] | str | None = None, +) -> None: if title: click.echo(click.style(title, fg=color, bold=True)) @@ -45,7 +50,7 @@ def display_plugins(plugins, title=None, color=None): @plug.command() @click.argument("name") -def new(name: str): +def new(name: str) -> None: """创建新插件""" base_path = _get_data_path() plug_path = base_path / "plugins" / name @@ -100,7 +105,7 @@ def new(name: str): @plug.command() @click.option("--all", "-a", is_flag=True, help="列出未安装的插件") -def list(all: bool): +def list(all: bool) -> None: """列出插件""" base_path = _get_data_path() plugins = build_plug_list(base_path / "plugins") @@ -141,7 +146,7 @@ def list(all: bool): @plug.command() @click.argument("name") @click.option("--proxy", help="代理服务器地址") -def install(name: str, proxy: str | None): +def install(name: str, proxy: str | None) -> None: """安装插件""" base_path = _get_data_path() plug_path = base_path / "plugins" @@ -164,13 +169,13 @@ def install(name: str, proxy: str | None): @plug.command() @click.argument("name") -def remove(name: str): +def remove(name: str) -> None: """卸载插件""" base_path = _get_data_path() plugins = build_plug_list(base_path / "plugins") plugin = next((p for p in plugins if p["name"] == name), None) - if not plugin or not plugin.get("local_path"): + if not plugin or not plugin["local_path"]: raise click.ClickException(f"插件 {name} 不存在或未安装") plugin_path = plugin["local_path"] @@ -187,7 +192,7 @@ def remove(name: str): @plug.command() @click.argument("name", required=False) @click.option("--proxy", help="Github代理地址") -def update(name: str, proxy: str | None): +def update(name: str, proxy: str | None) -> None: """更新插件""" base_path = _get_data_path() plug_path = base_path / "plugins" @@ -225,7 +230,7 @@ def update(name: str, proxy: str | None): @plug.command() @click.argument("query") -def search(query: str): +def search(query: str) -> None: """搜索插件""" base_path = _get_data_path() plugins = build_plug_list(base_path / "plugins") diff --git a/astrbot/cli/commands/cmd_run.py b/astrbot/cli/commands/cmd_run.py index 9333f1b87..23665dff3 100644 --- a/astrbot/cli/commands/cmd_run.py +++ b/astrbot/cli/commands/cmd_run.py @@ -10,7 +10,7 @@ from ..utils import check_astrbot_root, check_dashboard, get_astrbot_root -async def run_astrbot(astrbot_root: Path): +async def run_astrbot(astrbot_root: Path) -> None: """运行 AstrBot""" from astrbot.core import LogBroker, LogManager, db_helper, logger from astrbot.core.initial_loader import InitialLoader diff --git a/astrbot/cli/utils/__init__.py b/astrbot/cli/utils/__init__.py index 3830682f0..6d6b02bf2 100644 --- a/astrbot/cli/utils/__init__.py +++ b/astrbot/cli/utils/__init__.py @@ -3,10 +3,17 @@ check_dashboard, get_astrbot_root, ) -from .plugin import PluginStatus, build_plug_list, get_git_repo, manage_plugin +from .plugin import ( + PluginInfo, + PluginStatus, + build_plug_list, + get_git_repo, + manage_plugin, +) from .version_comparator import VersionComparator __all__ = [ + "PluginInfo", "PluginStatus", "VersionComparator", "build_plug_list", diff --git a/astrbot/cli/utils/plugin.py b/astrbot/cli/utils/plugin.py index cd76a07c8..44542c33d 100644 --- a/astrbot/cli/utils/plugin.py +++ b/astrbot/cli/utils/plugin.py @@ -3,6 +3,7 @@ from enum import Enum from io import BytesIO from pathlib import Path +from typing import TypedDict from zipfile import ZipFile import click @@ -19,7 +20,17 @@ class PluginStatus(str, Enum): NOT_PUBLISHED = "未发布" -def get_git_repo(url: str, target_path: Path, proxy: str | None = None): +class PluginInfo(TypedDict): + name: str + desc: str + version: str + author: str + repo: str + status: PluginStatus + local_path: str | None + + +def get_git_repo(url: str, target_path: Path, proxy: str | None = None) -> None: """从 Git 仓库下载代码并解压到指定路径""" temp_dir = Path(tempfile.mkdtemp()) try: @@ -102,18 +113,18 @@ def load_yaml_metadata(plugin_dir: Path) -> dict: return {} -def build_plug_list(plugins_dir: Path) -> list: +def build_plug_list(plugins_dir: Path) -> list[PluginInfo]: """构建插件列表,包含本地和在线插件信息 Args: plugins_dir (Path): 插件目录路径 Returns: - list: 包含插件信息的字典列表 + list[PluginInfo]: 包含插件信息的字典列表 """ # 获取本地插件信息 - result = [] + result: list[PluginInfo] = [] if plugins_dir.exists(): for plugin_name in [d.name for d in plugins_dir.glob("*") if d.is_dir()]: plugin_dir = plugins_dir / plugin_name @@ -141,7 +152,7 @@ def build_plug_list(plugins_dir: Path) -> list: ) # 获取在线插件列表 - online_plugins = [] + online_plugins: list[PluginInfo] = [] try: with httpx.Client() as client: resp = client.get("https://api.soulter.top/astrbot/plugins") @@ -191,7 +202,7 @@ def build_plug_list(plugins_dir: Path) -> list: def manage_plugin( - plugin: dict, + plugin: PluginInfo, plugins_dir: Path, is_update: bool = False, proxy: str | None = None, @@ -209,7 +220,7 @@ def manage_plugin( repo_url = plugin["repo"] # 如果是更新且有本地路径,直接使用本地路径 - if is_update and plugin.get("local_path"): + if is_update and plugin["local_path"]: target_path = Path(plugin["local_path"]) else: target_path = plugins_dir / plugin_name diff --git a/astrbot/cli/utils/version_comparator.py b/astrbot/cli/utils/version_comparator.py index 0aaf8dcab..fbf15a612 100644 --- a/astrbot/cli/utils/version_comparator.py +++ b/astrbot/cli/utils/version_comparator.py @@ -15,7 +15,7 @@ def compare_version(v1: str, v2: str) -> int: v1 = v1.lower().replace("v", "") v2 = v2.lower().replace("v", "") - def split_version(version): + def split_version(version: str) -> tuple[list[int], list[int | str] | None]: match = re.match( r"^([0-9]+(?:\.[0-9]+)*)(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+(.+))?$", version, @@ -77,7 +77,7 @@ def split_version(version): return 0 # 数字部分和预发布标签都相同 @staticmethod - def _split_prerelease(prerelease): + def _split_prerelease(prerelease: str) -> list[int | str] | None: if not prerelease: return None parts = prerelease.split(".") diff --git a/astrbot/core/agent/handoff.py b/astrbot/core/agent/handoff.py index 85276540b..755cc45a6 100644 --- a/astrbot/core/agent/handoff.py +++ b/astrbot/core/agent/handoff.py @@ -1,8 +1,23 @@ +from __future__ import annotations + +from collections.abc import AsyncGenerator, Awaitable, Callable from typing import Generic +from typing_extensions import TypedDict, Unpack + +from astrbot.core.message.message_event_result import MessageEventResult + from .agent import Agent from .run_context import TContext -from .tool import FunctionTool +from .tool import FunctionTool, ParametersType + + +class HandoffInitKwargs(TypedDict, total=False): + handler: ( + Callable[..., Awaitable[str | None] | AsyncGenerator[MessageEventResult]] | None + ) + handler_module_path: str | None + active: bool class HandoffTool(FunctionTool, Generic[TContext]): @@ -11,9 +26,9 @@ class HandoffTool(FunctionTool, Generic[TContext]): def __init__( self, agent: Agent[TContext], - parameters: dict | None = None, - **kwargs, - ): + parameters: ParametersType | None = None, + **kwargs: Unpack[HandoffInitKwargs], + ) -> None: self.agent = agent super().__init__( name=f"transfer_to_{agent.name}", diff --git a/astrbot/core/agent/hooks.py b/astrbot/core/agent/hooks.py index d834240b7..74ca6335b 100644 --- a/astrbot/core/agent/hooks.py +++ b/astrbot/core/agent/hooks.py @@ -9,22 +9,22 @@ class BaseAgentRunHooks(Generic[TContext]): - async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ... + async def on_agent_begin(self, run_context: ContextWrapper[TContext]) -> None: ... async def on_tool_start( self, run_context: ContextWrapper[TContext], tool: FunctionTool, tool_args: dict | None, - ): ... + ) -> None: ... async def on_tool_end( self, run_context: ContextWrapper[TContext], tool: FunctionTool, tool_args: dict | None, tool_result: mcp.types.CallToolResult | None, - ): ... + ) -> None: ... async def on_agent_done( self, run_context: ContextWrapper[TContext], llm_response: LLMResponse, - ): ... + ) -> None: ... diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index c5ff123b2..13ac2d7de 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -4,6 +4,7 @@ from datetime import timedelta from typing import Generic +from mcp.types import CallToolResult from tenacity import ( before_sleep_log, retry, @@ -108,7 +109,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: class MCPClient: - def __init__(self): + def __init__(self) -> None: # Initialize session and client objects self.session: mcp.ClientSession | None = None self.exit_stack = AsyncExitStack() @@ -126,7 +127,7 @@ def __init__(self): self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection self._reconnecting: bool = False # For logging and debugging - async def connect_to_server(self, mcp_server_config: dict, name: str): + async def connect_to_server(self, mcp_server_config: dict, name: str) -> None: """Connect to MCP server If `url` parameter exists: @@ -144,7 +145,7 @@ async def connect_to_server(self, mcp_server_config: dict, name: str): cfg = _prepare_config(mcp_server_config.copy()) - def logging_callback(msg: str): + def logging_callback(msg: str) -> None: # Handle MCP service error logs print(f"MCP Server {name} Error: {msg}") self.server_errlogs.append(msg) @@ -214,7 +215,7 @@ def logging_callback(msg: str): **cfg, ) - def callback(msg: str): + def callback(msg: str) -> None: # Handle MCP service error logs self.server_errlogs.append(msg) @@ -322,7 +323,7 @@ async def call_tool_with_reconnect( before_sleep=before_sleep_log(logger, logging.WARNING), reraise=True, ) - async def _call_with_retry(): + async def _call_with_retry() -> CallToolResult: if not self.session: raise ValueError("MCP session is not available for MCP function tools.") @@ -343,7 +344,7 @@ async def _call_with_retry(): return await _call_with_retry() - async def cleanup(self): + async def cleanup(self) -> None: """Clean up resources including old exit stacks from reconnections""" # Close current exit stack try: @@ -364,8 +365,12 @@ class MCPTool(FunctionTool, Generic[TContext]): """A function tool that calls an MCP service.""" def __init__( - self, mcp_tool: mcp.Tool, mcp_client: MCPClient, mcp_server_name: str, **kwargs - ): + self, + mcp_tool: mcp.Tool, + mcp_client: MCPClient, + mcp_server_name: str, + **kwargs: object, + ) -> None: super().__init__( name=mcp_tool.name, description=mcp_tool.description or "", @@ -376,7 +381,9 @@ def __init__( self.mcp_server_name = mcp_server_name async def call( - self, context: ContextWrapper[TContext], **kwargs + self, + context: ContextWrapper[TContext], + **kwargs: object, ) -> mcp.types.CallToolResult: return await self.mcp_client.call_tool_with_reconnect( tool_name=self.mcp_tool.name, diff --git a/astrbot/core/agent/message.py b/astrbot/core/agent/message.py index 582b1eef2..62c2428d2 100644 --- a/astrbot/core/agent/message.py +++ b/astrbot/core/agent/message.py @@ -1,10 +1,19 @@ # Inspired by MoonshotAI/kosong, credits to MoonshotAI/kosong authors for the original implementation. # License: Apache License 2.0 +import builtins from typing import Any, ClassVar, Literal, cast -from pydantic import BaseModel, GetCoreSchemaHandler, model_serializer, model_validator +from pydantic import ( + BaseModel, + GetCoreSchemaHandler, + SerializerFunctionWrapHandler, + model_serializer, + model_validator, +) +from pydantic.config import ConfigDict from pydantic_core import core_schema +from typing_extensions import Unpack class ContentPart(BaseModel): @@ -14,7 +23,7 @@ class ContentPart(BaseModel): type: Literal["text", "think", "image_url", "audio_url"] - def __init_subclass__(cls, **kwargs: Any) -> None: + def __init_subclass__(cls, **kwargs: Unpack[ConfigDict]) -> None: super().__init_subclass__(**kwargs) invalid_subclass_error_msg = f"ContentPart subclass {cls.__name__} must have a `type` field of type `str`" @@ -27,15 +36,15 @@ def __init_subclass__(cls, **kwargs: Any) -> None: @classmethod def __get_pydantic_core_schema__( - cls, source_type: Any, handler: GetCoreSchemaHandler + cls, source_type: builtins.type[BaseModel], handler: GetCoreSchemaHandler ) -> core_schema.CoreSchema: # If we're dealing with the base ContentPart class, use custom validation if cls.__name__ == "ContentPart": - def validate_content_part(value: Any) -> Any: + def validate_content_part(value: object) -> "ContentPart": # if it's already an instance of a ContentPart subclass, return it if hasattr(value, "__class__") and issubclass(value.__class__, cls): - return value + return cast("ContentPart", value) # if it's a dict with a type field, dispatch to the appropriate subclass if isinstance(value, dict) and "type" in value: @@ -74,7 +83,7 @@ class ThinkPart(ContentPart): encrypted: str | None = None """Encrypted thinking content, or signature.""" - def merge_in_place(self, other: Any) -> bool: + def merge_in_place(self, other: object) -> bool: if not isinstance(other, ThinkPart): return False if self.encrypted: @@ -145,7 +154,7 @@ class FunctionBody(BaseModel): """Extra metadata for the tool call.""" @model_serializer(mode="wrap") - def serialize(self, handler): + def serialize(self, handler: SerializerFunctionWrapHandler) -> dict: data = handler(self) if self.extra_content is None: data.pop("extra_content", None) @@ -179,7 +188,7 @@ class Message(BaseModel): """The ID of the tool call.""" @model_validator(mode="after") - def check_content_required(self): + def check_content_required(self) -> "Message": # assistant + tool_calls is not None: allow content to be None if self.role == "assistant" and self.tool_calls is not None: return self @@ -192,7 +201,7 @@ def check_content_required(self): return self @model_serializer(mode="wrap") - def serialize(self, handler): + def serialize(self, handler: SerializerFunctionWrapHandler) -> dict: data = handler(self) if self.tool_calls is None: data.pop("tool_calls", None) diff --git a/astrbot/core/agent/runners/base.py b/astrbot/core/agent/runners/base.py index 21e796433..a9ddc0e35 100644 --- a/astrbot/core/agent/runners/base.py +++ b/astrbot/core/agent/runners/base.py @@ -25,7 +25,7 @@ async def reset( self, run_context: ContextWrapper[TContext], agent_hooks: BaseAgentRunHooks[TContext], - **kwargs: T.Any, + **kwargs: object, ) -> None: """Reset the agent to its initial state. This method should be called before starting a new run. diff --git a/astrbot/core/agent/runners/coze/coze_agent_runner.py b/astrbot/core/agent/runners/coze/coze_agent_runner.py index a8300bb71..c2e8a948e 100644 --- a/astrbot/core/agent/runners/coze/coze_agent_runner.py +++ b/astrbot/core/agent/runners/coze/coze_agent_runner.py @@ -70,7 +70,7 @@ async def reset( self.file_id_cache: dict[str, dict[str, str]] = {} @override - async def step(self): + async def step(self) -> T.AsyncGenerator[AgentResponse, None]: """ 执行 Coze Agent 的一个步骤 """ @@ -113,7 +113,7 @@ async def step_until_done( async for resp in self.step(): yield resp - async def _execute_coze_request(self): + async def _execute_coze_request(self) -> T.AsyncGenerator[AgentResponse, None]: """执行 Coze 请求的核心逻辑""" prompt = self.req.prompt or "" session_id = self.req.session_id or "unknown" diff --git a/astrbot/core/agent/runners/coze/coze_api_client.py b/astrbot/core/agent/runners/coze/coze_api_client.py index e8f3a1e24..a5e62520b 100644 --- a/astrbot/core/agent/runners/coze/coze_api_client.py +++ b/astrbot/core/agent/runners/coze/coze_api_client.py @@ -10,12 +10,12 @@ class CozeAPIClient: - def __init__(self, api_key: str, api_base: str = "https://api.coze.cn"): + def __init__(self, api_key: str, api_base: str = "https://api.coze.cn") -> None: self.api_key = api_key self.api_base = api_base self.session = None - async def _ensure_session(self): + async def _ensure_session(self) -> aiohttp.ClientSession: """确保HTTP session存在""" if self.session is None: connector = aiohttp.TCPConnector( @@ -208,7 +208,7 @@ async def chat_messages( except Exception as e: raise Exception(f"Coze API 流式请求失败: {e!s}") - async def clear_context(self, conversation_id: str): + async def clear_context(self, conversation_id: str) -> dict: """清空会话上下文 Args: @@ -247,7 +247,7 @@ async def get_message_list( order: str = "desc", limit: int = 10, offset: int = 0, - ): + ) -> dict: """获取消息列表 Args: @@ -277,7 +277,7 @@ async def get_message_list( logger.error(f"获取Coze消息列表失败: {e!s}") raise Exception(f"获取Coze消息列表失败: {e!s}") - async def close(self): + async def close(self) -> None: """关闭会话""" if self.session: await self.session.close() @@ -288,7 +288,7 @@ async def close(self): import asyncio import os - async def test_coze_api_client(): + async def test_coze_api_client() -> None: api_key = os.getenv("COZE_API_KEY", "") bot_id = os.getenv("COZE_BOT_ID", "") client = CozeAPIClient(api_key=api_key) diff --git a/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py b/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py index 7a095a60b..f14fae0ca 100644 --- a/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py +++ b/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py @@ -67,7 +67,7 @@ async def reset( if isinstance(self.timeout, str): self.timeout = int(self.timeout) - def has_rag_options(self): + def has_rag_options(self) -> bool: """判断是否有 RAG 选项 Returns: @@ -82,7 +82,7 @@ def has_rag_options(self): return False @override - async def step(self): + async def step(self) -> T.AsyncGenerator[AgentResponse, None]: """ 执行 Dashscope Agent 的一个步骤 """ @@ -124,7 +124,7 @@ async def step_until_done( yield resp def _consume_sync_generator( - self, response: T.Any, response_queue: queue.Queue + self, response: T.Iterable[object], response_queue: queue.Queue ) -> None: """在线程中消费同步generator,将结果放入队列 @@ -278,7 +278,7 @@ async def _build_request_payload( return payload async def _handle_streaming_response( - self, response: T.Any, session_id: str + self, response: object, session_id: str ) -> T.AsyncGenerator[AgentResponse, None]: """处理流式响应 @@ -292,7 +292,7 @@ async def _handle_streaming_response( response_queue = queue.Queue() consumer_thread = threading.Thread( target=self._consume_sync_generator, - args=(response, response_queue), + args=(T.cast(T.Iterable[object], response), response_queue), daemon=True, ) consumer_thread.start() @@ -319,14 +319,14 @@ async def _handle_streaming_response( ( output_text, chunk_doc_refs, - response, + response_obj, ) = await self._process_stream_chunk(chunk, output_text) - if response: - if response.type == "err": - yield response + if response_obj: + if response_obj.type == "err": + yield response_obj return - yield response + yield response_obj if chunk_doc_refs: doc_references = chunk_doc_refs @@ -366,7 +366,7 @@ async def _handle_streaming_response( data=AgentResponseData(chain=chain), ) - async def _execute_dashscope_request(self): + async def _execute_dashscope_request(self) -> T.AsyncGenerator[AgentResponse, None]: """执行 Dashscope 请求的核心逻辑""" prompt = self.req.prompt or "" session_id = self.req.session_id or "unknown" diff --git a/astrbot/core/agent/runners/dify/dify_agent_runner.py b/astrbot/core/agent/runners/dify/dify_agent_runner.py index d9a8b7cd6..64f028dd9 100644 --- a/astrbot/core/agent/runners/dify/dify_agent_runner.py +++ b/astrbot/core/agent/runners/dify/dify_agent_runner.py @@ -63,7 +63,7 @@ async def reset( self.api_client = DifyAPIClient(self.api_key, self.api_base) @override - async def step(self): + async def step(self) -> T.AsyncGenerator[AgentResponse, None]: """ 执行 Dify Agent 的一个步骤 """ @@ -106,7 +106,7 @@ async def step_until_done( async for resp in self.step(): yield resp - async def _execute_dify_request(self): + async def _execute_dify_request(self) -> T.AsyncGenerator[AgentResponse, None]: """执行 Dify 请求的核心逻辑""" prompt = self.req.prompt or "" session_id = self.req.session_id or "unknown" @@ -285,7 +285,7 @@ async def parse_dify_result(self, chunk: dict | str) -> MessageChain: # Chat return MessageChain(chain=[Comp.Plain(chunk)]) - async def parse_file(item: dict): + async def parse_file(item: dict) -> object: match item["type"]: case "image": return Comp.Image(file=item["url"], url=item["url"]) diff --git a/astrbot/core/agent/runners/dify/dify_api_client.py b/astrbot/core/agent/runners/dify/dify_api_client.py index d9c6556cf..9e683d257 100644 --- a/astrbot/core/agent/runners/dify/dify_api_client.py +++ b/astrbot/core/agent/runners/dify/dify_api_client.py @@ -31,7 +31,7 @@ async def _stream_sse(resp: ClientResponse) -> AsyncGenerator[dict, None]: class DifyAPIClient: - def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"): + def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1") -> None: self.api_key = api_key self.api_base = api_base self.session = ClientSession(trust_env=True) @@ -77,7 +77,7 @@ async def workflow_run( response_mode: str = "streaming", files: list[dict[str, Any]] | None = None, timeout: float = 60, - ): + ) -> AsyncGenerator[dict[str, Any], None]: if files is None: files = [] url = f"{self.api_base}/workflows/run" @@ -155,10 +155,10 @@ async def file_upload( raise Exception(f"Dify 文件上传失败:{resp.status}. {text}") return await resp.json() # {"id": "xxx", ...} - async def close(self): + async def close(self) -> None: await self.session.close() - async def get_chat_convs(self, user: str, limit: int = 20): + async def get_chat_convs(self, user: str, limit: int = 20) -> dict: # conversations. GET url = f"{self.api_base}/conversations" payload = { @@ -168,7 +168,7 @@ async def get_chat_convs(self, user: str, limit: int = 20): async with self.session.get(url, params=payload, headers=self.headers) as resp: return await resp.json() - async def delete_chat_conv(self, user: str, conversation_id: str): + async def delete_chat_conv(self, user: str, conversation_id: str) -> dict: # conversation. DELETE url = f"{self.api_base}/conversations/{conversation_id}" payload = { @@ -183,7 +183,7 @@ async def rename( name: str, user: str, auto_generate: bool = False, - ): + ) -> dict: # /conversations/:conversation_id/name url = f"{self.api_base}/conversations/{conversation_id}/name" payload = { diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 606163685..e20c46c4a 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -134,7 +134,7 @@ async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]: yield await self.provider.text_chat(**payload) @override - async def step(self): + async def step(self) -> T.AsyncGenerator[AgentResponse, None]: """Process a single step of the agent. This method should return the result of the step. """ diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py index 7f30f44ef..03679a245 100644 --- a/astrbot/core/agent/tool.py +++ b/astrbot/core/agent/tool.py @@ -1,4 +1,4 @@ -from collections.abc import AsyncGenerator, Awaitable, Callable +from collections.abc import AsyncGenerator, Awaitable, Callable, Iterator from typing import Any, Generic import jsonschema @@ -58,10 +58,12 @@ class FunctionTool(ToolSchema, Generic[TContext]): You can ignore it when integrating with other frameworks. """ - def __repr__(self): + def __repr__(self) -> str: return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})" - async def call(self, context: ContextWrapper[TContext], **kwargs) -> ToolExecResult: + async def call( + self, context: ContextWrapper[TContext], **kwargs: object + ) -> ToolExecResult: """Run the tool with the given arguments. The handler field has priority.""" raise NotImplementedError( "FunctionTool.call() must be implemented by subclasses or set a handler." @@ -82,7 +84,7 @@ def empty(self) -> bool: """Check if the tool set is empty.""" return len(self.tools) == 0 - def add_tool(self, tool: FunctionTool): + def add_tool(self, tool: FunctionTool) -> None: """Add a tool to the set.""" # 检查是否已存在同名工具 for i, existing_tool in enumerate(self.tools): @@ -91,7 +93,7 @@ def add_tool(self, tool: FunctionTool): return self.tools.append(tool) - def remove_tool(self, name: str): + def remove_tool(self, name: str) -> None: """Remove a tool by its name.""" self.tools = [tool for tool in self.tools if tool.name != name] @@ -109,7 +111,7 @@ def add_func( func_args: list, desc: str, handler: Callable[..., Awaitable[Any]], - ): + ) -> None: """Add a function tool to the set.""" params = { "type": "object", # hard-coded here @@ -129,7 +131,7 @@ def add_func( self.add_tool(_func) @deprecated(reason="Use remove_tool() instead", version="4.0.0") - def remove_func(self, name: str): + def remove_func(self, name: str) -> None: """Remove a function tool by its name.""" self.remove_tool(name) @@ -259,32 +261,34 @@ def convert_schema(schema: dict) -> dict: return declarations @deprecated(reason="Use openai_schema() instead", version="4.0.0") - def get_func_desc_openai_style(self, omit_empty_parameter_field: bool = False): + def get_func_desc_openai_style( + self, omit_empty_parameter_field: bool = False + ) -> list[dict]: return self.openai_schema(omit_empty_parameter_field) @deprecated(reason="Use anthropic_schema() instead", version="4.0.0") - def get_func_desc_anthropic_style(self): + def get_func_desc_anthropic_style(self) -> list[dict]: return self.anthropic_schema() @deprecated(reason="Use google_schema() instead", version="4.0.0") - def get_func_desc_google_genai_style(self): + def get_func_desc_google_genai_style(self) -> dict: return self.google_schema() def names(self) -> list[str]: """获取所有工具的名称列表""" return [tool.name for tool in self.tools] - def __len__(self): + def __len__(self) -> int: return len(self.tools) - def __bool__(self): + def __bool__(self) -> bool: return len(self.tools) > 0 - def __iter__(self): + def __iter__(self) -> Iterator: return iter(self.tools) - def __repr__(self): + def __repr__(self) -> str: return f"ToolSet(tools={self.tools})" - def __str__(self): + def __str__(self) -> str: return f"ToolSet(tools={self.tools})" diff --git a/astrbot/core/agent/tool_executor.py b/astrbot/core/agent/tool_executor.py index 2704119d4..f73d37b7b 100644 --- a/astrbot/core/agent/tool_executor.py +++ b/astrbot/core/agent/tool_executor.py @@ -1,17 +1,20 @@ from collections.abc import AsyncGenerator -from typing import Any, Generic +from typing import Generic import mcp from .run_context import ContextWrapper, TContext from .tool import FunctionTool +# 子类工具执行器的统一返回类型(yield 值) +ToolExecResult = mcp.types.CallToolResult | str | None + class BaseFunctionToolExecutor(Generic[TContext]): @classmethod async def execute( cls, - tool: FunctionTool, + tool: FunctionTool[TContext], run_context: ContextWrapper[TContext], - **tool_args, - ) -> AsyncGenerator[Any | mcp.types.CallToolResult, None]: ... + **tool_args: object, + ) -> AsyncGenerator[ToolExecResult, None]: ... diff --git a/astrbot/core/astr_agent_hooks.py b/astrbot/core/astr_agent_hooks.py index 9d85de0cc..5099185ba 100644 --- a/astrbot/core/astr_agent_hooks.py +++ b/astrbot/core/astr_agent_hooks.py @@ -1,5 +1,3 @@ -from typing import Any - from mcp.types import CallToolResult from astrbot.core.agent.hooks import BaseAgentRunHooks @@ -7,11 +5,16 @@ from astrbot.core.agent.tool import FunctionTool from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.pipeline.context_utils import call_event_hook +from astrbot.core.provider.entities import LLMResponse from astrbot.core.star.star_handler import EventType class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]): - async def on_agent_done(self, run_context, llm_response): + async def on_agent_done( + self, + run_context: ContextWrapper[AstrAgentContext], + llm_response: LLMResponse, + ) -> None: # 执行事件钩子 if llm_response and llm_response.reasoning_content: # we will use this in result_decorate stage to inject reasoning content to chain @@ -28,10 +31,10 @@ async def on_agent_done(self, run_context, llm_response): async def on_tool_end( self, run_context: ContextWrapper[AstrAgentContext], - tool: FunctionTool[Any], - tool_args: dict | None, + tool: FunctionTool[AstrAgentContext], + tool_args: dict[str, object] | None, tool_result: CallToolResult | None, - ): + ) -> None: run_context.context.event.clear_result() diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index 5d40f48fa..26596f296 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -10,7 +10,7 @@ from astrbot.core.agent.mcp_client import MCPTool from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.agent.tool import FunctionTool, ToolSet -from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor +from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor, ToolExecResult from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.message.message_event_result import ( CommandResult, @@ -22,7 +22,12 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): @classmethod - async def execute(cls, tool, run_context, **tool_args): + async def execute( + cls, + tool: HandoffTool | MCPTool | FunctionTool[AstrAgentContext], + run_context: ContextWrapper[AstrAgentContext], + **tool_args: object, + ) -> T.AsyncGenerator[ToolExecResult, None]: """执行函数调用。 Args: @@ -53,9 +58,16 @@ async def _execute_handoff( cls, tool: HandoffTool, run_context: ContextWrapper[AstrAgentContext], - **tool_args, - ): - input_ = tool_args.get("input") + **tool_args: object, + ) -> T.AsyncGenerator[ToolExecResult, None]: + # 将输入统一转换/缩窄为 str | None,满足下游调用的签名 + val = tool_args.get("input") + if val is None: + input_: str | None = None + elif isinstance(val, str): + input_ = val + else: + input_ = str(val) # make toolset for the agent tools = tool.agent.tools @@ -91,10 +103,10 @@ async def _execute_handoff( @classmethod async def _execute_local( cls, - tool: FunctionTool, + tool: FunctionTool[AstrAgentContext], run_context: ContextWrapper[AstrAgentContext], - **tool_args, - ): + **tool_args: object, + ) -> T.AsyncGenerator[ToolExecResult, None]: event = run_context.context.event if not event: raise ValueError("Event must be provided for local function tools.") @@ -173,10 +185,10 @@ async def _execute_local( @classmethod async def _execute_mcp( cls, - tool: FunctionTool, + tool: FunctionTool[AstrAgentContext], run_context: ContextWrapper[AstrAgentContext], - **tool_args, - ): + **tool_args: object, + ) -> T.AsyncGenerator[ToolExecResult, None]: res = await tool.call(run_context, **tool_args) if not res: return @@ -191,9 +203,9 @@ async def call_local_llm_tool( | T.AsyncGenerator[MessageEventResult | CommandResult | str | None, None], ], method_name: str, - *args, - **kwargs, -) -> T.AsyncGenerator[T.Any, None]: + *args: object, + **kwargs: object, +) -> T.AsyncGenerator[ToolExecResult, None]: """执行本地 LLM 工具的处理函数并处理其返回结果""" ready_to_call = None # 一个协程或者异步生成器 diff --git a/astrbot/core/astrbot_config_mgr.py b/astrbot/core/astrbot_config_mgr.py index 3a1353ce5..c2bfb1c37 100644 --- a/astrbot/core/astrbot_config_mgr.py +++ b/astrbot/core/astrbot_config_mgr.py @@ -36,7 +36,7 @@ def __init__( default_config: AstrBotConfig, ucr: UmopConfigRouter, sp: SharedPreferences, - ): + ) -> None: self.sp = sp self.ucr = ucr self.confs: dict[str, AstrBotConfig] = {} @@ -56,7 +56,7 @@ def _get_abconf_data(self) -> dict: ) return self.abconf_data - def _load_all_configs(self): + def _load_all_configs(self) -> None: """Load all configurations from the shared preferences.""" abconf_data = self._get_abconf_data() self.abconf_data = abconf_data diff --git a/astrbot/core/backup/exporter.py b/astrbot/core/backup/exporter.py index 51c4a4650..574b191b9 100644 --- a/astrbot/core/backup/exporter.py +++ b/astrbot/core/backup/exporter.py @@ -8,15 +8,18 @@ import json import os import zipfile +from collections.abc import Awaitable, Callable from datetime import datetime, timezone from pathlib import Path from typing import TYPE_CHECKING, Any from sqlalchemy import select +from sqlmodel import SQLModel from astrbot.core import logger from astrbot.core.config.default import VERSION from astrbot.core.db import BaseDatabase +from astrbot.core.knowledge_base.kb_helper import KBHelper from astrbot.core.utils.astrbot_path import ( get_astrbot_backups_path, get_astrbot_data_path, @@ -59,7 +62,7 @@ def __init__( main_db: BaseDatabase, kb_manager: "KnowledgeBaseManager | None" = None, config_path: str = CMD_CONFIG_FILE_PATH, - ): + ) -> None: self.main_db = main_db self.kb_manager = kb_manager self.config_path = config_path @@ -68,7 +71,8 @@ def __init__( async def export_all( self, output_dir: str | None = None, - progress_callback: Any | None = None, + progress_callback: Callable[[str, int, int, str], Awaitable[None]] + | None = None, ) -> str: """导出所有数据到 ZIP 文件 @@ -248,13 +252,18 @@ async def _export_kb_metadata(self) -> dict[str, list[dict]]: return export_data - async def _export_kb_documents(self, kb_helper: Any) -> dict[str, Any]: + async def _export_kb_documents(self, kb_helper: KBHelper) -> dict[str, Any]: """导出知识库的文档块数据""" try: + from astrbot.core.db.vec_db.base import BaseVecDB from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB - vec_db: FaissVecDB = kb_helper.vec_db - if not vec_db or not vec_db.document_storage: + vec_db: BaseVecDB = kb_helper.vec_db + if ( + not vec_db + or not isinstance(vec_db, FaissVecDB) + or not vec_db.document_storage + ): return {"documents": []} # 获取所有文档 @@ -272,7 +281,7 @@ async def _export_kb_documents(self, kb_helper: Any) -> dict[str, Any]: async def _export_faiss_index( self, zf: zipfile.ZipFile, - kb_helper: Any, + kb_helper: KBHelper, kb_id: str, ) -> None: """导出 FAISS 索引文件""" @@ -286,7 +295,7 @@ async def _export_faiss_index( logger.warning(f"导出 FAISS 索引失败: {e}") async def _export_kb_media_files( - self, zf: zipfile.ZipFile, kb_helper: Any, kb_id: str + self, zf: zipfile.ZipFile, kb_helper: KBHelper, kb_id: str ) -> None: """导出知识库的多媒体文件""" try: @@ -371,7 +380,7 @@ async def _export_attachments( except Exception as e: logger.warning(f"导出附件失败: {e}") - def _model_to_dict(self, record: Any) -> dict: + def _model_to_dict(self, record: SQLModel) -> dict: """将 SQLModel 实例转换为字典 这是数据库无关的序列化方式,支持未来迁移到其他数据库。 diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py index f36a79cf5..985baf7c6 100644 --- a/astrbot/core/backup/importer.py +++ b/astrbot/core/backup/importer.py @@ -11,10 +11,11 @@ import os import shutil import zipfile +from collections.abc import Awaitable, Callable from dataclasses import dataclass, field from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from sqlalchemy import delete @@ -110,7 +111,7 @@ def to_dict(self) -> dict: class ImportResult: """导入结果""" - def __init__(self): + def __init__(self) -> None: self.success = True self.imported_tables: dict[str, int] = {} self.imported_files: dict[str, int] = {} @@ -161,7 +162,7 @@ def __init__( kb_manager: "KnowledgeBaseManager | None" = None, config_path: str = CMD_CONFIG_FILE_PATH, kb_root_dir: str = KB_PATH, - ): + ) -> None: self.main_db = main_db self.kb_manager = kb_manager self.config_path = config_path @@ -283,7 +284,8 @@ async def import_all( self, zip_path: str, mode: str = "replace", # "replace" 清空后导入 - progress_callback: Any | None = None, + progress_callback: Callable[[str, int, int, str], Awaitable[None]] + | None = None, ) -> ImportResult: """从 ZIP 文件导入所有数据 diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index 2208ee766..699b1966e 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -33,9 +33,8 @@ def __init__( config_path: str = ASTRBOT_CONFIG_PATH, default_config: dict = DEFAULT_CONFIG, schema: dict | None = None, - ): + ) -> None: super().__init__() - # 调用父类的 __setattr__ 方法,防止保存配置时将此属性写入配置文件 object.__setattr__(self, "config_path", config_path) object.__setattr__(self, "default_config", default_config) @@ -66,7 +65,7 @@ def _config_schema_to_default_config(self, schema: dict) -> dict: """将 Schema 转换成 Config""" conf = {} - def _parse_schema(schema: dict, conf: dict): + def _parse_schema(schema: dict, conf: dict) -> None: for k, v in schema.items(): if v["type"] not in DEFAULT_VALUE_MAP: raise TypeError( @@ -89,7 +88,9 @@ def _parse_schema(schema: dict, conf: dict): return conf - def check_config_integrity(self, refer_conf: dict, conf: dict, path=""): + def check_config_integrity( + self, refer_conf: dict, conf: dict, path: str = "" + ) -> bool: """检查配置完整性,如果有新的配置项或顺序不一致则返回 True""" has_new = False @@ -148,7 +149,7 @@ def check_config_integrity(self, refer_conf: dict, conf: dict, path=""): return has_new - def save_config(self, replace_config: dict | None = None): + def save_config(self, replace_config: dict | None = None) -> None: """将配置写入文件 如果传入 replace_config,则将配置替换为 replace_config @@ -158,20 +159,20 @@ def save_config(self, replace_config: dict | None = None): with open(self.config_path, "w", encoding="utf-8-sig") as f: json.dump(self, f, indent=2, ensure_ascii=False) - def __getattr__(self, item): + def __getattr__(self, item: str) -> object: try: return self[item] except KeyError: return None - def __delattr__(self, key): + def __delattr__(self, key: str) -> None: try: del self[key] self.save_config() except KeyError: raise AttributeError(f"没有找到 Key: '{key}'") - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: object) -> None: self[key] = value def check_exist(self) -> bool: diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index a0a0c0e2f..70403c09f 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -16,7 +16,7 @@ class ConversationManager: """负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。""" - def __init__(self, db_helper: BaseDatabase): + def __init__(self, db_helper: BaseDatabase) -> None: self.session_conversations: dict[str, str] = {} self.db = db_helper self.save_interval = 60 # 每 60 秒保存一次 @@ -106,7 +106,9 @@ async def new_conversation( await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id) return conv.conversation_id - async def switch_conversation(self, unified_msg_origin: str, conversation_id: str): + async def switch_conversation( + self, unified_msg_origin: str, conversation_id: str + ) -> None: """切换会话的对话 Args: @@ -121,7 +123,7 @@ async def delete_conversation( self, unified_msg_origin: str, conversation_id: str | None = None, - ): + ) -> None: """删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话 Args: @@ -138,7 +140,7 @@ async def delete_conversation( self.session_conversations.pop(unified_msg_origin, None) await sp.session_remove(unified_msg_origin, "sel_conv_id") - async def delete_conversations_by_user_id(self, unified_msg_origin: str): + async def delete_conversations_by_user_id(self, unified_msg_origin: str) -> None: """删除会话的所有对话 Args: @@ -224,7 +226,7 @@ async def get_filtered_conversations( page_size: int = 20, platform_ids: list[str] | None = None, search_query: str = "", - **kwargs, + **kwargs: object, ) -> tuple[list[Conversation], int]: """获取过滤后的对话列表. diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 3a79e41c2..9cb1dfd83 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -39,7 +39,7 @@ def __init__(self) -> None: expire_on_commit=False, ) - async def initialize(self): + async def initialize(self) -> None: """初始化数据库连接""" @asynccontextmanager @@ -105,7 +105,7 @@ async def get_conversations( ... @abc.abstractmethod - async def get_conversation_by_id(self, cid: str) -> ConversationV2: + async def get_conversation_by_id(self, cid: str) -> ConversationV2 | None: """Get a specific conversation by its ID.""" ... @@ -125,7 +125,7 @@ async def get_filtered_conversations( page_size: int = 20, platform_ids: list[str] | None = None, search_query: str = "", - **kwargs, + **kwargs: object, ) -> tuple[list[ConversationV2], int]: """Get conversations filtered by platform IDs and search query.""" ... @@ -153,7 +153,7 @@ async def update_conversation( persona_id: str | None = None, content: list[dict] | None = None, token_usage: int | None = None, - ) -> None: + ) -> ConversationV2 | None: """Update a conversation's history.""" ... @@ -214,12 +214,12 @@ async def insert_attachment( path: str, type: str, mime_type: str, - ): + ) -> Attachment: """Insert a new attachment record.""" ... @abc.abstractmethod - async def get_attachment_by_id(self, attachment_id: str) -> Attachment: + async def get_attachment_by_id(self, attachment_id: str) -> Attachment | None: """Get an attachment by its ID.""" ... @@ -256,7 +256,7 @@ async def insert_persona( ... @abc.abstractmethod - async def get_persona_by_id(self, persona_id: str) -> Persona: + async def get_persona_by_id(self, persona_id: str) -> Persona | None: """Get a persona by its ID.""" ... @@ -293,7 +293,9 @@ async def insert_preference_or_update( ... @abc.abstractmethod - async def get_preference(self, scope: str, scope_id: str, key: str) -> Preference: + async def get_preference( + self, scope: str, scope_id: str, key: str + ) -> Preference | None: """Get a preference by scope ID and key.""" ... diff --git a/astrbot/core/db/migration/migra_3_to_4.py b/astrbot/core/db/migration/migra_3_to_4.py index 66b72d5cb..727d97b29 100644 --- a/astrbot/core/db/migration/migra_3_to_4.py +++ b/astrbot/core/db/migration/migra_3_to_4.py @@ -43,7 +43,7 @@ def get_platform_type( async def migration_conversation_table( db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]], -): +) -> None: db_helper_v3 = SQLiteV3DatabaseV3( db_path=DB_PATH.replace("data_v4.db", "data_v3.db"), ) @@ -101,7 +101,7 @@ async def migration_conversation_table( async def migration_platform_table( db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]], -): +) -> None: db_helper_v3 = SQLiteV3DatabaseV3( db_path=DB_PATH.replace("data_v4.db", "data_v3.db"), ) @@ -180,7 +180,7 @@ async def migration_platform_table( async def migration_webchat_data( db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]], -): +) -> None: """迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中""" db_helper_v3 = SQLiteV3DatabaseV3( db_path=DB_PATH.replace("data_v4.db", "data_v3.db"), @@ -236,7 +236,7 @@ async def migration_webchat_data( async def migration_persona_data( db_helper: BaseDatabase, astrbot_config: AstrBotConfig, -): +) -> None: """迁移 Persona 数据到新的表中。 旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。 """ @@ -279,7 +279,7 @@ async def migration_persona_data( async def migration_preferences( db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]], -): +) -> None: # 1. global scope migration keys = [ "inactivated_llm_tools", diff --git a/astrbot/core/db/migration/migra_45_to_46.py b/astrbot/core/db/migration/migra_45_to_46.py index dc70026f9..58736ab51 100644 --- a/astrbot/core/db/migration/migra_45_to_46.py +++ b/astrbot/core/db/migration/migra_45_to_46.py @@ -3,7 +3,7 @@ from astrbot.core.umop_config_router import UmopConfigRouter -async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter): +async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter) -> None: abconf_data = acm.abconf_data if not isinstance(abconf_data, dict): diff --git a/astrbot/core/db/migration/migra_webchat_session.py b/astrbot/core/db/migration/migra_webchat_session.py index ff0b5ca6f..46025fc64 100644 --- a/astrbot/core/db/migration/migra_webchat_session.py +++ b/astrbot/core/db/migration/migra_webchat_session.py @@ -17,7 +17,7 @@ from astrbot.core.db.po import ConversationV2, PlatformMessageHistory, PlatformSession -async def migrate_webchat_session(db_helper: BaseDatabase): +async def migrate_webchat_session(db_helper: BaseDatabase) -> None: """Create PlatformSession records from platform_message_history. This migration extracts all unique user_ids from platform_message_history diff --git a/astrbot/core/db/migration/shared_preferences_v3.py b/astrbot/core/db/migration/shared_preferences_v3.py index 3abcb1a66..72802f057 100644 --- a/astrbot/core/db/migration/shared_preferences_v3.py +++ b/astrbot/core/db/migration/shared_preferences_v3.py @@ -8,13 +8,13 @@ class SharedPreferences: - def __init__(self, path=None): + def __init__(self, path: str | None = None) -> None: if path is None: path = os.path.join(get_astrbot_data_path(), "shared_preferences.json") self.path = path self._data = self._load_preferences() - def _load_preferences(self): + def _load_preferences(self) -> dict: if os.path.exists(self.path): try: with open(self.path) as f: @@ -23,24 +23,24 @@ def _load_preferences(self): os.remove(self.path) return {} - def _save_preferences(self): + def _save_preferences(self) -> None: with open(self.path, "w") as f: json.dump(self._data, f, indent=4, ensure_ascii=False) f.flush() - def get(self, key, default: _VT = None) -> _VT: + def get(self, key: str, default: _VT = None) -> _VT: return self._data.get(key, default) - def put(self, key, value): + def put(self, key: str, value: object) -> None: self._data[key] = value self._save_preferences() - def remove(self, key): + def remove(self, key: str) -> None: if key in self._data: del self._data[key] self._save_preferences() - def clear(self): + def clear(self) -> None: self._data.clear() self._save_preferences() diff --git a/astrbot/core/db/migration/sqlite_v3.py b/astrbot/core/db/migration/sqlite_v3.py index b1a780d48..b326ebb44 100644 --- a/astrbot/core/db/migration/sqlite_v3.py +++ b/astrbot/core/db/migration/sqlite_v3.py @@ -127,7 +127,7 @@ def _get_conn(self, db_path: str) -> sqlite3.Connection: conn.text_factory = str return conn - def _exec_sql(self, sql: str, params: tuple | None = None): + def _exec_sql(self, sql: str, params: tuple | None = None) -> None: conn = self.conn try: c = self.conn.cursor() @@ -144,7 +144,7 @@ def _exec_sql(self, sql: str, params: tuple | None = None): conn.commit() - def insert_platform_metrics(self, metrics: dict): + def insert_platform_metrics(self, metrics: dict) -> None: for k, v in metrics.items(): self._exec_sql( """ @@ -153,7 +153,7 @@ def insert_platform_metrics(self, metrics: dict): (k, v, int(time.time())), ) - def insert_llm_metrics(self, metrics: dict): + def insert_llm_metrics(self, metrics: dict) -> None: for k, v in metrics.items(): self._exec_sql( """ @@ -249,7 +249,7 @@ def get_conversation_by_user_id( return Conversation(*res) - def new_conversation(self, user_id: str, cid: str): + def new_conversation(self, user_id: str, cid: str) -> None: history = "[]" updated_at = int(time.time()) created_at = updated_at @@ -287,7 +287,7 @@ def get_conversations(self, user_id: str) -> list[Conversation]: ) return conversations - def update_conversation(self, user_id: str, cid: str, history: str): + def update_conversation(self, user_id: str, cid: str, history: str) -> None: """更新对话,并且同时更新时间""" updated_at = int(time.time()) self._exec_sql( @@ -297,7 +297,7 @@ def update_conversation(self, user_id: str, cid: str, history: str): (history, updated_at, user_id, cid), ) - def update_conversation_title(self, user_id: str, cid: str, title: str): + def update_conversation_title(self, user_id: str, cid: str, title: str) -> None: self._exec_sql( """ UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ? @@ -305,7 +305,9 @@ def update_conversation_title(self, user_id: str, cid: str, title: str): (title, user_id, cid), ) - def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str): + def update_conversation_persona_id( + self, user_id: str, cid: str, persona_id: str + ) -> None: self._exec_sql( """ UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ? @@ -313,7 +315,7 @@ def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str (persona_id, user_id, cid), ) - def delete_conversation(self, user_id: str, cid: str): + def delete_conversation(self, user_id: str, cid: str) -> None: self._exec_sql( """ DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ? diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index 7422a5cc2..4a61355b0 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -1,6 +1,11 @@ import asyncio import threading import typing as T + +try: + from typing import Unpack # type: ignore[attr-defined] +except ImportError: + from typing_extensions import Unpack from collections.abc import Awaitable, Callable from datetime import datetime, timedelta, timezone @@ -32,6 +37,30 @@ TxResult = T.TypeVar("TxResult") +class FilterKwargs(T.TypedDict, total=False): + message_types: list[str] + platforms: list[str] + exclude_ids: list[str] + exclude_platforms: list[str] + + +class UpdateKwargs(T.TypedDict, total=False): + plugin_name: str + module_path: str + original_command: str + resolved_command: str | None + enabled: bool + keep_original_alias: bool + conflict_key: str + resolution_strategy: str | None + note: str | None + extra_data: dict[str, object] | None + auto_managed: bool + status: str + resolution: str | None + auto_generated: bool + + class SQLiteDatabase(BaseDatabase): def __init__(self, db_path: str) -> None: self.db_path = db_path @@ -57,10 +86,10 @@ async def initialize(self) -> None: async def insert_platform_stats( self, - platform_id, - platform_type, - count=1, - timestamp=None, + platform_id: str, + platform_type: str, + count: int = 1, + timestamp: datetime | None = None, ) -> None: """Insert a new platform statistic record.""" async with self.get_db() as session: @@ -121,7 +150,9 @@ async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat # Conversation Management # ==== - async def get_conversations(self, user_id=None, platform_id=None): + async def get_conversations( + self, user_id: str | None = None, platform_id: str | None = None + ) -> list[ConversationV2]: async with self.get_db() as session: session: AsyncSession query = select(ConversationV2) @@ -134,16 +165,18 @@ async def get_conversations(self, user_id=None, platform_id=None): query = query.order_by(desc(ConversationV2.created_at)) result = await session.execute(query) - return result.scalars().all() + return result.scalars().all() # type: ignore - async def get_conversation_by_id(self, cid): + async def get_conversation_by_id(self, cid: str) -> ConversationV2 | None: async with self.get_db() as session: session: AsyncSession query = select(ConversationV2).where(ConversationV2.conversation_id == cid) result = await session.execute(query) return result.scalar_one_or_none() - async def get_all_conversations(self, page=1, page_size=20): + async def get_all_conversations( + self, page: int = 1, page_size: int = 20 + ) -> list[ConversationV2]: async with self.get_db() as session: session: AsyncSession offset = (page - 1) * page_size @@ -153,16 +186,16 @@ async def get_all_conversations(self, page=1, page_size=20): .offset(offset) .limit(page_size), ) - return result.scalars().all() + return result.scalars().all() # type:ignore async def get_filtered_conversations( self, - page=1, - page_size=20, - platform_ids=None, - search_query="", - **kwargs, - ): + page: int = 1, + page_size: int = 20, + platform_ids: list[str] | None = None, + search_query: str = "", + **kwargs: Unpack[FilterKwargs], + ) -> tuple[list[ConversationV2], int]: async with self.get_db() as session: session: AsyncSession # Build the base query with filters @@ -207,19 +240,19 @@ async def get_filtered_conversations( result = await session.execute(result_query) conversations = result.scalars().all() - return conversations, total + return conversations, total # type:ignore async def create_conversation( self, - user_id, - platform_id, - content=None, - title=None, - persona_id=None, - cid=None, - created_at=None, - updated_at=None, - ): + user_id: str, + platform_id: str, + content: list[dict] | None = None, + title: str | None = None, + persona_id: str | None = None, + cid: str | None = None, + created_at: datetime | None = None, + updated_at: datetime | None = None, + ) -> ConversationV2: kwargs = {} if cid: kwargs["conversation_id"] = cid @@ -242,8 +275,13 @@ async def create_conversation( return new_conversation async def update_conversation( - self, cid, title=None, persona_id=None, content=None, token_usage=None - ): + self, + cid: str, + title: str | None = None, + persona_id: str | None = None, + content: list[dict] | None = None, + token_usage: int | None = None, + ) -> ConversationV2 | None: async with self.get_db() as session: session: AsyncSession async with session.begin(): @@ -265,7 +303,7 @@ async def update_conversation( await session.execute(query) return await self.get_conversation_by_id(cid) - async def delete_conversation(self, cid): + async def delete_conversation(self, cid: str) -> None: async with self.get_db() as session: session: AsyncSession async with session.begin(): @@ -287,10 +325,10 @@ async def delete_conversations_by_user_id(self, user_id: str) -> None: async def get_session_conversations( self, - page=1, - page_size=20, - search_query=None, - platform=None, + page: int = 1, + page_size: int = 20, + search_query: str | None = None, + platform: str | None = None, ) -> tuple[list[dict], int]: """Get paginated session conversations with joined conversation and persona details.""" async with self.get_db() as session: @@ -396,12 +434,12 @@ async def get_session_conversations( async def insert_platform_message_history( self, - platform_id, - user_id, - content, - sender_id=None, - sender_name=None, - ): + platform_id: str, + user_id: str, + content: dict, + sender_id: str | None = None, + sender_name: str | None = None, + ) -> PlatformMessageHistory: """Insert a new platform message history record.""" async with self.get_db() as session: session: AsyncSession @@ -417,12 +455,9 @@ async def insert_platform_message_history( return new_history async def delete_platform_message_offset( - self, - platform_id, - user_id, - offset_sec=86400, - ): - """Delete platform message history records newer than the specified offset.""" + self, platform_id: str, user_id: str, offset_sec: int = 86400 + ) -> None: + """Delete platform message history records older than the specified offset.""" async with self.get_db() as session: session: AsyncSession async with session.begin(): @@ -437,12 +472,8 @@ async def delete_platform_message_offset( ) async def get_platform_message_history( - self, - platform_id, - user_id, - page=1, - page_size=20, - ): + self, platform_id: str, user_id: str, page: int = 1, page_size: int = 20 + ) -> list[PlatformMessageHistory]: """Get platform message history records.""" async with self.get_db() as session: session: AsyncSession @@ -456,7 +487,7 @@ async def get_platform_message_history( .order_by(desc(PlatformMessageHistory.created_at)) ) result = await session.execute(query.offset(offset).limit(page_size)) - return result.scalars().all() + return result.scalars().all() # type:ignore async def get_platform_message_history_by_id( self, message_id: int @@ -470,7 +501,9 @@ async def get_platform_message_history_by_id( result = await session.execute(query) return result.scalar_one_or_none() - async def insert_attachment(self, path, type, mime_type): + async def insert_attachment( + self, path: str, type: str, mime_type: str + ) -> Attachment: """Insert a new attachment record.""" async with self.get_db() as session: session: AsyncSession @@ -483,7 +516,7 @@ async def insert_attachment(self, path, type, mime_type): session.add(new_attachment) return new_attachment - async def get_attachment_by_id(self, attachment_id): + async def get_attachment_by_id(self, attachment_id: str) -> Attachment | None: """Get an attachment by its ID.""" async with self.get_db() as session: session: AsyncSession @@ -535,11 +568,11 @@ async def delete_attachments(self, attachment_ids: list[str]) -> int: async def insert_persona( self, - persona_id, - system_prompt, - begin_dialogs=None, - tools=None, - ): + persona_id: str, + system_prompt: str, + begin_dialogs: list[str] | None = None, + tools: list[str] | None = None, + ) -> Persona: """Insert a new persona record.""" async with self.get_db() as session: session: AsyncSession @@ -553,7 +586,7 @@ async def insert_persona( session.add(new_persona) return new_persona - async def get_persona_by_id(self, persona_id): + async def get_persona_by_id(self, persona_id: str) -> Persona | None: """Get a persona by its ID.""" async with self.get_db() as session: session: AsyncSession @@ -561,21 +594,21 @@ async def get_persona_by_id(self, persona_id): result = await session.execute(query) return result.scalar_one_or_none() - async def get_personas(self): + async def get_personas(self) -> list[Persona]: """Get all personas for a specific bot.""" async with self.get_db() as session: session: AsyncSession query = select(Persona) result = await session.execute(query) - return result.scalars().all() + return result.scalars().all() # type:ignore async def update_persona( self, - persona_id, - system_prompt=None, - begin_dialogs=None, - tools=NOT_GIVEN, - ): + persona_id: str, + system_prompt: str | None = None, + begin_dialogs: list[str] | None = None, + tools: list[str] | None = None, + ) -> Persona | None: """Update a persona's system prompt or begin dialogs.""" async with self.get_db() as session: session: AsyncSession @@ -594,7 +627,7 @@ async def update_persona( await session.execute(query) return await self.get_persona_by_id(persona_id) - async def delete_persona(self, persona_id): + async def delete_persona(self, persona_id: str) -> None: """Delete a persona by its ID.""" async with self.get_db() as session: session: AsyncSession @@ -603,7 +636,9 @@ async def delete_persona(self, persona_id): delete(Persona).where(col(Persona.persona_id) == persona_id), ) - async def insert_preference_or_update(self, scope, scope_id, key, value): + async def insert_preference_or_update( + self, scope: str, scope_id: str, key: str, value: dict + ) -> Preference: """Insert a new preference record or update if it exists.""" async with self.get_db() as session: session: AsyncSession @@ -627,7 +662,9 @@ async def insert_preference_or_update(self, scope, scope_id, key, value): session.add(new_preference) return existing_preference or new_preference - async def get_preference(self, scope, scope_id, key): + async def get_preference( + self, scope: str, scope_id: str, key: str + ) -> Preference | None: """Get a preference by key.""" async with self.get_db() as session: session: AsyncSession @@ -639,7 +676,9 @@ async def get_preference(self, scope, scope_id, key): result = await session.execute(query) return result.scalar_one_or_none() - async def get_preferences(self, scope, scope_id=None, key=None): + async def get_preferences( + self, scope: str, scope_id: str | None = None, key: str | None = None + ) -> list[Preference]: """Get all preferences for a specific scope ID or key.""" async with self.get_db() as session: session: AsyncSession @@ -649,9 +688,9 @@ async def get_preferences(self, scope, scope_id=None, key=None): if key is not None: query = query.where(Preference.key == key) result = await session.execute(query) - return result.scalars().all() + return result.scalars().all() # type:ignore - async def remove_preference(self, scope, scope_id, key): + async def remove_preference(self, scope: str, scope_id: str, key: str) -> None: """Remove a preference by scope ID and key.""" async with self.get_db() as session: session: AsyncSession @@ -665,7 +704,7 @@ async def remove_preference(self, scope, scope_id, key): ) await session.commit() - async def clear_preferences(self, scope, scope_id): + async def clear_preferences(self, scope: str, scope_id: str) -> None: """Clear all preferences for a specific scope ID.""" async with self.get_db() as session: session: AsyncSession @@ -692,7 +731,7 @@ async def _run_in_tx( return await fn(session) @staticmethod - def _apply_updates(model, **updates) -> None: + def _apply_updates(model: SQLModel, **updates: Unpack[UpdateKwargs]) -> None: for field, value in updates.items(): if value is not None: setattr(model, field, value) @@ -918,10 +957,10 @@ async def _op(session: AsyncSession) -> None: # Deprecated Methods # ==== - def get_base_stats(self, offset_sec=86400): + def get_base_stats(self, offset_sec: int = 86400) -> DeprecatedStats: """Get base statistics within the specified offset in seconds.""" - async def _inner(): + async def _inner() -> DeprecatedStats: async with self.get_db() as session: session: AsyncSession now = datetime.now() @@ -943,19 +982,19 @@ async def _inner(): result = None - def runner(): + def runner() -> None: nonlocal result result = asyncio.run(_inner()) t = threading.Thread(target=runner) t.start() t.join() - return result + return result # type:ignore - def get_total_message_count(self): + def get_total_message_count(self) -> int: """Get the total message count from platform statistics.""" - async def _inner(): + async def _inner() -> int: async with self.get_db() as session: session: AsyncSession result = await session.execute( @@ -966,18 +1005,18 @@ async def _inner(): result = None - def runner(): + def runner() -> None: nonlocal result result = asyncio.run(_inner()) t = threading.Thread(target=runner) t.start() t.join() - return result + return result # type:ignore - def get_grouped_base_stats(self, offset_sec=86400): + def get_grouped_base_stats(self, offset_sec: int = 86400) -> DeprecatedStats: # group by platform_id - async def _inner(): + async def _inner() -> DeprecatedStats: async with self.get_db() as session: session: AsyncSession now = datetime.now() @@ -1001,14 +1040,14 @@ async def _inner(): result = None - def runner(): + def runner() -> None: nonlocal result result = asyncio.run(_inner()) t = threading.Thread(target=runner) t.start() t.join() - return result + return result # type:ignore # ==== # Platform Session Management diff --git a/astrbot/core/db/vec_db/base.py b/astrbot/core/db/vec_db/base.py index 7440b6f2a..0ac3c608d 100644 --- a/astrbot/core/db/vec_db/base.py +++ b/astrbot/core/db/vec_db/base.py @@ -1,4 +1,5 @@ import abc +from collections.abc import Awaitable, Callable from dataclasses import dataclass @@ -9,7 +10,7 @@ class Result: class BaseVecDB: - async def initialize(self): + async def initialize(self) -> None: """初始化向量数据库""" @abc.abstractmethod @@ -31,7 +32,7 @@ async def insert_batch( batch_size: int = 32, tasks_limit: int = 3, max_retries: int = 3, - progress_callback=None, + progress_callback: Callable[[int, int], Awaitable[None]] | None = None, ) -> int: """批量插入文本和其对应向量,自动生成 ID 并保持一致性。 @@ -70,4 +71,4 @@ async def delete(self, doc_id: str) -> bool: ... @abc.abstractmethod - async def close(self): ... + async def close(self) -> None: ... diff --git a/astrbot/core/db/vec_db/faiss_impl/document_storage.py b/astrbot/core/db/vec_db/faiss_impl/document_storage.py index e27eb6fe8..db275006f 100644 --- a/astrbot/core/db/vec_db/faiss_impl/document_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/document_storage.py @@ -1,5 +1,6 @@ import json import os +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from datetime import datetime @@ -33,7 +34,7 @@ class Document(BaseDocModel, table=True): class DocumentStorage: - def __init__(self, db_path: str): + def __init__(self, db_path: str) -> None: self.db_path = db_path self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}" self.engine: AsyncEngine | None = None @@ -43,7 +44,7 @@ def __init__(self, db_path: str): "sqlite_init.sql", ) - async def initialize(self): + async def initialize(self) -> None: """Initialize the SQLite database and create the documents table if it doesn't exist.""" await self.connect() async with self.engine.begin() as conn: # type: ignore @@ -80,7 +81,7 @@ async def initialize(self): await conn.commit() - async def connect(self): + async def connect(self) -> None: """Connect to the SQLite database.""" if self.engine is None: self.engine = create_async_engine( @@ -95,7 +96,7 @@ async def connect(self): ) # type: ignore @asynccontextmanager - async def get_session(self): + async def get_session(self) -> AsyncGenerator[AsyncSession, None]: """Context manager for database sessions.""" async with self.async_session_maker() as session: # type: ignore yield session @@ -211,7 +212,7 @@ async def insert_documents_batch( await session.flush() # Flush to get all IDs return [doc.id for doc in documents] # type: ignore - async def delete_document_by_doc_id(self, doc_id: str): + async def delete_document_by_doc_id(self, doc_id: str) -> None: """Delete a document by its doc_id. Args: @@ -228,7 +229,7 @@ async def delete_document_by_doc_id(self, doc_id: str): if document: await session.delete(document) - async def get_document_by_doc_id(self, doc_id: str): + async def get_document_by_doc_id(self, doc_id: str) -> dict | None: """Retrieve a document by its doc_id. Args: @@ -249,7 +250,7 @@ async def get_document_by_doc_id(self, doc_id: str): return self._document_to_dict(document) return None - async def update_document_by_doc_id(self, doc_id: str, new_text: str): + async def update_document_by_doc_id(self, doc_id: str, new_text: str) -> None: """Update a document by its doc_id. Args: @@ -269,7 +270,7 @@ async def update_document_by_doc_id(self, doc_id: str, new_text: str): document.updated_at = datetime.now() session.add(document) - async def delete_documents(self, metadata_filters: dict): + async def delete_documents(self, metadata_filters: dict) -> None: """Delete documents by their metadata filters. Args: @@ -363,7 +364,7 @@ def _document_to_dict(self, document: Document) -> dict: else document.updated_at, } - async def tuple_to_dict(self, row): + async def tuple_to_dict(self, row: tuple) -> dict: """Convert a tuple to a dictionary. Args: @@ -384,7 +385,7 @@ async def tuple_to_dict(self, row): "updated_at": row[5], } - async def close(self): + async def close(self) -> None: """Close the connection to the SQLite database.""" if self.engine: await self.engine.dispose() diff --git a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py index 564454cb1..dc6977cf8 100644 --- a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py @@ -10,7 +10,7 @@ class EmbeddingStorage: - def __init__(self, dimension: int, path: str | None = None): + def __init__(self, dimension: int, path: str | None = None) -> None: self.dimension = dimension self.path = path self.index = None @@ -20,7 +20,7 @@ def __init__(self, dimension: int, path: str | None = None): base_index = faiss.IndexFlatL2(dimension) self.index = faiss.IndexIDMap(base_index) - async def insert(self, vector: np.ndarray, id: int): + async def insert(self, vector: np.ndarray, id: int) -> None: """插入向量 Args: @@ -38,7 +38,7 @@ async def insert(self, vector: np.ndarray, id: int): self.index.add_with_ids(vector.reshape(1, -1), np.array([id])) await self.save_index() - async def insert_batch(self, vectors: np.ndarray, ids: list[int]): + async def insert_batch(self, vectors: np.ndarray, ids: list[int]) -> None: """批量插入向量 Args: @@ -71,7 +71,7 @@ async def search(self, vector: np.ndarray, k: int) -> tuple: distances, indices = self.index.search(vector, k) return distances, indices - async def delete(self, ids: list[int]): + async def delete(self, ids: list[int]) -> None: """删除向量 Args: @@ -83,7 +83,7 @@ async def delete(self, ids: list[int]): self.index.remove_ids(id_array) await self.save_index() - async def save_index(self): + async def save_index(self) -> None: """保存索引 Args: diff --git a/astrbot/core/db/vec_db/faiss_impl/vec_db.py b/astrbot/core/db/vec_db/faiss_impl/vec_db.py index 14221f1e8..e6b9c209c 100644 --- a/astrbot/core/db/vec_db/faiss_impl/vec_db.py +++ b/astrbot/core/db/vec_db/faiss_impl/vec_db.py @@ -1,5 +1,6 @@ import time import uuid +from collections.abc import Awaitable, Callable import numpy as np @@ -20,7 +21,7 @@ def __init__( index_store_path: str, embedding_provider: EmbeddingProvider, rerank_provider: RerankProvider | None = None, - ): + ) -> None: self.doc_store_path = doc_store_path self.index_store_path = index_store_path self.embedding_provider = embedding_provider @@ -32,7 +33,7 @@ def __init__( self.embedding_provider = embedding_provider self.rerank_provider = rerank_provider - async def initialize(self): + async def initialize(self) -> None: await self.document_storage.initialize() async def insert( @@ -63,7 +64,7 @@ async def insert_batch( batch_size: int = 32, tasks_limit: int = 3, max_retries: int = 3, - progress_callback=None, + progress_callback: Callable[[int, int], Awaitable[None]] | None = None, ) -> list[int]: """批量插入文本和其对应向量,自动生成 ID 并保持一致性。 @@ -165,7 +166,7 @@ async def retrieve( return top_k_results - async def delete(self, doc_id: str): + async def delete(self, doc_id: str) -> None: """删除一条文档块(chunk)""" # 获得对应的 int id result = await self.document_storage.get_document_by_doc_id(doc_id) @@ -177,7 +178,7 @@ async def delete(self, doc_id: str): await self.document_storage.delete_document_by_doc_id(doc_id) await self.embedding_storage.delete([int_id]) - async def close(self): + async def close(self) -> None: await self.document_storage.close() async def count_documents(self, metadata_filter: dict | None = None) -> int: @@ -192,7 +193,7 @@ async def count_documents(self, metadata_filter: dict | None = None) -> int: ) return count - async def delete_documents(self, metadata_filters: dict): + async def delete_documents(self, metadata_filters: dict) -> None: """根据元数据过滤器删除文档""" docs = await self.document_storage.get_documents( metadata_filters=metadata_filters, diff --git a/astrbot/core/event_bus.py b/astrbot/core/event_bus.py index 0017e65fa..44cdccb83 100644 --- a/astrbot/core/event_bus.py +++ b/astrbot/core/event_bus.py @@ -28,13 +28,13 @@ def __init__( event_queue: Queue, pipeline_scheduler_mapping: dict[str, PipelineScheduler], astrbot_config_mgr: AstrBotConfigManager, - ): + ) -> None: self.event_queue = event_queue # 事件队列 # abconf uuid -> scheduler self.pipeline_scheduler_mapping = pipeline_scheduler_mapping self.astrbot_config_mgr = astrbot_config_mgr - async def dispatch(self): + async def dispatch(self) -> None: while True: event: AstrMessageEvent = await self.event_queue.get() conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin) @@ -47,7 +47,7 @@ async def dispatch(self): continue asyncio.create_task(scheduler.execute(event)) - def _print_event(self, event: AstrMessageEvent, conf_name: str): + def _print_event(self, event: AstrMessageEvent, conf_name: str) -> None: """用于记录事件信息 Args: diff --git a/astrbot/core/file_token_service.py b/astrbot/core/file_token_service.py index ea97759c1..42fbd23df 100644 --- a/astrbot/core/file_token_service.py +++ b/astrbot/core/file_token_service.py @@ -9,12 +9,12 @@ class FileTokenService: """维护一个简单的基于令牌的文件下载服务,支持超时和懒清除。""" - def __init__(self, default_timeout: float = 300): + def __init__(self, default_timeout: float = 300) -> None: self.lock = asyncio.Lock() self.staged_files = {} # token: (file_path, expire_time) self.default_timeout = default_timeout - async def _cleanup_expired_tokens(self): + async def _cleanup_expired_tokens(self) -> None: """清理过期的令牌""" now = time.time() expired_tokens = [ diff --git a/astrbot/core/initial_loader.py b/astrbot/core/initial_loader.py index f54d18641..3f836a4c4 100644 --- a/astrbot/core/initial_loader.py +++ b/astrbot/core/initial_loader.py @@ -17,13 +17,13 @@ class InitialLoader: """AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。""" - def __init__(self, db: BaseDatabase, log_broker: LogBroker): + def __init__(self, db: BaseDatabase, log_broker: LogBroker) -> None: self.db = db self.logger = logger self.log_broker = log_broker self.webui_dir: str | None = None - async def start(self): + async def start(self) -> None: core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db) try: diff --git a/astrbot/core/knowledge_base/chunking/base.py b/astrbot/core/knowledge_base/chunking/base.py index a45d86ad1..11ae0caba 100644 --- a/astrbot/core/knowledge_base/chunking/base.py +++ b/astrbot/core/knowledge_base/chunking/base.py @@ -13,7 +13,7 @@ class BaseChunker(ABC): """ @abstractmethod - async def chunk(self, text: str, **kwargs) -> list[str]: + async def chunk(self, text: str, **kwargs: object) -> list[str]: """将文本分块 Args: diff --git a/astrbot/core/knowledge_base/chunking/fixed_size.py b/astrbot/core/knowledge_base/chunking/fixed_size.py index 5439f070f..cd146f137 100644 --- a/astrbot/core/knowledge_base/chunking/fixed_size.py +++ b/astrbot/core/knowledge_base/chunking/fixed_size.py @@ -12,7 +12,7 @@ class FixedSizeChunker(BaseChunker): 按照固定的字符数分块,并支持块之间的重叠。 """ - def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50): + def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50) -> None: """初始化分块器 Args: @@ -23,7 +23,13 @@ def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50): self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap - async def chunk(self, text: str, **kwargs) -> list[str]: + async def chunk( + self, + text: str, + *, + chunk_size: int | None = None, + chunk_overlap: int | None = None, + ) -> list[str]: """固定大小分块 Args: @@ -35,8 +41,11 @@ async def chunk(self, text: str, **kwargs) -> list[str]: list[str]: 分块后的文本列表 """ - chunk_size = kwargs.get("chunk_size", self.chunk_size) - chunk_overlap = kwargs.get("chunk_overlap", self.chunk_overlap) + chunk_size = self.chunk_size if chunk_size is None else chunk_size + chunk_overlap = self.chunk_overlap if chunk_overlap is None else chunk_overlap + if chunk_size <= 0: + return [text] + chunk_overlap = max(0, min(chunk_overlap, chunk_size - 1)) chunks = [] start = 0 diff --git a/astrbot/core/knowledge_base/chunking/recursive.py b/astrbot/core/knowledge_base/chunking/recursive.py index 3882b0871..9417f65be 100644 --- a/astrbot/core/knowledge_base/chunking/recursive.py +++ b/astrbot/core/knowledge_base/chunking/recursive.py @@ -11,7 +11,7 @@ def __init__( length_function: Callable[[str], int] = len, is_separator_regex: bool = False, separators: list[str] | None = None, - ): + ) -> None: """初始化递归字符文本分割器 Args: @@ -39,7 +39,13 @@ def __init__( "", # 字符 ] - async def chunk(self, text: str, **kwargs) -> list[str]: + async def chunk( + self, + text: str, + *, + chunk_size: int | None = None, + chunk_overlap: int | None = None, + ) -> list[str]: """递归地将文本分割成块 Args: @@ -54,8 +60,11 @@ async def chunk(self, text: str, **kwargs) -> list[str]: if not text: return [] - overlap = kwargs.get("chunk_overlap", self.chunk_overlap) - chunk_size = kwargs.get("chunk_size", self.chunk_size) + overlap = self.chunk_overlap if chunk_overlap is None else chunk_overlap + chunk_size = self.chunk_size if chunk_size is None else chunk_size + if chunk_size <= 0: + return [text] + overlap = max(0, min(overlap, chunk_size - 1)) text_length = self.length_function(text) if text_length <= chunk_size: diff --git a/astrbot/core/knowledge_base/kb_db_sqlite.py b/astrbot/core/knowledge_base/kb_db_sqlite.py index 5e1db842f..fa8809cb8 100644 --- a/astrbot/core/knowledge_base/kb_db_sqlite.py +++ b/astrbot/core/knowledge_base/kb_db_sqlite.py @@ -1,3 +1,4 @@ +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from pathlib import Path @@ -46,7 +47,7 @@ def __init__(self, db_path: str = "data/knowledge_base/kb.db") -> None: ) @asynccontextmanager - async def get_db(self): + async def get_db(self) -> AsyncGenerator[AsyncSession, None]: """获取数据库会话 用法: @@ -253,7 +254,7 @@ async def get_document_with_metadata(self, doc_id: str) -> dict | None: "knowledge_base": row[1], } - async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB): + async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB) -> None: """删除单个文档及其相关数据""" # 在知识库表中删除 async with self.get_db() as session, session.begin(): diff --git a/astrbot/core/knowledge_base/kb_helper.py b/astrbot/core/knowledge_base/kb_helper.py index 4adfb60b8..57dad4a03 100644 --- a/astrbot/core/knowledge_base/kb_helper.py +++ b/astrbot/core/knowledge_base/kb_helper.py @@ -3,7 +3,9 @@ import re import time import uuid +from collections.abc import Awaitable, Callable from pathlib import Path +from types import TracebackType import aiofiles @@ -31,14 +33,14 @@ class RateLimiter: """一个简单的速率限制器""" - def __init__(self, max_rpm: int): + def __init__(self, max_rpm: int) -> None: self.max_per_minute = max_rpm self.interval = 60.0 / max_rpm if max_rpm > 0 else 0 self.last_call_time = 0 - async def __aenter__(self): + async def __aenter__(self) -> "RateLimiter": if self.interval == 0: - return + return self now = time.monotonic() elapsed = now - self.last_call_time @@ -47,8 +49,14 @@ async def __aenter__(self): await asyncio.sleep(self.interval - elapsed) self.last_call_time = time.monotonic() + return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: pass @@ -116,7 +124,7 @@ def __init__( provider_manager: ProviderManager, kb_root_dir: str, chunker: BaseChunker, - ): + ) -> None: self.kb_db = kb_db self.kb = kb self.prov_mgr = provider_manager @@ -130,7 +138,7 @@ def __init__( self.kb_medias_dir.mkdir(parents=True, exist_ok=True) self.kb_files_dir.mkdir(parents=True, exist_ok=True) - async def initialize(self): + async def initialize(self) -> None: await self._ensure_vec_db() async def get_ep(self) -> EmbeddingProvider: @@ -174,7 +182,7 @@ async def _ensure_vec_db(self) -> FaissVecDB: self.vec_db = vec_db return vec_db - async def delete_vec_db(self): + async def delete_vec_db(self) -> None: """删除知识库的向量数据库和所有相关文件""" import shutil @@ -182,7 +190,7 @@ async def delete_vec_db(self): if self.kb_dir.exists(): shutil.rmtree(self.kb_dir) - async def terminate(self): + async def terminate(self) -> None: if self.vec_db: await self.vec_db.close() @@ -196,7 +204,7 @@ async def upload_document( batch_size: int = 32, tasks_limit: int = 3, max_retries: int = 3, - progress_callback=None, + progress_callback: Callable[[str, int, int], Awaitable[None]] | None = None, pre_chunked_text: list[str] | None = None, ) -> KBDocument: """上传并处理文档(带原子性保证和失败清理) @@ -293,7 +301,7 @@ async def upload_document( await progress_callback("chunking", 100, 100) # 阶段3: 生成向量(带进度回调) - async def embedding_progress_callback(current, total): + async def embedding_progress_callback(current: int, total: int) -> None: if progress_callback: await progress_callback("embedding", current, total) @@ -360,7 +368,7 @@ async def get_document(self, doc_id: str) -> KBDocument | None: doc = await self.kb_db.get_document_by_id(doc_id) return doc - async def delete_document(self, doc_id: str): + async def delete_document(self, doc_id: str) -> None: """删除单个文档及其相关数据""" await self.kb_db.delete_document_by_id( doc_id=doc_id, @@ -372,7 +380,7 @@ async def delete_document(self, doc_id: str): ) await self.refresh_kb() - async def delete_chunk(self, chunk_id: str, doc_id: str): + async def delete_chunk(self, chunk_id: str, doc_id: str) -> None: """删除单个文本块及其相关数据""" vec_db: FaissVecDB = self.vec_db # type: ignore await vec_db.delete(chunk_id) @@ -383,7 +391,7 @@ async def delete_chunk(self, chunk_id: str, doc_id: str): await self.refresh_kb() await self.refresh_document(doc_id) - async def refresh_kb(self): + async def refresh_kb(self) -> None: if self.kb: kb = await self.kb_db.get_kb_by_id(self.kb.kb_id) if kb: @@ -475,7 +483,7 @@ async def upload_from_url( batch_size: int = 32, tasks_limit: int = 3, max_retries: int = 3, - progress_callback=None, + progress_callback: Callable[[str, int, int], Awaitable[None]] | None = None, enable_cleaning: bool = False, cleaning_provider_id: str | None = None, ) -> KBDocument: @@ -562,7 +570,7 @@ async def _clean_and_rechunk_content( self, content: str, url: str, - progress_callback=None, + progress_callback: Callable[[str, int, int], Awaitable[None]] | None = None, enable_cleaning: bool = False, cleaning_provider_id: str | None = None, repair_max_rpm: int = 60, diff --git a/astrbot/core/knowledge_base/kb_mgr.py b/astrbot/core/knowledge_base/kb_mgr.py index 2219cc00b..7a95f5a3f 100644 --- a/astrbot/core/knowledge_base/kb_mgr.py +++ b/astrbot/core/knowledge_base/kb_mgr.py @@ -1,4 +1,5 @@ import traceback +from collections.abc import Awaitable, Callable from pathlib import Path from astrbot.core import logger @@ -26,14 +27,14 @@ class KnowledgeBaseManager: def __init__( self, provider_manager: ProviderManager, - ): + ) -> None: Path(DB_PATH).parent.mkdir(parents=True, exist_ok=True) self.provider_manager = provider_manager self._session_deleted_callback_registered = False self.kb_insts: dict[str, KBHelper] = {} - async def initialize(self): + async def initialize(self) -> None: """初始化知识库模块""" try: logger.info("正在初始化知识库模块...") @@ -58,13 +59,13 @@ async def initialize(self): logger.error(f"知识库模块初始化失败: {e}") logger.error(traceback.format_exc()) - async def _init_kb_database(self): + async def _init_kb_database(self) -> None: self.kb_db = KBSQLiteDatabase(DB_PATH.as_posix()) await self.kb_db.initialize() await self.kb_db.migrate_to_v1() logger.info(f"KnowledgeBase database initialized: {DB_PATH}") - async def load_kbs(self): + async def load_kbs(self) -> None: """加载所有知识库实例""" kb_records = await self.kb_db.list_kbs() for record in kb_records: @@ -268,7 +269,7 @@ def _format_context(self, results: list[RetrievalResult]) -> str: return "\n".join(lines) - async def terminate(self): + async def terminate(self) -> None: """终止所有知识库实例,关闭数据库连接""" for kb_id, kb_helper in self.kb_insts.items(): try: @@ -294,7 +295,7 @@ async def upload_from_url( batch_size: int = 32, tasks_limit: int = 3, max_retries: int = 3, - progress_callback=None, + progress_callback: Callable[[str, int, int], Awaitable[None]] | None = None, ) -> KBDocument: """从 URL 上传文档到指定的知识库 diff --git a/astrbot/core/knowledge_base/parsers/url_parser.py b/astrbot/core/knowledge_base/parsers/url_parser.py index f68e2e0c4..2867164a9 100644 --- a/astrbot/core/knowledge_base/parsers/url_parser.py +++ b/astrbot/core/knowledge_base/parsers/url_parser.py @@ -6,7 +6,7 @@ class URLExtractor: """URL 内容提取器,封装了 Tavily API 调用和密钥管理""" - def __init__(self, tavily_keys: list[str]): + def __init__(self, tavily_keys: list[str]) -> None: """ 初始化 URL 提取器 diff --git a/astrbot/core/knowledge_base/retrieval/manager.py b/astrbot/core/knowledge_base/retrieval/manager.py index 746406e90..a90cbef11 100644 --- a/astrbot/core/knowledge_base/retrieval/manager.py +++ b/astrbot/core/knowledge_base/retrieval/manager.py @@ -44,7 +44,7 @@ def __init__( sparse_retriever: SparseRetriever, rank_fusion: RankFusion, kb_db: KBSQLiteDatabase, - ): + ) -> None: """初始化检索管理器 Args: @@ -195,7 +195,7 @@ async def _dense_retrieve( query: str, kb_ids: list[str], kb_options: dict, - ): + ) -> list[Result]: """稠密检索 (向量相似度) 为每个知识库使用独立的向量数据库进行检索,然后合并结果。 diff --git a/astrbot/core/knowledge_base/retrieval/rank_fusion.py b/astrbot/core/knowledge_base/retrieval/rank_fusion.py index 26203f94b..40afd9748 100644 --- a/astrbot/core/knowledge_base/retrieval/rank_fusion.py +++ b/astrbot/core/knowledge_base/retrieval/rank_fusion.py @@ -31,7 +31,7 @@ class RankFusion: - 使用 Reciprocal Rank Fusion (RRF) 算法 """ - def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60): + def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60) -> None: """初始化结果融合器 Args: diff --git a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py index ea5da1c9e..d453251d1 100644 --- a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py +++ b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py @@ -34,7 +34,7 @@ class SparseRetriever: - 使用 BM25 算法计算相关度 """ - def __init__(self, kb_db: KBSQLiteDatabase): + def __init__(self, kb_db: KBSQLiteDatabase) -> None: """初始化稀疏检索器 Args: diff --git a/astrbot/core/log.py b/astrbot/core/log.py index a70fdbf01..a08e19639 100644 --- a/astrbot/core/log.py +++ b/astrbot/core/log.py @@ -44,7 +44,7 @@ } -def is_plugin_path(pathname): +def is_plugin_path(pathname: str) -> bool: """检查文件路径是否来自插件目录 Args: @@ -61,7 +61,7 @@ def is_plugin_path(pathname): return ("data/plugins" in norm_path) or ("astrbot/builtin_stars/" in norm_path) -def get_short_level_name(level_name): +def get_short_level_name(level_name: str) -> str: """将日志级别名称转换为四个字母的缩写 Args: @@ -87,7 +87,7 @@ class LogBroker: 发布-订阅模式 """ - def __init__(self): + def __init__(self) -> None: self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志 self.subscribers: list[Queue] = [] # 订阅者列表 @@ -102,7 +102,7 @@ def register(self) -> Queue: self.subscribers.append(q) return q - def unregister(self, q: Queue): + def unregister(self, q: Queue) -> None: """取消订阅 Args: @@ -111,7 +111,7 @@ def unregister(self, q: Queue): """ self.subscribers.remove(q) - def publish(self, log_entry: dict): + def publish(self, log_entry: dict) -> None: """发布新日志到所有订阅者, 使用非阻塞方式投递, 避免一个订阅者阻塞整个系统 Args: @@ -133,11 +133,11 @@ class LogQueueHandler(logging.Handler): 继承自 logging.Handler """ - def __init__(self, log_broker: LogBroker): + def __init__(self, log_broker: LogBroker) -> None: super().__init__() self.log_broker = log_broker - def emit(self, record): + def emit(self, record: logging.LogRecord) -> None: """日志处理的入口方法, 接受一个日志记录, 转换为字符串后由 LogBroker 发布 这个方法会在每次日志记录时被调用 @@ -162,7 +162,7 @@ class LogManager: """ @classmethod - def GetLogger(cls, log_name: str = "default"): + def GetLogger(cls, log_name: str = "default") -> logging.Logger: """获取指定名称的日志记录器logger Args: @@ -194,7 +194,7 @@ def GetLogger(cls, log_name: str = "default"): class PluginFilter(logging.Filter): """插件过滤器类, 用于标记日志来源是插件还是核心组件""" - def filter(self, record): + def filter(self, record: logging.LogRecord) -> bool: record.plugin_tag = ( "[Plug]" if is_plugin_path(record.pathname) else "[Core]" ) @@ -206,7 +206,7 @@ class FileNameFilter(logging.Filter): """ # 获取这个文件和父文件夹的名字:. 并且去除 .py - def filter(self, record): + def filter(self, record: logging.LogRecord) -> bool: dirname = os.path.dirname(record.pathname) record.filename = ( os.path.basename(dirname) @@ -219,7 +219,7 @@ class LevelNameFilter(logging.Filter): """短日志级别名称过滤器类, 用于将日志级别名称转换为四个字母的缩写""" # 添加短日志级别名称 - def filter(self, record): + def filter(self, record: logging.LogRecord) -> bool: record.short_levelname = get_short_level_name(record.levelname) return True @@ -233,7 +233,7 @@ def filter(self, record): return logger @classmethod - def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker): + def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker) -> None: """设置队列处理器, 用于将日志消息发送到 LogBroker Args: diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 050e36521..dcc808f4f 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -25,6 +25,7 @@ import base64 import json import os +import typing as T import uuid from enum import Enum @@ -66,10 +67,10 @@ class ComponentType(str, Enum): class BaseMessageComponent(BaseModel): type: ComponentType - def __init__(self, **kwargs): + def __init__(self, **kwargs: object) -> None: super().__init__(**kwargs) - def toDict(self): + def toDict(self) -> dict: data = {} for k, v in self.__dict__.items(): if k == "type" or v is None: @@ -89,13 +90,13 @@ class Plain(BaseMessageComponent): text: str convert: bool | None = True - def __init__(self, text: str, convert: bool = True, **_): + def __init__(self, text: str, convert: bool = True, **_: object) -> None: super().__init__(text=text, convert=convert, **_) - def toDict(self): + def toDict(self) -> dict: return {"type": "text", "data": {"text": self.text.strip()}} - async def to_dict(self): + async def to_dict(self) -> dict: return {"type": "text", "data": {"text": self.text}} @@ -103,7 +104,7 @@ class Face(BaseMessageComponent): type = ComponentType.Face id: int - def __init__(self, **_): + def __init__(self, **_: object) -> None: super().__init__(**_) @@ -118,7 +119,7 @@ class Record(BaseMessageComponent): # 额外 path: str | None - def __init__(self, file: str | None, **_): + def __init__(self, file: str | None, **_: object) -> None: for k in _: if k == "url": pass @@ -126,17 +127,17 @@ def __init__(self, file: str | None, **_): super().__init__(file=file, **_) @staticmethod - def fromFileSystem(path, **_): + def fromFileSystem(path: str, **_: object) -> "Record": return Record(file=f"file:///{os.path.abspath(path)}", path=path, **_) @staticmethod - def fromURL(url: str, **_): + def fromURL(url: str, **_: object) -> "Record": if url.startswith("http://") or url.startswith("https://"): return Record(file=url, **_) raise Exception("not a valid url") @staticmethod - def fromBase64(bs64_data: str, **_): + def fromBase64(bs64_data: str, **_: object) -> "Record": return Record(file=f"base64://{bs64_data}", **_) async def convert_to_file_path(self) -> str: @@ -221,15 +222,15 @@ class Video(BaseMessageComponent): # 额外 path: str | None = "" - def __init__(self, file: str, **_): + def __init__(self, file: str, **_: object) -> None: super().__init__(file=file, **_) @staticmethod - def fromFileSystem(path, **_): + def fromFileSystem(path: str, **_: object) -> "Video": return Video(file=f"file:///{os.path.abspath(path)}", path=path, **_) @staticmethod - def fromURL(url: str, **_): + def fromURL(url: str, **_: object) -> "Video": if url.startswith("http://") or url.startswith("https://"): return Video(file=url, **_) raise Exception("not a valid url") @@ -255,7 +256,7 @@ async def convert_to_file_path(self) -> str: return os.path.abspath(url) raise Exception(f"not a valid file: {url}") - async def register_to_file_service(self): + async def register_to_file_service(self) -> str: """将视频注册到文件服务。 Returns: @@ -278,7 +279,7 @@ async def register_to_file_service(self): return f"{callback_host}/api/file/{token}" - async def to_dict(self): + async def to_dict(self) -> dict: """需要和 toDict 区分开,toDict 是同步方法""" url_or_path = self.file if url_or_path.startswith("http"): @@ -303,10 +304,10 @@ class At(BaseMessageComponent): qq: int | str # 此处str为all时代表所有人 name: str | None = "" - def __init__(self, **_): + def __init__(self, **_: object) -> None: super().__init__(**_) - def toDict(self): + def toDict(self) -> dict: return { "type": "at", "data": {"qq": str(self.qq)}, @@ -316,28 +317,28 @@ def toDict(self): class AtAll(At): qq: str = "all" - def __init__(self, **_): + def __init__(self, **_: object) -> None: super().__init__(**_) class RPS(BaseMessageComponent): # TODO type = ComponentType.RPS - def __init__(self, **_): + def __init__(self, **_: object) -> None: super().__init__(**_) class Dice(BaseMessageComponent): # TODO type = ComponentType.Dice - def __init__(self, **_): + def __init__(self, **_: object) -> None: super().__init__(**_) class Shake(BaseMessageComponent): # TODO type = ComponentType.Shake - def __init__(self, **_): + def __init__(self, **_: object) -> None: super().__init__(**_) @@ -348,7 +349,7 @@ class Share(BaseMessageComponent): content: str | None = "" image: str | None = "" - def __init__(self, **_): + def __init__(self, **_: object) -> None: super().__init__(**_) @@ -357,7 +358,7 @@ class Contact(BaseMessageComponent): # TODO _type: str # type 字段冲突 id: int | None = 0 - def __init__(self, **_): + def __init__(self, **_: object) -> None: super().__init__(**_) @@ -368,7 +369,7 @@ class Location(BaseMessageComponent): # TODO title: str | None = "" content: str | None = "" - def __init__(self, **_): + def __init__(self, **_: object) -> None: super().__init__(**_) @@ -382,7 +383,7 @@ class Music(BaseMessageComponent): content: str | None = "" image: str | None = "" - def __init__(self, **_): + def __init__(self, **_: object) -> None: # for k in _.keys(): # if k == "_type" and _[k] not in ["qq", "163", "xm", "custom"]: # logger.warn(f"Protocol: {k}={_[k]} doesn't match values") @@ -402,29 +403,29 @@ class Image(BaseMessageComponent): path: str | None = "" file_unique: str | None = "" # 某些平台可能有图片缓存的唯一标识 - def __init__(self, file: str | None, **_): + def __init__(self, file: str | None, **_: object) -> None: super().__init__(file=file, **_) @staticmethod - def fromURL(url: str, **_): + def fromURL(url: str, **_: object) -> "Image": if url.startswith("http://") or url.startswith("https://"): return Image(file=url, **_) raise Exception("not a valid url") @staticmethod - def fromFileSystem(path, **_): + def fromFileSystem(path: str, **_: object) -> "Image": return Image(file=f"file:///{os.path.abspath(path)}", path=path, **_) @staticmethod - def fromBase64(base64: str, **_): + def fromBase64(base64: str, **_: object) -> "Image": return Image(f"base64://{base64}", **_) @staticmethod - def fromBytes(byte: bytes): + def fromBytes(byte: bytes) -> "Image": return Image.fromBase64(base64.b64encode(byte).decode()) @staticmethod - def fromIO(IO): + def fromIO(IO: T.BinaryIO) -> "Image": return Image.fromBytes(IO.read()) async def convert_to_file_path(self) -> str: @@ -525,16 +526,16 @@ class Reply(BaseMessageComponent): seq: int | None = 0 """deprecated""" - def __init__(self, **_): + def __init__(self, **_: object) -> None: super().__init__(**_) class Poke(BaseMessageComponent): - type: str = ComponentType.Poke + type = ComponentType.Poke id: int | None = 0 qq: int | None = 0 - def __init__(self, type: str, **_): + def __init__(self, type: str, **_: object) -> None: type = f"Poke:{type}" super().__init__(type=type, **_) @@ -543,7 +544,7 @@ class Forward(BaseMessageComponent): type = ComponentType.Forward id: str - def __init__(self, **_): + def __init__(self, **_: object) -> None: super().__init__(**_) @@ -558,13 +559,13 @@ class Node(BaseMessageComponent): seq: str | list | None = "" # 忽略 time: int | None = 0 # 忽略 - def __init__(self, content: list[BaseMessageComponent], **_): + def __init__(self, content: list[BaseMessageComponent], **_: object) -> None: if isinstance(content, Node): # back content = [content] super().__init__(content=content, **_) - async def to_dict(self): + async def to_dict(self) -> dict: data_content = [] for comp in self.content: if isinstance(comp, (Image, Record)): @@ -605,10 +606,10 @@ class Nodes(BaseMessageComponent): type = ComponentType.Nodes nodes: list[Node] - def __init__(self, nodes: list[Node], **_): + def __init__(self, nodes: list[Node], **_: object) -> None: super().__init__(nodes=nodes, **_) - def toDict(self): + def toDict(self) -> dict: """Deprecated. Use to_dict instead""" ret = { "messages": [], @@ -631,7 +632,7 @@ class Json(BaseMessageComponent): type = ComponentType.Json data: dict - def __init__(self, data: str | dict, **_): + def __init__(self, data: str | dict, **_: object) -> None: if isinstance(data, str): data = json.loads(data) super().__init__(data=data, **_) @@ -650,7 +651,7 @@ class File(BaseMessageComponent): file_: str | None = "" # 本地路径 url: str | None = "" # url - def __init__(self, name: str, file: str = "", url: str = ""): + def __init__(self, name: str, file: str = "", url: str = "") -> None: """文件消息段。""" super().__init__(name=name, file_=file, url=url) @@ -686,7 +687,7 @@ def file(self) -> str: return "" @file.setter - def file(self, value: str): + def file(self, value: str) -> None: """向前兼容, 设置file属性, 传入的参数可能是文件路径或URL Args: @@ -721,7 +722,7 @@ async def get_file(self, allow_return_url: bool = False) -> str: return "" - async def _download_file(self): + async def _download_file(self) -> None: """下载文件""" if not self.url: raise ValueError("Download failed: No URL provided in File component.") @@ -736,7 +737,7 @@ async def _download_file(self): await download_file(self.url, file_path) self.file_ = os.path.abspath(file_path) - async def register_to_file_service(self): + async def register_to_file_service(self) -> str: """将文件注册到文件服务。 Returns: @@ -759,7 +760,7 @@ async def register_to_file_service(self): return f"{callback_host}/api/file/{token}" - async def to_dict(self): + async def to_dict(self) -> dict: """需要和 toDict 区分开,toDict 是同步方法""" url_or_path = await self.get_file(allow_return_url=True) if url_or_path.startswith("http"): @@ -786,7 +787,7 @@ class WechatEmoji(BaseMessageComponent): md5_len: int | None = 0 cdnurl: str | None = "" - def __init__(self, **_): + def __init__(self, **_: object) -> None: super().__init__(**_) diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index ed4e25f43..2e1527e27 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -2,7 +2,7 @@ from collections.abc import AsyncGenerator from dataclasses import dataclass, field -from typing_extensions import deprecated +from typing_extensions import Self, deprecated from astrbot.core.message.components import ( At, @@ -29,7 +29,7 @@ class MessageChain: type: str | None = None """消息链承载的消息的类型。可选,用于让消息平台区分不同业务场景的消息链。""" - def message(self, message: str): + def message(self, message: str) -> Self: """添加一条文本消息到消息链 `chain` 中。 Example: @@ -40,7 +40,7 @@ def message(self, message: str): self.chain.append(Plain(message)) return self - def at(self, name: str, qq: str | int): + def at(self, name: str, qq: str | int) -> Self: """添加一条 At 消息到消息链 `chain` 中。 Example: @@ -51,7 +51,7 @@ def at(self, name: str, qq: str | int): self.chain.append(At(name=name, qq=qq)) return self - def at_all(self): + def at_all(self) -> Self: """添加一条 AtAll 消息到消息链 `chain` 中。 Example: @@ -63,7 +63,7 @@ def at_all(self): return self @deprecated("请使用 message 方法代替。") - def error(self, message: str): + def error(self, message: str) -> Self: """添加一条错误消息到消息链 `chain` 中 Example: @@ -73,7 +73,7 @@ def error(self, message: str): self.chain.append(Plain(message)) return self - def url_image(self, url: str): + def url_image(self, url: str) -> Self: """添加一条图片消息(https 链接)到消息链 `chain` 中。 Note: @@ -86,7 +86,7 @@ def url_image(self, url: str): self.chain.append(Image.fromURL(url)) return self - def file_image(self, path: str): + def file_image(self, path: str) -> Self: """添加一条图片消息(本地文件路径)到消息链 `chain` 中。 Note: @@ -98,7 +98,7 @@ def file_image(self, path: str): self.chain.append(Image.fromFileSystem(path)) return self - def base64_image(self, base64_str: str): + def base64_image(self, base64_str: str) -> Self: """添加一条图片消息(base64 编码字符串)到消息链 `chain` 中。 Example: @@ -107,7 +107,7 @@ def base64_image(self, base64_str: str): self.chain.append(Image.fromBase64(base64_str)) return self - def use_t2i(self, use_t2i: bool): + def use_t2i(self, use_t2i: bool) -> Self: """设置是否使用文本转图片服务。 Args: @@ -121,7 +121,7 @@ def get_plain_text(self) -> str: """获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。""" return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)]) - def squash_plain(self): + def squash_plain(self) -> Self | None: """将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。""" if not self.chain: return None @@ -195,12 +195,12 @@ class MessageEventResult(MessageChain): async_stream: AsyncGenerator | None = None """异步流""" - def stop_event(self) -> "MessageEventResult": + def stop_event(self) -> Self: """终止事件传播。""" self.result_type = EventResultType.STOP return self - def continue_event(self) -> "MessageEventResult": + def continue_event(self) -> Self: """继续事件传播。""" self.result_type = EventResultType.CONTINUE return self @@ -209,12 +209,12 @@ def is_stopped(self) -> bool: """是否终止事件传播。""" return self.result_type == EventResultType.STOP - def set_async_stream(self, stream: AsyncGenerator) -> "MessageEventResult": + def set_async_stream(self, stream: AsyncGenerator) -> Self: """设置异步流。""" self.async_stream = stream return self - def set_result_content_type(self, typ: ResultContentType) -> "MessageEventResult": + def set_result_content_type(self, typ: ResultContentType) -> Self: """设置事件处理的结果类型。 Args: diff --git a/astrbot/core/persona_mgr.py b/astrbot/core/persona_mgr.py index b2d2c6be1..46cf7922c 100644 --- a/astrbot/core/persona_mgr.py +++ b/astrbot/core/persona_mgr.py @@ -16,7 +16,7 @@ class PersonaManager: - def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager): + def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager) -> None: self.db = db_helper self.acm = acm default_ps = acm.default_conf.get("provider_settings", {}) @@ -28,12 +28,12 @@ def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager): self.selected_default_persona_v3: Personality | None = None self.persona_v3_config: list[dict] = [] - async def initialize(self): + async def initialize(self) -> None: self.personas = await self.get_all_personas() self.get_v3_persona_data() logger.info(f"已加载 {len(self.personas)} 个人格。") - async def get_persona(self, persona_id: str): + async def get_persona(self, persona_id: str) -> Persona: """获取指定 persona 的信息""" persona = await self.db.get_persona_by_id(persona_id) if not persona: @@ -57,7 +57,7 @@ async def get_default_persona_v3( except Exception: return DEFAULT_PERSONALITY - async def delete_persona(self, persona_id: str): + async def delete_persona(self, persona_id: str) -> None: """删除指定 persona""" if not await self.db.get_persona_by_id(persona_id): raise ValueError(f"Persona with ID {persona_id} does not exist.") @@ -71,7 +71,7 @@ async def update_persona( system_prompt: str | None = None, begin_dialogs: list[str] | None = None, tools: list[str] | None = None, - ): + ) -> Persona | None: """更新指定 persona 的信息。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具""" existing_persona = await self.db.get_persona_by_id(persona_id) if not existing_persona: diff --git a/astrbot/core/pipeline/content_safety_check/stage.py b/astrbot/core/pipeline/content_safety_check/stage.py index b089c48e0..19037eb08 100644 --- a/astrbot/core/pipeline/content_safety_check/stage.py +++ b/astrbot/core/pipeline/content_safety_check/stage.py @@ -16,7 +16,7 @@ class ContentSafetyCheckStage(Stage): 当前只会检查文本的。 """ - async def initialize(self, ctx: PipelineContext): + async def initialize(self, ctx: PipelineContext) -> None: config = ctx.astrbot_config["content_safety"] self.strategy_selector = StrategySelector(config) diff --git a/astrbot/core/pipeline/context_utils.py b/astrbot/core/pipeline/context_utils.py index 1f5ba43a0..49fe3be3a 100644 --- a/astrbot/core/pipeline/context_utils.py +++ b/astrbot/core/pipeline/context_utils.py @@ -12,8 +12,8 @@ async def call_handler( event: AstrMessageEvent, handler: T.Callable[..., T.Awaitable[T.Any] | T.AsyncGenerator[T.Any, None]], - *args, - **kwargs, + *args: object, + **kwargs: object, ) -> T.AsyncGenerator[T.Any, None]: """执行事件处理函数并处理其返回结果 @@ -75,8 +75,8 @@ async def call_handler( async def call_event_hook( event: AstrMessageEvent, hook_type: EventType, - *args, - **kwargs, + *args: object, + **kwargs: object, ) -> bool: """调用事件钩子函数 diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 69bd04314..1a3ec31c9 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -22,6 +22,7 @@ LLMResponse, ProviderRequest, ) +from astrbot.core.provider.provider import Providers from astrbot.core.star.star_handler import EventType, star_map from astrbot.core.utils.file_extract import extract_file_moonshotai from astrbot.core.utils.llm_metadata import LLM_METADATAS @@ -82,7 +83,7 @@ async def initialize(self, ctx: PipelineContext) -> None: self.conv_manager = ctx.plugin_manager.context.conversation_manager - def _select_provider(self, event: AstrMessageEvent): + def _select_provider(self, event: AstrMessageEvent) -> Providers | None: """选择使用的 LLM 提供商""" sel_provider = event.get_extra("selected_provider") _ctx = self.ctx.plugin_manager.context @@ -114,7 +115,7 @@ async def _apply_kb( self, event: AstrMessageEvent, req: ProviderRequest, - ): + ) -> None: """Apply knowledge base context to the provider request""" if not self.kb_agentic_mode: if req.prompt is None: @@ -142,7 +143,7 @@ async def _apply_file_extract( self, event: AstrMessageEvent, req: ProviderRequest, - ): + ) -> None: """Apply file extract to the provider request""" file_paths = [] file_names = [] @@ -186,7 +187,7 @@ def _modalities_fix( self, provider: Provider, req: ProviderRequest, - ): + ) -> None: """检查提供商的模态能力,清理请求中的不支持内容""" if req.image_urls: provider_cfg = provider.provider_config.get("modalities", ["image"]) @@ -206,7 +207,7 @@ def _plugin_tool_fix( self, event: AstrMessageEvent, req: ProviderRequest, - ): + ) -> None: """根据事件中的插件设置,过滤请求中的工具列表""" if event.plugins_name is not None and req.func_tool: new_tool_set = ToolSet() @@ -226,7 +227,7 @@ async def _handle_webchat( event: AstrMessageEvent, req: ProviderRequest, prov: Provider, - ): + ) -> None: """处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title""" if not req.conversation: return @@ -284,7 +285,7 @@ async def _save_to_history( llm_response: LLMResponse | None, all_messages: list[Message], runner_stats: AgentStats | None, - ): + ) -> None: if ( not req or not req.conversation diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py index 8a79b96c9..7f8df08cf 100644 --- a/astrbot/core/pipeline/process_stage/method/star_request.py +++ b/astrbot/core/pipeline/process_stage/method/star_request.py @@ -25,10 +25,10 @@ async def process( event: AstrMessageEvent, ) -> AsyncGenerator[Any, None]: activated_handlers: list[StarHandlerMetadata] = event.get_extra( - "activated_handlers", + "activated_handlers", [] ) handlers_parsed_params: dict[str, dict[str, Any]] = event.get_extra( - "handlers_parsed_params", + "handlers_parsed_params", {} ) if not handlers_parsed_params: handlers_parsed_params = {} diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py index 076f7f12a..2093f7876 100644 --- a/astrbot/core/pipeline/process_stage/stage.py +++ b/astrbot/core/pipeline/process_stage/stage.py @@ -31,7 +31,7 @@ async def process( ) -> None | AsyncGenerator[None, None]: """处理事件""" activated_handlers: list[StarHandlerMetadata] = event.get_extra( - "activated_handlers", + "activated_handlers", [] ) # 有插件 Handler 被激活 if activated_handlers: diff --git a/astrbot/core/pipeline/process_stage/utils.py b/astrbot/core/pipeline/process_stage/utils.py index 24e052e1e..d6ad8f2c1 100644 --- a/astrbot/core/pipeline/process_stage/utils.py +++ b/astrbot/core/pipeline/process_stage/utils.py @@ -31,13 +31,17 @@ class KnowledgeBaseQueryTool(FunctionTool[AstrAgentContext]): ) async def call( - self, context: ContextWrapper[AstrAgentContext], **kwargs + self, context: ContextWrapper[AstrAgentContext], **kwargs: object ) -> ToolExecResult: query = kwargs.get("query", "") if not query: return "error: Query parameter is empty." + + # 显式转换为 str,解决类型检查报错 "object cannot be assigned to str" + query_str = str(query) + result = await retrieve_knowledge_base( - query=kwargs.get("query", ""), + query=query_str, umo=context.context.event.unified_msg_origin, context=context.context.context, ) diff --git a/astrbot/core/pipeline/rate_limit_check/stage.py b/astrbot/core/pipeline/rate_limit_check/stage.py index 64e21dd7e..392bceff3 100644 --- a/astrbot/core/pipeline/rate_limit_check/stage.py +++ b/astrbot/core/pipeline/rate_limit_check/stage.py @@ -19,7 +19,7 @@ class RateLimitStage(Stage): 如果触发限流,将 stall 流水线,直到下一个时间窗口来临时自动唤醒。 """ - def __init__(self): + def __init__(self) -> None: # 存储每个会话的请求时间队列 self.event_timestamps: defaultdict[str, deque[datetime]] = defaultdict(deque) # 为每个会话设置一个锁,避免并发冲突 diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index 60ab168b3..129277d11 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -35,7 +35,7 @@ class RespondStage(Stage): Comp.WechatEmoji: lambda comp: comp.md5 is not None, # 微信表情 } - async def initialize(self, ctx: PipelineContext): + async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx self.config = ctx.astrbot_config self.platform_settings: dict = self.config.get("platform_settings", {}) @@ -91,7 +91,7 @@ async def _calc_comp_interval(self, comp: BaseMessageComponent) -> float: # random return random.uniform(self.interval[0], self.interval[1]) - async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]): + async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]) -> bool: """检查消息链是否为空 Args: @@ -136,7 +136,7 @@ def _extract_comp( raw_chain: list[BaseMessageComponent], extract_types: set[ComponentType], modify_raw_chain: bool = True, - ): + ) -> list[BaseMessageComponent]: extracted = [] if modify_raw_chain: remaining = [] diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index e0bcd5ac9..13f969fa5 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -3,6 +3,7 @@ import time import traceback from collections.abc import AsyncGenerator +from typing import cast from astrbot.core import file_token_service, html_renderer, logger from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply @@ -20,7 +21,9 @@ @register_stage class ResultDecorateStage(Stage): - async def initialize(self, ctx: PipelineContext): + content_safe_check_stage: ContentSafetyCheckStage | None + + async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx self.reply_prefix = ctx.astrbot_config["platform_settings"]["reply_prefix"] self.reply_with_mention = ctx.astrbot_config["platform_settings"][ @@ -95,7 +98,9 @@ async def initialize(self, ctx: PipelineContext): if self.content_safe_check_reply: for stage_cls in registered_stages: if stage_cls.__name__ == "ContentSafetyCheckStage": - self.content_safe_check_stage = stage_cls() + self.content_safe_check_stage = cast( + ContentSafetyCheckStage, stage_cls() + ) await self.content_safe_check_stage.initialize(ctx) provider_cfg = ctx.astrbot_config.get("provider_settings", {}) @@ -151,7 +156,7 @@ async def process( if isinstance(self.content_safe_check_stage, ContentSafetyCheckStage): async for _ in self.content_safe_check_stage.process( event, - check_text=text, + check_text=text, # type:ignore ): yield diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index 5fb3034f5..dea31dcb6 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -15,21 +15,23 @@ class PipelineScheduler: """管道调度器,负责调度各个阶段的执行""" - def __init__(self, context: PipelineContext): + def __init__(self, context: PipelineContext) -> None: registered_stages.sort( key=lambda x: STAGES_ORDER.index(x.__name__), ) # 按照顺序排序 self.ctx = context # 上下文对象 self.stages = [] # 存储阶段实例 - async def initialize(self): + async def initialize(self) -> None: """初始化管道调度器时, 初始化所有阶段""" for stage_cls in registered_stages: stage_instance = stage_cls() # 创建实例 await stage_instance.initialize(self.ctx) self.stages.append(stage_instance) - async def _process_stages(self, event: AstrMessageEvent, from_stage=0): + async def _process_stages( + self, event: AstrMessageEvent, from_stage: int = 0 + ) -> None: """依次执行各个阶段 Args: @@ -72,7 +74,7 @@ async def _process_stages(self, event: AstrMessageEvent, from_stage=0): logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。") break - async def execute(self, event: AstrMessageEvent): + async def execute(self, event: AstrMessageEvent) -> None: """执行 pipeline Args: diff --git a/astrbot/core/pipeline/stage.py b/astrbot/core/pipeline/stage.py index 74aca4ef1..e3a91e1be 100644 --- a/astrbot/core/pipeline/stage.py +++ b/astrbot/core/pipeline/stage.py @@ -10,7 +10,7 @@ registered_stages: list[type[Stage]] = [] # 维护了所有已注册的 Stage 实现类类型 -def register_stage(cls): +def register_stage(cls: type[Stage]) -> type[Stage]: """一个简单的装饰器,用于注册 pipeline 包下的 Stage 实现类""" registered_stages.append(cls) return cls @@ -33,7 +33,7 @@ async def initialize(self, ctx: PipelineContext) -> None: async def process( self, event: AstrMessageEvent, - ) -> None | AsyncGenerator[None, None]: + ) -> None | AsyncGenerator[None]: """处理事件 Args: diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index f6eda07a9..61542a0b2 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -4,9 +4,10 @@ import re import uuid from collections.abc import AsyncGenerator -from typing import Any +from typing import Any, TypeVar, overload from astrbot import logger +from astrbot.core.agent.tool import ToolSet from astrbot.core.db.po import Conversation from astrbot.core.message.components import ( At, @@ -21,12 +22,15 @@ from astrbot.core.message.message_event_result import MessageChain, MessageEventResult from astrbot.core.platform.message_type import MessageType from astrbot.core.provider.entities import ProviderRequest +from astrbot.core.provider.func_tool_manager import FunctionToolManager from astrbot.core.utils.metrics import Metric from .astrbot_message import AstrBotMessage, Group from .message_session import MessageSesion, MessageSession # noqa from .platform_metadata import PlatformMetadata +_VT = TypeVar("_VT") + class AstrMessageEvent(abc.ABC): def __init__( @@ -35,7 +39,7 @@ def __init__( message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, - ): + ) -> None: self.message_str = message_str """纯文本的消息""" self.message_obj = message_obj @@ -72,14 +76,14 @@ def __init__( # back_compability self.platform = platform_meta - def get_platform_name(self): + def get_platform_name(self) -> str: """获取这个事件所属的平台的类型(如 aiocqhttp, slack, discord 等)。 NOTE: 用户可能会同时运行多个相同类型的平台适配器。 """ return self.platform_meta.name - def get_platform_id(self): + def get_platform_id(self) -> str: """获取这个事件所属的平台的 ID。 NOTE: 用户可能会同时运行多个相同类型的平台适配器,但能确定的是 ID 是唯一的。 @@ -157,17 +161,23 @@ def get_sender_name(self) -> str: return self.message_obj.sender.nickname return "" - def set_extra(self, key, value): + def set_extra(self, key: str, value: object) -> None: """设置额外的信息。""" self._extras[key] = value - def get_extra(self, key: str | None = None, default=None) -> Any: + @overload + def get_extra(self, key: str, default: _VT = None) -> _VT: ... + + @overload + def get_extra(self, key: None = None, default: object | None = None) -> dict: ... + + def get_extra(self, key: str | None = None, default: _VT = None) -> dict | _VT: """获取额外的信息。""" if key is None: return self._extras return self._extras.get(key, default) - def clear_extra(self): + def clear_extra(self) -> None: """清除额外的信息。""" logger.info(f"清除 {self.get_platform_name()} 的额外信息: {self._extras}") self._extras.clear() @@ -200,7 +210,7 @@ async def send_streaming( self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False, - ): + ) -> None: """发送流式消息到消息平台,使用异步生成器。 目前仅支持: telegram,qq official 私聊。 Fallback仅支持 aiocqhttp。 @@ -210,13 +220,13 @@ async def send_streaming( ) self._has_send_oper = True - async def _pre_send(self): + async def _pre_send(self) -> None: """调度器会在执行 send() 前调用该方法 deprecated in v3.5.18""" - async def _post_send(self): + async def _post_send(self) -> None: """调度器会在执行 send() 后调用该方法 deprecated in v3.5.18""" - def set_result(self, result: MessageEventResult | str): + def set_result(self, result: MessageEventResult | str) -> None: """设置消息事件的结果。 Note: @@ -245,14 +255,14 @@ async def check_count(self, event: AstrMessageEvent): result.chain = [] self._result = result - def stop_event(self): + def stop_event(self) -> None: """终止事件传播。""" if self._result is None: self.set_result(MessageEventResult().stop_event()) else: self._result.stop_event() - def continue_event(self): + def continue_event(self) -> None: """继续事件传播。""" if self._result is None: self.set_result(MessageEventResult().continue_event()) @@ -265,7 +275,7 @@ def is_stopped(self) -> bool: return False # 默认是继续传播 return self._result.is_stopped() - def should_call_llm(self, call_llm: bool): + def should_call_llm(self, call_llm: bool) -> None: """是否在此消息事件中禁止默认的 LLM 请求。 只会阻止 AstrBot 默认的 LLM 请求链路,不会阻止插件中的 LLM 请求。 @@ -276,7 +286,7 @@ def get_result(self) -> MessageEventResult | None: """获取消息事件的结果。""" return self._result - def clear_result(self): + def clear_result(self) -> None: """清除消息事件的结果。""" self._result = None @@ -321,7 +331,7 @@ def chain_result(self, chain: list[BaseMessageComponent]) -> MessageEventResult: def request_llm( self, prompt: str, - func_tool_manager=None, + func_tool_manager: FunctionToolManager | None = None, session_id: str = "", image_urls: list[str] | None = None, contexts: list | None = None, @@ -360,7 +370,9 @@ def request_llm( prompt=prompt, session_id=session_id, image_urls=image_urls, - func_tool=func_tool_manager, + func_tool=func_tool_manager.get_full_tool_set() + if func_tool_manager + else ToolSet(), contexts=contexts, system_prompt=system_prompt, conversation=conversation, @@ -368,7 +380,7 @@ def request_llm( """平台适配器""" - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: """发送消息到消息平台。 Args: @@ -387,7 +399,7 @@ async def send(self, message: MessageChain): ) self._has_send_oper = True - async def react(self, emoji: str): + async def react(self, emoji: str) -> None: """对消息添加表情回应。 默认实现为发送一条包含该表情的消息。 @@ -396,7 +408,9 @@ async def react(self, emoji: str): """ await self.send(MessageChain([Plain(emoji)])) - async def get_group(self, group_id: str | None = None, **kwargs) -> Group | None: + async def get_group( + self, group_id: str | None = None, **kwargs: object + ) -> Group | None: """获取一个群聊的数据, 如果不填写 group_id: 如果是私聊消息,返回 None。如果是群聊消息,返回当前群聊的数据。 适配情况: diff --git a/astrbot/core/platform/astrbot_message.py b/astrbot/core/platform/astrbot_message.py index 253963322..3db53fd48 100644 --- a/astrbot/core/platform/astrbot_message.py +++ b/astrbot/core/platform/astrbot_message.py @@ -11,7 +11,7 @@ class MessageMember: user_id: str # 发送者id nickname: str | None = None - def __str__(self): + def __str__(self) -> str: # 使用 f-string 来构建返回的字符串表示形式 return ( f"User ID: {self.user_id}," @@ -34,7 +34,7 @@ class Group: members: list[MessageMember] | None = None """所有群成员""" - def __str__(self): + def __str__(self) -> str: # 使用 f-string 来构建返回的字符串表示形式 return ( f"Group ID: {self.group_id}\n" @@ -78,7 +78,7 @@ def group_id(self) -> str: return "" @group_id.setter - def group_id(self, value: str | None): + def group_id(self, value: str | None) -> None: """设置 group_id""" if value: if self.group: diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index f4313f642..c13cca3d5 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -13,7 +13,7 @@ class PlatformManager: - def __init__(self, config: AstrBotConfig, event_queue: Queue): + def __init__(self, config: AstrBotConfig, event_queue: Queue) -> None: self.platform_insts: list[Platform] = [] """加载的 Platform 的实例""" @@ -27,7 +27,7 @@ def __init__(self, config: AstrBotConfig, event_queue: Queue): 约定整个项目中对 unique_session 的引用都从 default 的配置中获取""" self.event_queue = event_queue - async def initialize(self): + async def initialize(self) -> None: """初始化所有平台适配器""" for platform in self.platforms_config: try: @@ -47,7 +47,7 @@ async def initialize(self): ), ) - async def load_platform(self, platform_config: dict): + async def load_platform(self, platform_config: dict) -> None: """实例化一个平台""" # 动态导入 try: @@ -153,7 +153,9 @@ async def load_platform(self, platform_config: dict): except Exception: logger.error(traceback.format_exc()) - async def _task_wrapper(self, task: asyncio.Task, platform: Platform | None = None): + async def _task_wrapper( + self, task: asyncio.Task, platform: Platform | None = None + ) -> None: # 设置平台状态为运行中 if platform: platform.status = PlatformStatus.RUNNING @@ -175,7 +177,7 @@ async def _task_wrapper(self, task: asyncio.Task, platform: Platform | None = No if platform: platform.record_error(error_msg, tb_str) - async def reload(self, platform_config: dict): + async def reload(self, platform_config: dict) -> None: await self.terminate_platform(platform_config["id"]) if platform_config["enable"]: await self.load_platform(platform_config) @@ -186,7 +188,7 @@ async def reload(self, platform_config: dict): if key not in config_ids: await self.terminate_platform(key) - async def terminate_platform(self, platform_id: str): + async def terminate_platform(self, platform_id: str) -> None: if platform_id in self._inst_map: logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...") @@ -208,12 +210,12 @@ async def terminate_platform(self, platform_id: str): if getattr(inst, "terminate", None): await inst.terminate() - async def terminate(self): + async def terminate(self) -> None: for inst in self.platform_insts: if getattr(inst, "terminate", None): await inst.terminate() - def get_insts(self): + def get_insts(self) -> list: return self.platform_insts def get_all_stats(self) -> dict: diff --git a/astrbot/core/platform/message_session.py b/astrbot/core/platform/message_session.py index bca5300b8..d14c02a8d 100644 --- a/astrbot/core/platform/message_session.py +++ b/astrbot/core/platform/message_session.py @@ -15,14 +15,14 @@ class MessageSession: session_id: str platform_id: str | None = None - def __str__(self): + def __str__(self) -> str: return f"{self.platform_id}:{self.message_type.value}:{self.session_id}" - def __post_init__(self): + def __post_init__(self) -> None: self.platform_id = self.platform_name @staticmethod - def from_str(session_str: str): + def from_str(session_str: str) -> "MessageSesion": platform_id, message_type, session_id = session_str.split(":") return MessageSession(platform_id, MessageType(message_type), session_id) diff --git a/astrbot/core/platform/platform.py b/astrbot/core/platform/platform.py index c2e55fb63..e68ed1fc0 100644 --- a/astrbot/core/platform/platform.py +++ b/astrbot/core/platform/platform.py @@ -7,6 +7,8 @@ from enum import Enum from typing import Any +from quart import Request, ResponseReturnValue + from astrbot.core.message.message_event_result import MessageChain from astrbot.core.utils.metrics import Metric @@ -34,10 +36,12 @@ class PlatformError: class Platform(abc.ABC): - def __init__(self, config: dict, event_queue: Queue): + def __init__( + self, platform_config: dict, platform_settings: dict, event_queue: Queue + ) -> None: super().__init__() # 平台配置 - self.config = config + self.config = platform_config # 维护了消息平台的事件队列,EventBus 会从这里取出事件并处理。 self._event_queue = event_queue self.client_self_id = uuid.uuid4().hex @@ -53,7 +57,7 @@ def status(self) -> PlatformStatus: return self._status @status.setter - def status(self, value: PlatformStatus): + def status(self, value: PlatformStatus) -> None: """设置平台运行状态""" self._status = value if value == PlatformStatus.RUNNING and self._started_at is None: @@ -69,12 +73,12 @@ def last_error(self) -> PlatformError | None: """获取最近的错误""" return self._errors[-1] if self._errors else None - def record_error(self, message: str, traceback_str: str | None = None): + def record_error(self, message: str, traceback_str: str | None = None) -> None: """记录一个错误""" self._errors.append(PlatformError(message=message, traceback=traceback_str)) self._status = PlatformStatus.ERROR - def clear_errors(self): + def clear_errors(self) -> None: """清除错误记录""" self._errors.clear() if self._status == PlatformStatus.ERROR: @@ -112,7 +116,7 @@ def run(self) -> Coroutine[Any, Any, None]: """得到一个平台的运行实例,需要返回一个协程对象。""" raise NotImplementedError - async def terminate(self): + async def terminate(self) -> None: """终止一个平台的运行实例。""" @abc.abstractmethod @@ -131,14 +135,14 @@ async def send_by_session( """ await Metric.upload(msg_event_tick=1, adapter_name=self.meta().name) - def commit_event(self, event: AstrMessageEvent): + def commit_event(self, event: AstrMessageEvent) -> None: """提交一个事件到事件队列。""" self._event_queue.put_nowait(event) - def get_client(self): + def get_client(self) -> None: """获取平台的客户端对象。""" - async def webhook_callback(self, request: Any) -> Any: + async def webhook_callback(self, request: Request) -> ResponseReturnValue: """统一 Webhook 回调入口。 支持统一 Webhook 模式的平台需要实现此方法。 diff --git a/astrbot/core/platform/register.py b/astrbot/core/platform/register.py index 5f550ecd1..e31d60eab 100644 --- a/astrbot/core/platform/register.py +++ b/astrbot/core/platform/register.py @@ -1,10 +1,16 @@ +from collections.abc import Callable +from typing import TypeVar + from astrbot.core import logger +from .platform import Platform from .platform_metadata import PlatformMetadata +T = TypeVar("T", bound=Platform) + platform_registry: list[PlatformMetadata] = [] """维护了通过装饰器注册的平台适配器""" -platform_cls_map: dict[str, type] = {} +platform_cls_map: dict[str, type[Platform]] = {} """维护了平台适配器名称和适配器类的映射""" @@ -15,14 +21,14 @@ def register_platform_adapter( adapter_display_name: str | None = None, logo_path: str | None = None, support_streaming_message: bool = True, -): +) -> Callable[[type[T]], type[T]]: """用于注册平台适配器的带参装饰器。 default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。 logo_path 指定了平台适配器的 logo 文件路径,是相对于插件目录的路径。 """ - def decorator(cls): + def decorator(cls: type[T]) -> type[T]: if adapter_name in platform_cls_map: raise ValueError( f"平台适配器 {adapter_name} 已经注册过了,可能发生了适配器命名冲突。", diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index 293b462d3..1ad26c265 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -26,7 +26,7 @@ def __init__( platform_meta, session_id, bot: CQHttp, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.bot = bot @@ -72,7 +72,7 @@ async def _dispatch_send( is_group: bool, session_id: str | None, messages: list[dict], - ): + ) -> None: # session_id 必须是纯数字字符串 session_id_int = ( int(session_id) if session_id and session_id.isdigit() else None @@ -97,7 +97,7 @@ async def send_message( event: Event | None = None, is_group: bool = False, session_id: str | None = None, - ): + ) -> None: """发送消息至 QQ 协议端(aiocqhttp)。 Args: @@ -143,7 +143,7 @@ async def send_message( await cls._dispatch_send(bot, event, is_group, session_id, messages) await asyncio.sleep(0.5) - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: """发送消息""" event = getattr(self.message_obj, "raw_message", None) diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index 29fde59ab..f818e5415 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -38,7 +38,7 @@ def __init__( platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(platform_config, event_queue) + super().__init__(platform_config, platform_settings, event_queue) self.settings = platform_settings self.host = platform_config["ws_reverse_host"] @@ -61,38 +61,38 @@ def __init__( ) @self.bot.on_request() - async def request(event: Event): + async def request(event: Event) -> None: abm = await self.convert_message(event) if abm: await self.handle_msg(abm) @self.bot.on_notice() - async def notice(event: Event): + async def notice(event: Event) -> None: abm = await self.convert_message(event) if abm: await self.handle_msg(abm) @self.bot.on_message("group") - async def group(event: Event): + async def group(event: Event) -> None: abm = await self.convert_message(event) if abm: await self.handle_msg(abm) @self.bot.on_message("private") - async def private(event: Event): + async def private(event: Event) -> None: abm = await self.convert_message(event) if abm: await self.handle_msg(abm) @self.bot.on_websocket_connection - def on_websocket_connection(_): + def on_websocket_connection(_) -> None: logger.info("aiocqhttp(OneBot v11) 适配器已连接。") async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: is_group = session.message_type == MessageType.GROUP_MESSAGE if is_group: session_id = session.session_id.split("_")[-1] @@ -417,17 +417,17 @@ def run(self) -> Awaitable[Any]: self.shutdown_event = asyncio.Event() return coro - async def terminate(self): + async def terminate(self) -> None: self.shutdown_event.set() - async def shutdown_trigger_placeholder(self): + async def shutdown_trigger_placeholder(self) -> None: await self.shutdown_event.wait() logger.info("aiocqhttp 适配器已被关闭") def meta(self) -> PlatformMetadata: return self.metadata - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: message_event = AiocqhttpMessageEvent( message_str=message.message_str, message_obj=message, diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py index ec2b29a64..8e64a79f1 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py @@ -2,7 +2,7 @@ import os import threading import uuid -from typing import cast +from typing import NoReturn, cast import aiohttp import dingtalk_stream @@ -48,7 +48,7 @@ def __init__( platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(platform_config, event_queue) + super().__init__(platform_config, platform_settings, event_queue) self.client_id = platform_config["client_id"] self.client_secret = platform_config["client_secret"] @@ -88,7 +88,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> NoReturn: raise NotImplementedError("钉钉机器人适配器不支持 send_by_session") def meta(self) -> PlatformMetadata: @@ -217,7 +217,7 @@ async def get_access_token(self) -> str: return "" return (await resp.json())["data"]["accessToken"] - async def handle_msg(self, abm: AstrBotMessage): + async def handle_msg(self, abm: AstrBotMessage) -> None: event = DingtalkMessageEvent( message_str=abm.message_str, message_obj=abm, @@ -228,10 +228,10 @@ async def handle_msg(self, abm: AstrBotMessage): self._event_queue.put_nowait(event) - async def run(self): + async def run(self) -> None: # await self.client_.start() # 钉钉的 SDK 并没有实现真正的异步,start() 里面有堵塞方法。 - def start_client(loop: asyncio.AbstractEventLoop): + def start_client(loop: asyncio.AbstractEventLoop) -> None: try: self._shutdown_event = threading.Event() task = loop.create_task(self.client_.start()) @@ -247,8 +247,8 @@ def start_client(loop: asyncio.AbstractEventLoop): loop = asyncio.get_event_loop() await loop.run_in_executor(None, start_client, loop) - async def terminate(self): - def monkey_patch_close(): + async def terminate(self) -> None: + def monkey_patch_close() -> NoReturn: raise KeyboardInterrupt("Graceful shutdown") if self.client_.websocket is not None: diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py index 197701e0d..27744d71f 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py @@ -16,7 +16,7 @@ def __init__( platform_meta, session_id, client: dingtalk_stream.ChatbotHandler, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client @@ -24,7 +24,7 @@ async def send_with_client( self, client: dingtalk_stream.ChatbotHandler, message: MessageChain, - ): + ) -> None: icm = cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message) ats = [] # fixes: #4218 @@ -78,7 +78,7 @@ async def send_with_client( logger.warning(f"钉钉图片处理失败: {e}, 跳过图片发送") continue - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: await self.send_with_client(self.client, message) await super().send(message) diff --git a/astrbot/core/platform/sources/discord/client.py b/astrbot/core/platform/sources/discord/client.py index ac0610f2a..ebd32c471 100644 --- a/astrbot/core/platform/sources/discord/client.py +++ b/astrbot/core/platform/sources/discord/client.py @@ -15,7 +15,7 @@ class DiscordBotClient(discord.Bot): """Discord客户端封装""" - def __init__(self, token: str, proxy: str | None = None): + def __init__(self, token: str, proxy: str | None = None) -> None: self.token = token self.proxy = proxy @@ -32,7 +32,7 @@ def __init__(self, token: str, proxy: str | None = None): self.on_ready_once_callback: Callable[[], Awaitable[None]] | None = None self._ready_once_fired = False - async def on_ready(self): + async def on_ready(self) -> None: """当机器人成功连接并准备就绪时触发""" if self.user is None: logger.error("[Discord] 客户端未正确加载用户信息 (self.user is None)") @@ -93,7 +93,7 @@ def _create_interaction_data(self, interaction: discord.Interaction) -> dict: "type": "interaction", } - async def on_message(self, message: discord.Message): + async def on_message(self, message: discord.Message) -> None: """当接收到消息时触发""" if message.author.bot: return @@ -130,12 +130,12 @@ def _extract_interaction_content(self, interaction: discord.Interaction) -> str: return str(interaction_data) - async def start_polling(self): + async def start_polling(self) -> None: """开始轮询消息,这是个阻塞方法""" await self.start(self.token) @override - async def close(self): + async def close(self) -> None: """关闭客户端""" if not self.is_closed(): await super().close() diff --git a/astrbot/core/platform/sources/discord/components.py b/astrbot/core/platform/sources/discord/components.py index f875652a0..433509f5e 100644 --- a/astrbot/core/platform/sources/discord/components.py +++ b/astrbot/core/platform/sources/discord/components.py @@ -19,7 +19,7 @@ def __init__( image: str | None = None, footer: str | None = None, fields: list[dict] | None = None, - ): + ) -> None: self.title = title self.description = description self.color = color @@ -71,7 +71,7 @@ def __init__( emoji: str | None = None, url: str | None = None, disabled: bool = False, - ): + ) -> None: self.label = label self.custom_id = custom_id self.style = style @@ -85,7 +85,7 @@ class DiscordReference(BaseMessageComponent): type: str = "discord_reference" - def __init__(self, message_id: str, channel_id: str): + def __init__(self, message_id: str, channel_id: str) -> None: self.message_id = message_id self.channel_id = channel_id @@ -99,7 +99,7 @@ def __init__( self, components: list[BaseMessageComponent] | None = None, timeout: float | None = None, - ): + ) -> None: self.components = components or [] self.timeout = timeout diff --git a/astrbot/core/platform/sources/discord/discord_platform_adapter.py b/astrbot/core/platform/sources/discord/discord_platform_adapter.py index 50aa0fe6f..36b71beda 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_adapter.py +++ b/astrbot/core/platform/sources/discord/discord_platform_adapter.py @@ -44,7 +44,7 @@ def __init__( platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(platform_config, event_queue) + super().__init__(platform_config, platform_settings, event_queue) self.settings = platform_settings self.client_self_id: str | None = None self.registered_handlers = [] @@ -60,7 +60,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: """通过会话发送消息""" if self.client.user is None: logger.error( @@ -122,11 +122,11 @@ def meta(self) -> PlatformMetadata: ) @override - async def run(self): + async def run(self) -> None: """主要运行逻辑""" # 初始化回调函数 - async def on_received(message_data): + async def on_received(message_data) -> None: logger.debug(f"[Discord] 收到消息: {message_data}") if self.client_self_id is None: self.client_self_id = message_data.get("bot_id") @@ -143,7 +143,7 @@ async def on_received(message_data): self.client = DiscordBotClient(token, proxy) self.client.on_message_received = on_received - async def callback(): + async def callback() -> None: if self.enable_command_register: await self._collect_and_register_commands() if self.activity_name: @@ -251,7 +251,7 @@ async def convert_message(self, data: dict) -> AstrBotMessage: # 由于 on_interaction 已被禁用,我们只处理普通消息 return self._convert_message_to_abm(data) - async def handle_msg(self, message: AstrBotMessage, followup_webhook=None): + async def handle_msg(self, message: AstrBotMessage, followup_webhook=None) -> None: """处理消息""" message_event = DiscordPlatformEvent( message_str=message.message_str, @@ -323,7 +323,7 @@ async def handle_msg(self, message: AstrBotMessage, followup_webhook=None): self.commit_event(message_event) @override - async def terminate(self): + async def terminate(self) -> None: """终止适配器""" logger.info("[Discord] 正在终止适配器... (step 1: cancel polling task)") self.shutdown_event.set() @@ -358,11 +358,11 @@ async def terminate(self): logger.warning(f"[Discord] 客户端关闭异常: {e}") logger.info("[Discord] 适配器已终止。") - def register_handler(self, handler_info): + def register_handler(self, handler_info) -> None: """注册处理器信息""" self.registered_handlers.append(handler_info) - async def _collect_and_register_commands(self): + async def _collect_and_register_commands(self) -> None: """收集所有指令并注册到Discord""" logger.info("[Discord] 开始收集并注册斜杠指令...") registered_commands = [] @@ -418,7 +418,7 @@ def _create_dynamic_callback(self, cmd_name: str): async def dynamic_callback( ctx: discord.ApplicationContext, params: str | None = None - ): + ) -> None: # 将平台特定的前缀'/'剥离,以适配通用的CommandFilter logger.debug(f"[Discord] 回调函数触发: {cmd_name}") logger.debug(f"[Discord] 回调函数参数: {ctx}") diff --git a/astrbot/core/platform/sources/discord/discord_platform_event.py b/astrbot/core/platform/sources/discord/discord_platform_event.py index 053018225..02d4dae86 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_event.py +++ b/astrbot/core/platform/sources/discord/discord_platform_event.py @@ -28,7 +28,7 @@ class DiscordViewComponent(BaseMessageComponent): type: str = "discord_view" - def __init__(self, view: discord.ui.View): + def __init__(self, view: discord.ui.View) -> None: self.view = view @@ -41,12 +41,12 @@ def __init__( session_id: str, client: DiscordBotClient, interaction_followup_webhook: discord.Webhook | None = None, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client self.interaction_followup_webhook = interaction_followup_webhook - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: """发送消息到Discord平台""" # 解析消息链为 Discord 所需的对象 try: @@ -267,7 +267,7 @@ async def _parse_to_discord( content = content[:2000] return content, files, view, embeds, reference_message_id - async def react(self, emoji: str): + async def react(self, emoji: str) -> None: """对原消息添加反应""" try: if hasattr(self.message_obj, "raw_message") and hasattr( diff --git a/astrbot/core/platform/sources/lark/lark_adapter.py b/astrbot/core/platform/sources/lark/lark_adapter.py index b71071167..82097fad7 100644 --- a/astrbot/core/platform/sources/lark/lark_adapter.py +++ b/astrbot/core/platform/sources/lark/lark_adapter.py @@ -4,7 +4,7 @@ import re import time import uuid -from typing import Any, cast +from typing import cast import lark_oapi as lark from lark_oapi.api.im.v1 import ( @@ -13,6 +13,7 @@ GetMessageResourceRequest, ) from lark_oapi.api.im.v1.processor import P2ImMessageReceiveV1Processor +from quart import Request, ResponseReturnValue import astrbot.api.message_components as Comp from astrbot import logger @@ -42,7 +43,7 @@ def __init__( platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(platform_config, event_queue) + super().__init__(platform_config, platform_settings, event_queue) self.appid = platform_config["app_id"] self.appsecret = platform_config["app_secret"] @@ -56,10 +57,10 @@ def __init__( logger.warning("未设置飞书机器人名称,@ 机器人可能得不到回复。") # 初始化 WebSocket 长连接相关配置 - async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1): + async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1) -> None: await self.convert_msg(event) - def do_v2_msg_event(event: lark.im.v1.P2ImMessageReceiveV1): + def do_v2_msg_event(event: lark.im.v1.P2ImMessageReceiveV1) -> None: asyncio.create_task(on_msg_event_recv(event)) self.event_handler = ( @@ -124,7 +125,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: if self.lark_api.im is None: logger.error("[Lark] API Client im 模块未初始化,无法发送消息") return @@ -173,7 +174,7 @@ def meta(self) -> PlatformMetadata: support_streaming_message=False, ) - async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1): + async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1) -> None: if event.event is None: logger.debug("[Lark] 收到空事件(event.event is None)") return @@ -323,7 +324,7 @@ async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1): logger.debug(abm) await self.handle_msg(abm) - async def handle_msg(self, abm: AstrBotMessage): + async def handle_msg(self, abm: AstrBotMessage) -> None: event = LarkMessageEvent( message_str=abm.message_str, message_obj=abm, @@ -356,7 +357,7 @@ async def handle_webhook_event(self, event_data: dict): except Exception as e: logger.error(f"[Lark Webhook] 处理事件失败: {e}", exc_info=True) - async def run(self): + async def run(self) -> None: if self.connection_mode == "webhook": # Webhook 模式 if self.webhook_server is None: @@ -372,14 +373,14 @@ async def run(self): # 长连接模式 await self.client._connect() - async def webhook_callback(self, request: Any) -> Any: + async def webhook_callback(self, request: Request) -> ResponseReturnValue: """统一 Webhook 回调入口""" if not self.webhook_server: return {"error": "Webhook server not initialized"}, 500 return await self.webhook_server.handle_callback(request) - async def terminate(self): + async def terminate(self) -> None: if self.connection_mode == "socket": await self.client._disconnect() logger.info("飞书(Lark) 适配器已关闭") diff --git a/astrbot/core/platform/sources/lark/lark_event.py b/astrbot/core/platform/sources/lark/lark_event.py index 7b7d20b38..b6c7d5258 100644 --- a/astrbot/core/platform/sources/lark/lark_event.py +++ b/astrbot/core/platform/sources/lark/lark_event.py @@ -31,7 +31,7 @@ def __init__( platform_meta, session_id, bot: lark.Client, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.bot = bot @@ -110,7 +110,7 @@ async def _convert_to_lark(message: MessageChain, lark_client: lark.Client) -> l ret.append(_stage) return ret - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: res = await LarkMessageEvent._convert_to_lark(message, self.bot) wrapped = { "zh_cn": { @@ -144,7 +144,7 @@ async def send(self, message: MessageChain): await super().send(message) - async def react(self, emoji: str): + async def react(self, emoji: str) -> None: if self.bot.im is None: logger.error("[Lark] API Client im 模块未初始化,无法发送表情") return diff --git a/astrbot/core/platform/sources/misskey/misskey_adapter.py b/astrbot/core/platform/sources/misskey/misskey_adapter.py index d8f560b1b..fae451ab8 100644 --- a/astrbot/core/platform/sources/misskey/misskey_adapter.py +++ b/astrbot/core/platform/sources/misskey/misskey_adapter.py @@ -54,7 +54,7 @@ def __init__( platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(platform_config or {}, event_queue) + super().__init__(platform_config or {}, platform_settings, event_queue) self.settings = platform_settings or {} self.instance_url = self.config.get("misskey_instance_url", "") self.access_token = self.config.get("misskey_token", "") @@ -121,7 +121,7 @@ def meta(self) -> PlatformMetadata: support_streaming_message=False, ) - async def run(self): + async def run(self) -> None: if not self.instance_url or not self.access_token: logger.error("[Misskey] 配置不完整,无法启动") return @@ -150,7 +150,7 @@ async def run(self): await self._start_websocket_connection() - def _register_event_handlers(self, streaming): + def _register_event_handlers(self, streaming) -> None: """注册事件处理器""" streaming.add_message_handler("notification", self._handle_notification) streaming.add_message_handler("main:notification", self._handle_notification) @@ -194,7 +194,7 @@ def _process_poll_data( message: AstrBotMessage, poll: dict[str, Any], message_parts: list[str], - ): + ) -> None: """处理投票数据,将其添加到消息中""" try: if not isinstance(message.raw_message, dict): @@ -233,7 +233,7 @@ def _extract_additional_fields(self, session, message_chain) -> dict[str, Any]: return fields - async def _start_websocket_connection(self): + async def _start_websocket_connection(self) -> None: backoff_delay = 1.0 max_backoff = 300.0 backoff_multiplier = 1.5 @@ -281,7 +281,7 @@ async def _start_websocket_connection(self): await asyncio.sleep(sleep_time) backoff_delay = min(backoff_delay * backoff_multiplier, max_backoff) - async def _handle_notification(self, data: dict[str, Any]): + async def _handle_notification(self, data: dict[str, Any]) -> None: try: notification_type = data.get("type") logger.debug( @@ -305,7 +305,7 @@ async def _handle_notification(self, data: dict[str, Any]): except Exception as e: logger.error(f"[Misskey] 处理通知失败: {e}") - async def _handle_chat_message(self, data: dict[str, Any]): + async def _handle_chat_message(self, data: dict[str, Any]) -> None: try: sender_id = str( data.get("fromUserId", "") or data.get("fromUser", {}).get("id", ""), @@ -340,7 +340,7 @@ async def _handle_chat_message(self, data: dict[str, Any]): except Exception as e: logger.error(f"[Misskey] 处理聊天消息失败: {e}") - async def _debug_handler(self, data: dict[str, Any]): + async def _debug_handler(self, data: dict[str, Any]) -> None: event_type = data.get("type", "unknown") logger.debug( f"[Misskey] 收到未处理事件: type={event_type}, channel={data.get('channel', 'unknown')}", @@ -754,7 +754,7 @@ async def convert_room_message(self, raw_data: dict[str, Any]) -> AstrBotMessage ) return message - async def terminate(self): + async def terminate(self) -> None: self._running = False if self.api: await self.api.close() diff --git a/astrbot/core/platform/sources/misskey/misskey_api.py b/astrbot/core/platform/sources/misskey/misskey_api.py index 06dc6304d..86636b12c 100644 --- a/astrbot/core/platform/sources/misskey/misskey_api.py +++ b/astrbot/core/platform/sources/misskey/misskey_api.py @@ -3,7 +3,7 @@ import random import uuid from collections.abc import Awaitable, Callable -from typing import Any +from typing import Any, NoReturn try: import aiohttp @@ -43,7 +43,7 @@ class WebSocketError(APIError): class StreamingClient: - def __init__(self, instance_url: str, access_token: str): + def __init__(self, instance_url: str, access_token: str) -> None: self.instance_url = instance_url.rstrip("/") self.access_token = access_token self.websocket: Any | None = None @@ -90,7 +90,7 @@ async def connect(self) -> bool: self.is_connected = False return False - async def disconnect(self): + async def disconnect(self) -> None: self._running = False if self.websocket: await self.websocket.close() @@ -116,7 +116,7 @@ async def subscribe_channel( self.channels[channel_id] = channel_type return channel_id - async def unsubscribe_channel(self, channel_id: str): + async def unsubscribe_channel(self, channel_id: str) -> None: if ( not self.is_connected or not self.websocket @@ -136,10 +136,10 @@ def add_message_handler( self, event_type: str, handler: Callable[[dict], Awaitable[None]], - ): + ) -> None: self.message_handlers[event_type] = handler - async def listen(self): + async def listen(self) -> None: if not self.is_connected or not self.websocket: raise WebSocketError("WebSocket 未连接") @@ -187,7 +187,7 @@ async def listen(self): except Exception: pass - async def _handle_message(self, data: dict[str, Any]): + async def _handle_message(self, data: dict[str, Any]) -> None: message_type = data.get("type") body = data.get("body", {}) @@ -334,7 +334,7 @@ def __init__( download_timeout: int = 15, chunk_size: int = 64 * 1024, max_download_bytes: int | None = None, - ): + ) -> None: self.instance_url = instance_url.rstrip("/") self.access_token = access_token self._session: aiohttp.ClientSession | None = None @@ -375,7 +375,7 @@ def session(self) -> aiohttp.ClientSession: self._session = aiohttp.ClientSession(headers=headers) return self._session - def _handle_response_status(self, status: int, endpoint: str): + def _handle_response_status(self, status: int, endpoint: str) -> NoReturn: """处理 HTTP 响应状态码""" if status == 400: logger.error(f"[Misskey API] 请求参数错误: {endpoint} (HTTP {status})") diff --git a/astrbot/core/platform/sources/misskey/misskey_event.py b/astrbot/core/platform/sources/misskey/misskey_event.py index 7975f0ec7..068f7e7a2 100644 --- a/astrbot/core/platform/sources/misskey/misskey_event.py +++ b/astrbot/core/platform/sources/misskey/misskey_event.py @@ -26,7 +26,7 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, client, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client @@ -40,7 +40,7 @@ def _is_system_command(self, message_str: str) -> bool: return any(message_trimmed.startswith(prefix) for prefix in system_prefixes) - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: """发送消息,使用适配器的完整上传和发送逻辑""" try: logger.debug( diff --git a/astrbot/core/platform/sources/misskey/misskey_utils.py b/astrbot/core/platform/sources/misskey/misskey_utils.py index d9388598d..dd02c13c0 100644 --- a/astrbot/core/platform/sources/misskey/misskey_utils.py +++ b/astrbot/core/platform/sources/misskey/misskey_utils.py @@ -403,7 +403,7 @@ def cache_user_info( raw_data: dict[str, Any], client_self_id: str, is_chat: bool = False, -): +) -> None: """缓存用户信息""" if is_chat: user_cache_data = { @@ -429,7 +429,7 @@ def cache_room_info( user_cache: dict[str, Any], raw_data: dict[str, Any], client_self_id: str, -): +) -> None: """缓存房间信息""" room_data = raw_data.get("toRoom") room_id = raw_data.get("toRoomId") diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index d693c4206..2da3cb4b7 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -32,12 +32,12 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, bot: Client, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.bot = bot self.send_buffer = None - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: self.send_buffer = message await self._post_send() diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py index 7de535fbf..f7dd3dae2 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py @@ -4,7 +4,7 @@ import logging import os import time -from typing import cast +from typing import NoReturn, cast import botpy import botpy.message @@ -35,11 +35,13 @@ # QQ 机器人官方框架 class botClient(Client): - def set_platform(self, platform: QQOfficialPlatformAdapter): + def set_platform(self, platform: QQOfficialPlatformAdapter) -> None: self.platform = platform # 收到群消息 - async def on_group_at_message_create(self, message: botpy.message.GroupMessage): + async def on_group_at_message_create( + self, message: botpy.message.GroupMessage + ) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.GROUP_MESSAGE, @@ -49,7 +51,7 @@ async def on_group_at_message_create(self, message: botpy.message.GroupMessage): self._commit(abm) # 收到频道消息 - async def on_at_message_create(self, message: botpy.message.Message): + async def on_at_message_create(self, message: botpy.message.Message) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.GROUP_MESSAGE, @@ -59,7 +61,9 @@ async def on_at_message_create(self, message: botpy.message.Message): self._commit(abm) # 收到私聊消息 - async def on_direct_message_create(self, message: botpy.message.DirectMessage): + async def on_direct_message_create( + self, message: botpy.message.DirectMessage + ) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.FRIEND_MESSAGE, @@ -68,7 +72,7 @@ async def on_direct_message_create(self, message: botpy.message.DirectMessage): self._commit(abm) # 收到 C2C 消息 - async def on_c2c_message_create(self, message: botpy.message.C2CMessage): + async def on_c2c_message_create(self, message: botpy.message.C2CMessage) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.FRIEND_MESSAGE, @@ -76,7 +80,7 @@ async def on_c2c_message_create(self, message: botpy.message.C2CMessage): abm.session_id = abm.sender.user_id self._commit(abm) - def _commit(self, abm: AstrBotMessage): + def _commit(self, abm: AstrBotMessage) -> None: self.platform.commit_event( QQOfficialMessageEvent( abm.message_str, @@ -96,7 +100,7 @@ def __init__( platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(platform_config, event_queue) + super().__init__(platform_config, platform_settings, event_queue) self.appid = platform_config["appid"] self.secret = platform_config["secret"] @@ -128,7 +132,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> NoReturn: raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session") def meta(self) -> PlatformMetadata: @@ -221,6 +225,6 @@ def run(self): def get_client(self) -> botClient: return self.client - async def terminate(self): + async def terminate(self) -> None: await self.client.close() logger.info("QQ 官方机器人接口 适配器已被优雅地关闭") diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py index 80ed34245..cfb13426d 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py @@ -1,12 +1,13 @@ import asyncio import logging -from typing import Any, cast +from typing import NoReturn, cast import botpy import botpy.message import botpy.types import botpy.types.message from botpy import Client +from quart import Request, ResponseReturnValue from astrbot import logger from astrbot.api.event import MessageChain @@ -26,11 +27,13 @@ # QQ 机器人官方框架 class botClient(Client): - def set_platform(self, platform: "QQOfficialWebhookPlatformAdapter"): + def set_platform(self, platform: "QQOfficialWebhookPlatformAdapter") -> None: self.platform = platform # 收到群消息 - async def on_group_at_message_create(self, message: botpy.message.GroupMessage): + async def on_group_at_message_create( + self, message: botpy.message.GroupMessage + ) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.GROUP_MESSAGE, @@ -40,7 +43,7 @@ async def on_group_at_message_create(self, message: botpy.message.GroupMessage): self._commit(abm) # 收到频道消息 - async def on_at_message_create(self, message: botpy.message.Message): + async def on_at_message_create(self, message: botpy.message.Message) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.GROUP_MESSAGE, @@ -50,7 +53,9 @@ async def on_at_message_create(self, message: botpy.message.Message): self._commit(abm) # 收到私聊消息 - async def on_direct_message_create(self, message: botpy.message.DirectMessage): + async def on_direct_message_create( + self, message: botpy.message.DirectMessage + ) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.FRIEND_MESSAGE, @@ -59,7 +64,7 @@ async def on_direct_message_create(self, message: botpy.message.DirectMessage): self._commit(abm) # 收到 C2C 消息 - async def on_c2c_message_create(self, message: botpy.message.C2CMessage): + async def on_c2c_message_create(self, message: botpy.message.C2CMessage) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.FRIEND_MESSAGE, @@ -67,7 +72,7 @@ async def on_c2c_message_create(self, message: botpy.message.C2CMessage): abm.session_id = abm.sender.user_id self._commit(abm) - def _commit(self, abm: AstrBotMessage): + def _commit(self, abm: AstrBotMessage) -> None: self.platform.commit_event( QQOfficialWebhookMessageEvent( abm.message_str, @@ -87,7 +92,7 @@ def __init__( platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(platform_config, event_queue) + super().__init__(platform_config, platform_settings, event_queue) self.appid = platform_config["appid"] self.secret = platform_config["secret"] @@ -110,7 +115,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> NoReturn: raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session") def meta(self) -> PlatformMetadata: @@ -120,7 +125,7 @@ def meta(self) -> PlatformMetadata: id=cast(str, self.config.get("id")), ) - async def run(self): + async def run(self) -> None: self.webhook_helper = QQOfficialWebhook( self.config, self._event_queue, @@ -140,7 +145,7 @@ async def run(self): def get_client(self) -> botClient: return self.client - async def webhook_callback(self, request: Any) -> Any: + async def webhook_callback(self, request: Request) -> ResponseReturnValue: """统一 Webhook 回调入口""" if not self.webhook_helper: return {"error": "Webhook helper not initialized"}, 500 @@ -148,7 +153,7 @@ async def webhook_callback(self, request: Any) -> Any: # 复用 webhook_helper 的回调处理逻辑 return await self.webhook_helper.handle_callback(request) - async def terminate(self): + async def terminate(self) -> None: if self.webhook_helper: self.webhook_helper.shutdown_event.set() await self.client.close() diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py index 306db5e56..5ceeb2c70 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py @@ -13,5 +13,5 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, bot: Client, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id, bot) diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py index 2eda11a6c..5f35471ee 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py @@ -14,7 +14,9 @@ class QQOfficialWebhook: - def __init__(self, config: dict, event_queue: asyncio.Queue, botpy_client: Client): + def __init__( + self, config: dict, event_queue: asyncio.Queue, botpy_client: Client + ) -> None: self.appid = config["appid"] self.secret = config["secret"] self.port = config.get("port", 6196) @@ -38,7 +40,7 @@ def __init__(self, config: dict, event_queue: asyncio.Queue, botpy_client: Clien self.event_queue = event_queue self.shutdown_event = asyncio.Event() - async def initialize(self): + async def initialize(self) -> None: logger.info("正在登录到 QQ 官方机器人...") self.user = await self.http.login(self.token) logger.info(f"已登录 QQ 官方机器人账号: {self.user}") @@ -46,7 +48,7 @@ async def initialize(self): self.client.api = self.api self.client.http = self.http - async def bot_connect(): + async def bot_connect() -> None: pass self._connection = ConnectionSession( @@ -115,7 +117,7 @@ async def handle_callback(self, request) -> dict: return {"opcode": 12} - async def start_polling(self): + async def start_polling(self) -> None: logger.info( f"将在 {self.callback_server_host}:{self.port} 端口启动 QQ 官方机器人 webhook 适配器。", ) @@ -125,5 +127,5 @@ async def start_polling(self): shutdown_trigger=self.shutdown_trigger, ) - async def shutdown_trigger(self): + async def shutdown_trigger(self) -> None: await self.shutdown_event.wait() diff --git a/astrbot/core/platform/sources/satori/satori_adapter.py b/astrbot/core/platform/sources/satori/satori_adapter.py index 10912dc8e..b60630ede 100644 --- a/astrbot/core/platform/sources/satori/satori_adapter.py +++ b/astrbot/core/platform/sources/satori/satori_adapter.py @@ -38,7 +38,7 @@ def __init__( platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(platform_config, event_queue) + super().__init__(platform_config, platform_settings, event_queue) self.settings = platform_settings self.api_base_url = self.config.get( @@ -73,7 +73,7 @@ async def send_by_session( self, session: MessageSession, message_chain: MessageChain, - ): + ) -> None: from .satori_event import SatoriPlatformEvent await SatoriPlatformEvent.send_with_adapter( @@ -99,7 +99,7 @@ def _is_websocket_closed(self, ws) -> bool: except AttributeError: return False - async def run(self): + async def run(self) -> None: self.running = True self.session = ClientSession(timeout=ClientTimeout(total=30)) @@ -133,7 +133,7 @@ async def run(self): if self.session: await self.session.close() - async def connect_websocket(self): + async def connect_websocket(self) -> None: logger.info(f"Satori 适配器正在连接到 WebSocket: {self.endpoint}") logger.info(f"Satori 适配器 HTTP API 地址: {self.api_base_url}") @@ -181,7 +181,7 @@ async def connect_websocket(self): except Exception as e: logger.error(f"Satori WebSocket 关闭异常: {e}") - async def send_identify(self): + async def send_identify(self) -> None: if not self.ws: raise Exception("WebSocket连接未建立") @@ -209,7 +209,7 @@ async def send_identify(self): logger.error(f"发送 IDENTIFY 信令失败: {e}") raise - async def heartbeat_loop(self): + async def heartbeat_loop(self) -> None: try: while self.running and self.ws: await asyncio.sleep(self.heartbeat_interval) @@ -234,7 +234,7 @@ async def heartbeat_loop(self): except Exception as e: logger.error(f"心跳任务异常: {e}") - async def handle_message(self, message: str): + async def handle_message(self, message: str) -> None: try: data = json.loads(message) op = data.get("op") @@ -275,7 +275,7 @@ async def handle_message(self, message: str): except Exception as e: logger.error(f"处理 WebSocket 消息异常: {e}") - async def handle_event(self, event_data: dict): + async def handle_event(self, event_data: dict) -> None: try: event_type = event_data.get("type") sn = event_data.get("sn") @@ -720,7 +720,7 @@ async def _parse_xml_node(self, node: ET.Element, elements: list) -> None: if child.tail and child.tail.strip(): elements.append(Plain(text=child.tail)) - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: from .satori_event import SatoriPlatformEvent message_event = SatoriPlatformEvent( @@ -780,7 +780,7 @@ async def send_http_request( logger.error(f"Satori HTTP 请求异常: {e}") return {} - async def terminate(self): + async def terminate(self) -> None: self.running = False if self.heartbeat_task: diff --git a/astrbot/core/platform/sources/satori/satori_event.py b/astrbot/core/platform/sources/satori/satori_event.py index 81a0d222c..021422283 100644 --- a/astrbot/core/platform/sources/satori/satori_event.py +++ b/astrbot/core/platform/sources/satori/satori_event.py @@ -28,7 +28,7 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, adapter: "SatoriPlatformAdapter", - ): + ) -> None: # 更新平台元数据 if adapter and hasattr(adapter, "logins") and adapter.logins: current_login = adapter.logins[0] @@ -110,7 +110,7 @@ async def send_with_adapter( logger.error(f"Satori 消息发送异常: {e}") return None - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: platform = getattr(self, "platform", None) user_id = getattr(self, "user_id", None) diff --git a/astrbot/core/platform/sources/slack/client.py b/astrbot/core/platform/sources/slack/client.py index fbdc71759..efd7a6f3d 100644 --- a/astrbot/core/platform/sources/slack/client.py +++ b/astrbot/core/platform/sources/slack/client.py @@ -27,7 +27,7 @@ def __init__( port: int = 3000, path: str = "/slack/events", event_handler: Callable | None = None, - ): + ) -> None: self.web_client = web_client self.signing_secret = signing_secret self.host = host @@ -44,7 +44,7 @@ def __init__( self.shutdown_event = asyncio.Event() - def _setup_routes(self): + def _setup_routes(self) -> None: """设置路由""" @self.app.route(self.path, methods=["POST"]) @@ -105,7 +105,7 @@ async def handle_callback(self, req): logger.error(f"处理 Slack 事件时出错: {e}") return Response("Internal Server Error", status=500) - async def start(self): + async def start(self) -> None: """启动 Webhook 服务器""" logger.info( f"Slack Webhook 服务器启动中,监听 {self.host}:{self.port}{self.path}...", @@ -118,10 +118,10 @@ async def start(self): shutdown_trigger=self.shutdown_trigger, ) - async def shutdown_trigger(self): + async def shutdown_trigger(self) -> None: await self.shutdown_event.wait() - async def stop(self): + async def stop(self) -> None: """停止 Webhook 服务器""" self.shutdown_event.set() logger.info("Slack Webhook 服务器已停止") @@ -135,7 +135,7 @@ def __init__( web_client: AsyncWebClient, app_token: str, event_handler: Callable | None = None, - ): + ) -> None: self.web_client = web_client self.app_token = app_token self.event_handler = event_handler @@ -143,7 +143,7 @@ def __init__( async def _handle_events( self, _: AsyncBaseSocketModeClient, req: SocketModeRequest - ): + ) -> None: """处理 Socket Mode 事件""" try: if self.socket_client is None: @@ -160,7 +160,7 @@ async def _handle_events( except Exception as e: logger.error(f"处理 Socket Mode 事件时出错: {e}") - async def start(self): + async def start(self) -> None: """启动 Socket Mode 连接""" self.socket_client = SocketModeClient( app_token=self.app_token, @@ -174,7 +174,7 @@ async def start(self): logger.info("Slack Socket Mode 客户端启动中...") await self.socket_client.connect() - async def stop(self): + async def stop(self) -> None: """停止 Socket Mode 连接""" if self.socket_client: await self.socket_client.disconnect() diff --git a/astrbot/core/platform/sources/slack/slack_adapter.py b/astrbot/core/platform/sources/slack/slack_adapter.py index afd80a8fe..dc06f44b6 100644 --- a/astrbot/core/platform/sources/slack/slack_adapter.py +++ b/astrbot/core/platform/sources/slack/slack_adapter.py @@ -3,9 +3,10 @@ import re import time import uuid -from typing import Any, cast +from typing import cast import aiohttp +from quart import Request, ResponseReturnValue from slack_sdk.socket_mode.request import SocketModeRequest from slack_sdk.web.async_client import AsyncWebClient @@ -39,7 +40,7 @@ def __init__( platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(platform_config, event_queue) + super().__init__(platform_config, platform_settings, event_queue) self.settings = platform_settings self.bot_token = platform_config.get("bot_token") @@ -81,7 +82,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: blocks, text = await SlackMessageEvent._parse_slack_blocks( message_chain=message_chain, web_client=self.web_client, @@ -285,7 +286,7 @@ def _parse_blocks(self, blocks: list) -> list: return message_components - async def _handle_socket_event(self, req: SocketModeRequest): + async def _handle_socket_event(self, req: SocketModeRequest) -> None: """处理 Socket Mode 事件""" if req.type == "events_api": # 事件 API @@ -374,7 +375,7 @@ async def run(self) -> None: f"不支持的连接模式: {self.connection_mode},请使用 'socket' 或 'webhook'", ) - async def _handle_webhook_event(self, event_data: dict): + async def _handle_webhook_event(self, event_data: dict) -> None: """处理 Webhook 事件""" event = event_data.get("event", {}) @@ -394,14 +395,14 @@ async def _handle_webhook_event(self, event_data: dict): if abm: await self.handle_msg(abm) - async def webhook_callback(self, request: Any) -> Any: + async def webhook_callback(self, request: Request) -> ResponseReturnValue: """统一 Webhook 回调入口""" if self.connection_mode != "webhook" or not self.webhook_client: return {"error": "Slack adapter is not in webhook mode"}, 400 return await self.webhook_client.handle_callback(request) - async def terminate(self): + async def terminate(self) -> None: if self.socket_client: await self.socket_client.stop() if self.webhook_client: @@ -411,7 +412,7 @@ async def terminate(self): def meta(self) -> PlatformMetadata: return self.metadata - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: message_event = SlackMessageEvent( message_str=message.message_str, message_obj=message, diff --git a/astrbot/core/platform/sources/slack/slack_event.py b/astrbot/core/platform/sources/slack/slack_event.py index 822e6fdeb..3f62690b5 100644 --- a/astrbot/core/platform/sources/slack/slack_event.py +++ b/astrbot/core/platform/sources/slack/slack_event.py @@ -24,7 +24,7 @@ def __init__( platform_meta, session_id, web_client: AsyncWebClient, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.web_client = web_client @@ -126,7 +126,7 @@ async def _parse_slack_blocks( return blocks, "" if blocks else text_content - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: blocks, text = await SlackMessageEvent._parse_slack_blocks( message, self.web_client, diff --git a/astrbot/core/platform/sources/telegram/tg_adapter.py b/astrbot/core/platform/sources/telegram/tg_adapter.py index 218d13bdc..db720160f 100644 --- a/astrbot/core/platform/sources/telegram/tg_adapter.py +++ b/astrbot/core/platform/sources/telegram/tg_adapter.py @@ -42,7 +42,7 @@ def __init__( platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(platform_config, event_queue) + super().__init__(platform_config, platform_settings, event_queue) self.settings = platform_settings self.client_self_id = uuid.uuid4().hex[:8] @@ -94,7 +94,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: from_username = session.session_id await TelegramPlatformEvent.send_with_client( self.client, @@ -109,7 +109,7 @@ def meta(self) -> PlatformMetadata: return PlatformMetadata(name="telegram", description="telegram 适配器", id=id_) @override - async def run(self): + async def run(self) -> None: await self.application.initialize() await self.application.start() @@ -134,7 +134,7 @@ async def run(self): logger.info("Telegram Platform Adapter is running.") await queue - async def register_commands(self): + async def register_commands(self) -> None: """收集所有注册的指令并注册到 Telegram""" try: commands = self.collect_commands() @@ -210,7 +210,7 @@ def _extract_command_info( description = description[:30] + "..." return cmd_name, description - async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE): + async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: if not update.effective_chat: logger.warning( "Received a start command without an effective chat, skipping /start reply.", @@ -221,7 +221,9 @@ async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE): text=self.config["start_message"], ) - async def message_handler(self, update: Update, context: ContextTypes.DEFAULT_TYPE): + async def message_handler( + self, update: Update, context: ContextTypes.DEFAULT_TYPE + ) -> None: logger.debug(f"Telegram message: {update.message}") abm = await self.convert_message(update, context) if abm: @@ -397,7 +399,7 @@ async def convert_message( return message - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: message_event = TelegramPlatformEvent( message_str=message.message_str, message_obj=message, @@ -410,7 +412,7 @@ async def handle_msg(self, message: AstrBotMessage): def get_client(self) -> ExtBot: return self.client - async def terminate(self): + async def terminate(self) -> None: try: if self.scheduler.running: self.scheduler.shutdown() diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py index 5faba6803..1df289d83 100644 --- a/astrbot/core/platform/sources/telegram/tg_event.py +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -38,7 +38,7 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, client: ExtBot, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client @@ -73,7 +73,7 @@ async def send_with_client( client: ExtBot, message: MessageChain, user_name: str, - ): + ) -> None: image_path = None has_reply = False @@ -134,14 +134,14 @@ async def send_with_client( path = await i.convert_to_file_path() await client.send_voice(voice=path, **cast(Any, payload)) - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: if self.get_message_type() == MessageType.GROUP_MESSAGE: await self.send_with_client(self.client, message, self.message_obj.group_id) else: await self.send_with_client(self.client, message, self.get_sender_id()) await super().send(message) - async def react(self, emoji: str | None, big: bool = False): + async def react(self, emoji: str | None, big: bool = False) -> None: """给原消息添加 Telegram 反应: - 普通 emoji:传入 '👍'、'😂' 等 - 自定义表情:传入其 custom_emoji_id(纯数字字符串) diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index 43a562026..32c64b606 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -31,7 +31,7 @@ def __init__(self, webchat_queue_mgr: WebChatQueueMgr, callback: Callable) -> No self.callback = callback self.running_tasks = set() - async def listen_to_queue(self, conversation_id: str): + async def listen_to_queue(self, conversation_id: str) -> None: """Listen to a specific conversation queue""" queue = self.webchat_queue_mgr.get_or_create_queue(conversation_id) while True: @@ -44,7 +44,7 @@ async def listen_to_queue(self, conversation_id: str): ) break - async def run(self): + async def run(self) -> None: """Monitor for new conversation queues and start listeners""" monitored_conversations = set() @@ -76,7 +76,7 @@ def __init__( platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(platform_config, event_queue) + super().__init__(platform_config, platform_settings, event_queue) self.settings = platform_settings self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs") @@ -92,7 +92,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: await WebChatMessageEvent._send(message_chain, session.session_id) await super().send_by_session(session, message_chain) @@ -207,7 +207,7 @@ async def convert_message(self, data: tuple) -> AstrBotMessage: return abm def run(self) -> Coroutine[Any, Any, None]: - async def callback(data: tuple): + async def callback(data: tuple) -> None: abm = await self.convert_message(data) await self.handle_msg(abm) @@ -217,7 +217,7 @@ async def callback(data: tuple): def meta(self) -> PlatformMetadata: return self.metadata - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: message_event = WebChatMessageEvent( message_str=message.message_str, message_obj=message, @@ -234,6 +234,6 @@ async def handle_msg(self, message: AstrBotMessage): self.commit_event(message_event) - async def terminate(self): + async def terminate(self) -> None: # Do nothing pass diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index 2e529bb1d..bc3d12778 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -15,7 +15,7 @@ class WebChatMessageEvent(AstrMessageEvent): - def __init__(self, message_str, message_obj, platform_meta, session_id): + def __init__(self, message_str, message_obj, platform_meta, session_id) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) os.makedirs(imgs_dir, exist_ok=True) @@ -107,11 +107,11 @@ async def _send( return data - async def send(self, message: MessageChain | None): + async def send(self, message: MessageChain | None) -> None: await WebChatMessageEvent._send(message, session_id=self.session_id) await super().send(MessageChain([])) - async def send_streaming(self, generator, use_fallback: bool = False): + async def send_streaming(self, generator, use_fallback: bool = False) -> None: final_data = "" reasoning_content = "" cid = self.session_id.split("!")[-1] diff --git a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py index 6c365cb3a..4824e2de9 100644 --- a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +++ b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py @@ -20,7 +20,7 @@ def get_or_create_back_queue(self, conversation_id: str) -> asyncio.Queue: self.back_queues[conversation_id] = asyncio.Queue() return self.back_queues[conversation_id] - def remove_queues(self, conversation_id: str): + def remove_queues(self, conversation_id: str) -> None: """Remove queues for the given conversation ID""" if conversation_id in self.queues: del self.queues[conversation_id] diff --git a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py index 214ac782c..22ad9d363 100644 --- a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +++ b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py @@ -43,7 +43,7 @@ def __init__( platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(platform_config, event_queue) + super().__init__(platform_config, platform_settings, event_queue) self._shutdown_event = None self.wxnewpass = None self.settings = platform_settings @@ -149,7 +149,7 @@ def load_credentials(self): logger.error(f"加载 WeChatPadPro 凭据失败: {e}") return None - def save_credentials(self): + def save_credentials(self) -> None: """将 auth_key 和 wxid 保存到文件。""" credentials = { "auth_key": self.auth_key, @@ -164,7 +164,7 @@ def save_credentials(self): except Exception as e: logger.error(f"保存 WeChatPadPro 凭据失败: {e}") - async def check_online_status(self): + async def check_online_status(self) -> bool | None: """检查 WeChatPadPro 设备是否在线。""" if not self.auth_key: return False @@ -218,7 +218,7 @@ def _extract_auth_key(self, data): return data[0] return None - async def generate_auth_key(self): + async def generate_auth_key(self) -> None: """生成授权码。""" url = f"{self.base_url}/admin/GenAuthKey1" params = {"key": self.admin_key} @@ -291,7 +291,7 @@ async def get_login_qr_code(self): logger.error(f"获取登录二维码时发生错误: {e}") return None - async def check_login_status(self): + async def check_login_status(self) -> bool: """循环检测扫码状态。 尝试 6 次后跳出循环,添加倒计时。 返回 True 如果登录成功,否则返回 False。 @@ -358,7 +358,7 @@ async def check_login_status(self): logger.warning("登录检测超过最大尝试次数,退出检测。") return False - async def connect_websocket(self): + async def connect_websocket(self) -> None: """建立 WebSocket 连接并处理接收到的消息。""" os.environ["no_proxy"] = f"localhost,127.0.0.1,{self.host}" ws_url = f"ws://{self.host}:{self.port}/ws/GetSyncMsg?key={self.auth_key}" @@ -398,7 +398,7 @@ async def connect_websocket(self): ) await asyncio.sleep(5) - async def handle_websocket_message(self, message: str | bytes): + async def handle_websocket_message(self, message: str | bytes) -> None: """处理从 WebSocket 接收到的消息。""" logger.debug(f"收到 WebSocket 消息: {message}") try: @@ -486,7 +486,7 @@ async def _process_chat_type( to_user_name: str, content: str, push_content: str, - ): + ) -> bool: """判断消息是群聊还是私聊,并设置 AstrBotMessage 的基本属性。""" if from_user_name == "weixin": return False @@ -638,7 +638,7 @@ async def _process_message_content( raw_message: dict, msg_type: int, content: str, - ): + ) -> None: """根据消息类型处理消息内容,填充 AstrBotMessage 的 message 列表。""" if msg_type == 1: # 文本消息 abm.message_str = content @@ -835,7 +835,7 @@ async def _process_message_content( else: logger.warning(f"收到未处理的消息类型: {msg_type}。") - async def terminate(self): + async def terminate(self) -> None: """终止一个平台的运行实例。""" logger.info("终止 WeChatPadPro 适配器。") try: @@ -854,7 +854,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: dummy_message_obj = AstrBotMessage() dummy_message_obj.session_id = session.session_id # 根据 session_id 判断消息类型 diff --git a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py index 08ab27013..8ecc4c64f 100644 --- a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +++ b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py @@ -32,12 +32,12 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, adapter: "WeChatPadProAdapter", # 传递适配器实例 - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.message_obj = message_obj # Save the full message object self.adapter = adapter # Save the adapter instance - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: async with aiohttp.ClientSession() as session: for comp in message.chain: await asyncio.sleep(1) @@ -66,7 +66,7 @@ async def send_streaming( await self.send(buffer) return await super().send_streaming(generator, use_fallback) - async def _send_image(self, session: aiohttp.ClientSession, comp: Image): + async def _send_image(self, session: aiohttp.ClientSession, comp: Image) -> None: b64 = await comp.convert_to_base64() raw = self._validate_base64(b64) b64c = self._compress_image(raw) @@ -78,7 +78,7 @@ async def _send_image(self, session: aiohttp.ClientSession, comp: Image): url = f"{self.adapter.base_url}/message/SendImageNewMessage" await self._post(session, url, payload) - async def _send_text(self, session: aiohttp.ClientSession, text: str): + async def _send_text(self, session: aiohttp.ClientSession, text: str) -> None: if ( self.message_obj.type == MessageType.GROUP_MESSAGE # 确保是群聊消息 and self.adapter.settings.get( @@ -114,7 +114,9 @@ async def _send_text(self, session: aiohttp.ClientSession, text: str): url = f"{self.adapter.base_url}/message/SendTextMessage" await self._post(session, url, payload) - async def _send_emoji(self, session: aiohttp.ClientSession, comp: WechatEmoji): + async def _send_emoji( + self, session: aiohttp.ClientSession, comp: WechatEmoji + ) -> None: payload = { "EmojiList": [ { @@ -127,7 +129,7 @@ async def _send_emoji(self, session: aiohttp.ClientSession, comp: WechatEmoji): url = f"{self.adapter.base_url}/message/SendEmojiMessage" await self._post(session, url, payload) - async def _send_voice(self, session: aiohttp.ClientSession, comp: Record): + async def _send_voice(self, session: aiohttp.ClientSession, comp: Record) -> None: record_path = await comp.convert_to_file_path() # 默认已经存在 data/temp 中 b64, duration = await audio_to_tencent_silk_base64(record_path) @@ -157,7 +159,7 @@ def _compress_image(data: bytes) -> str: # logger.info("图片处理完成!!!") return base64.b64encode(buf.getvalue()).decode() - async def _post(self, session, url, payload): + async def _post(self, session, url, payload) -> None: params = {"key": self.adapter.auth_key} try: async with session.post(url, params=params, json=payload) as resp: diff --git a/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py b/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py index 09924edb6..cf23c6b54 100644 --- a/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +++ b/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py @@ -20,7 +20,7 @@ def __init__( cached_images=None, raw_message: dict | None = None, downloader=None, - ): + ) -> None: self._xml = None self.content = content self.is_private_chat = is_private_chat diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index 44ed75117..417ea03b9 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -3,7 +3,7 @@ import sys import uuid from collections.abc import Awaitable, Callable -from typing import Any, cast +from typing import cast import quart from requests import Response @@ -39,7 +39,7 @@ class WecomServer: - def __init__(self, event_queue: asyncio.Queue, config: dict): + def __init__(self, event_queue: asyncio.Queue, config: dict) -> None: self.server = quart.Quart(__name__) self.port = int(cast(str, config.get("port"))) self.callback_server_host = config.get("callback_server_host", "0.0.0.0") @@ -123,7 +123,7 @@ async def handle_callback(self, request) -> str: return "success" - async def start_polling(self): + async def start_polling(self) -> None: logger.info( f"将在 {self.callback_server_host}:{self.port} 端口启动 企业微信 适配器。", ) @@ -133,7 +133,7 @@ async def start_polling(self): shutdown_trigger=self.shutdown_trigger, ) - async def shutdown_trigger(self): + async def shutdown_trigger(self) -> None: await self.shutdown_event.wait() @@ -145,7 +145,7 @@ def __init__( platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(platform_config, event_queue) + super().__init__(platform_config, platform_settings, event_queue) self.settingss = platform_settings self.client_self_id = uuid.uuid4().hex[:8] self.api_base_url = platform_config.get( @@ -182,7 +182,7 @@ def __init__( self.client.__setattr__("API_BASE_URL", self.api_base_url) - async def callback(msg: BaseMessage): + async def callback(msg: BaseMessage) -> None: if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event": def get_latest_msg_item() -> dict | None: @@ -214,7 +214,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: await super().send_by_session(session, message_chain) @override @@ -227,7 +227,7 @@ def meta(self) -> PlatformMetadata: ) @override - async def run(self): + async def run(self) -> None: loop = asyncio.get_event_loop() if self.kf_name: try: @@ -269,7 +269,9 @@ async def run(self): else: await self.server.start_polling() - async def webhook_callback(self, request: Any) -> Any: + async def webhook_callback( + self, request: quart.Request + ) -> quart.ResponseReturnValue: """统一 Webhook 回调入口""" # 根据请求方法分发到不同的处理函数 if request.method == "GET": @@ -403,7 +405,7 @@ async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None: return await self.handle_msg(abm) - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: message_event = WecomPlatformEvent( message_str=message.message_str, message_obj=message, @@ -416,7 +418,7 @@ async def handle_msg(self, message: AstrBotMessage): def get_client(self) -> WeChatClient: return self.client - async def terminate(self): + async def terminate(self) -> None: self.server.shutdown_event.set() try: await self.server.server.shutdown() diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py index 0b5dae272..865a14234 100644 --- a/astrbot/core/platform/sources/wecom/wecom_event.py +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -28,7 +28,7 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, client: WeChatClient, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client @@ -37,7 +37,7 @@ async def send_with_client( client: WeChatClient, message: MessageChain, user_name: str, - ): + ) -> None: pass async def split_plain(self, plain: str) -> list[str]: @@ -86,7 +86,7 @@ async def split_plain(self, plain: str) -> list[str]: return result - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: message_obj = self.message_obj is_wechat_kf = hasattr(self.client, "kf_message") diff --git a/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py b/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py index 2df09a763..260b950d1 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py @@ -14,6 +14,7 @@ import socket import struct import time +from typing import NoReturn from Crypto.Cipher import AES @@ -30,7 +31,7 @@ class FormatException(Exception): pass -def throw_exception(message, exception_class=FormatException): +def throw_exception(message, exception_class=FormatException) -> NoReturn: """My define raise exception function""" raise exception_class(message) @@ -145,7 +146,7 @@ class Prpcrypt: MIN_RANDOM_VALUE = 1000000000000000 # 最小值: 1000000000000000 (16位) RANDOM_RANGE = 9000000000000000 # 范围大小: 确保最大值为 9999999999999999 (16位) - def __init__(self, key): + def __init__(self, key) -> None: # self.key = base64.b64decode(key+"=") self.key = key # 设置加解密模式为AES的CBC模式 @@ -220,7 +221,7 @@ def get_random_str(self): class WXBizJsonMsgCrypt: # 构造函数 - def __init__(self, sToken, sEncodingAESKey, sReceiveId): + def __init__(self, sToken, sEncodingAESKey, sReceiveId) -> None: try: self.key = base64.b64decode(sEncodingAESKey + "=") assert len(self.key) == 32 diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py index 70581e7ea..35ee79425 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py @@ -11,6 +11,8 @@ from collections.abc import Awaitable, Callable from typing import Any +from quart import Request, ResponseReturnValue + from astrbot.api import logger from astrbot.api.event import MessageChain from astrbot.api.message_components import At, Image, Plain @@ -53,7 +55,7 @@ def __init__( self.callback = callback self.running_tasks = set() - async def listen_to_queue(self, session_id: str): + async def listen_to_queue(self, session_id: str) -> None: """监听特定会话的队列""" queue = self.queue_mgr.get_or_create_queue(session_id) while True: @@ -64,7 +66,7 @@ async def listen_to_queue(self, session_id: str): logger.error(f"处理会话 {session_id} 消息时发生错误: {e}") break - async def run(self): + async def run(self) -> None: """监控新会话队列并启动监听器""" monitored_sessions = set() @@ -104,7 +106,7 @@ def __init__( platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(platform_config, event_queue) + super().__init__(platform_config, platform_settings, event_queue) self.settings = platform_settings # 初始化配置参数 @@ -153,7 +155,7 @@ def __init__( self._handle_queued_message, ) - async def _handle_queued_message(self, data: dict): + async def _handle_queued_message(self, data: dict) -> None: """处理队列中的消息,类似webchat的callback""" try: abm = await self.convert_message(data) @@ -313,7 +315,7 @@ async def _enqueue_message( callback_params: dict[str, str], stream_id: str, session_id: str, - ): + ) -> None: """将消息放入队列进行异步处理""" input_queue = self.queue_mgr.get_or_create_queue(stream_id) _ = self.queue_mgr.get_or_create_back_queue(stream_id) @@ -417,7 +419,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: """通过会话发送消息""" # 企业微信智能机器人主要通过回调响应,这里记录日志 logger.info("会话发送消息: %s -> %s", session.session_id, message_chain) @@ -426,7 +428,7 @@ async def send_by_session( def run(self) -> Awaitable[Any]: """运行适配器,同时启动HTTP服务器和队列监听器""" - async def run_both(): + async def run_both() -> None: # 如果启用统一 webhook 模式,则不启动独立服务器 webhook_uuid = self.config.get("webhook_uuid") if self.unified_webhook_mode and webhook_uuid: @@ -445,7 +447,7 @@ async def run_both(): return run_both() - async def webhook_callback(self, request: Any) -> Any: + async def webhook_callback(self, request: Request) -> ResponseReturnValue: """统一 Webhook 回调入口""" # 根据请求方法分发到不同的处理函数 if request.method == "GET": @@ -453,7 +455,7 @@ async def webhook_callback(self, request: Any) -> Any: else: return await self.server.handle_callback(request) - async def terminate(self): + async def terminate(self) -> None: """终止适配器""" logger.info("企业微信智能机器人适配器正在关闭...") self.shutdown_event.set() @@ -463,7 +465,7 @@ def meta(self) -> PlatformMetadata: """获取平台元数据""" return self.metadata - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: """处理消息,创建消息事件并提交到事件队列""" try: message_event = WecomAIBotMessageEvent( diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py index 6c448a97e..97831fbb2 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py @@ -19,7 +19,7 @@ class WecomAIBotAPIClient: """企业微信智能机器人 API 客户端""" - def __init__(self, token: str, encoding_aes_key: str): + def __init__(self, token: str, encoding_aes_key: str) -> None: """初始化 API 客户端 Args: diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py index fd11d7ceb..90a9e363b 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py @@ -22,7 +22,7 @@ def __init__( session_id: str, api_client: WecomAIBotAPIClient, queue_mgr: WecomAIQueueMgr, - ): + ) -> None: """初始化消息事件 Args: @@ -90,7 +90,7 @@ async def _send( return data - async def send(self, message: MessageChain | None): + async def send(self, message: MessageChain | None) -> None: """发送消息""" raw = self.message_obj.raw_message assert isinstance(raw, dict), ( @@ -100,7 +100,7 @@ async def send(self, message: MessageChain | None): await WecomAIBotMessageEvent._send(message, stream_id, self.queue_mgr) await super().send(MessageChain([])) - async def send_streaming(self, generator, use_fallback=False): + async def send_streaming(self, generator, use_fallback=False) -> None: """流式发送消息,参考webchat的send_streaming设计""" final_data = "" raw = self.message_obj.raw_message diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py index 3a982bdf7..db6aa408e 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py @@ -52,7 +52,7 @@ def get_or_create_back_queue(self, session_id: str) -> asyncio.Queue: logger.debug(f"[WecomAI] 创建输出队列: {session_id}") return self.back_queues[session_id] - def remove_queues(self, session_id: str): + def remove_queues(self, session_id: str) -> None: """移除指定会话的所有队列 Args: @@ -95,7 +95,9 @@ def has_back_queue(self, session_id: str) -> bool: """ return session_id in self.back_queues - def set_pending_response(self, session_id: str, callback_params: dict[str, str]): + def set_pending_response( + self, session_id: str, callback_params: dict[str, str] + ) -> None: """设置待处理的响应参数 Args: @@ -121,7 +123,7 @@ def get_pending_response(self, session_id: str) -> dict[str, Any] | None: """ return self.pending_responses.get(session_id) - def cleanup_expired_responses(self, max_age_seconds: int = 300): + def cleanup_expired_responses(self, max_age_seconds: int = 300) -> None: """清理过期的待处理响应 Args: diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py index 5cbdd1130..80ec5179e 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py @@ -23,7 +23,7 @@ def __init__( port: int, api_client: WecomAIBotAPIClient, message_handler: Callable[[dict[str, Any], dict[str, str]], Any] | None = None, - ): + ) -> None: """初始化服务器 Args: @@ -43,7 +43,7 @@ def __init__( self.shutdown_event = asyncio.Event() - def _setup_routes(self): + def _setup_routes(self) -> None: """设置 Quart 路由""" # 使用 Quart 的 add_url_rule 方法添加路由 self.app.add_url_rule( @@ -162,7 +162,7 @@ async def handle_callback(self, request): logger.error("处理消息时发生异常: %s", e) return "内部服务器错误", 500 - async def start_server(self): + async def start_server(self) -> None: """启动服务器""" logger.info("启动企业微信智能机器人服务器,监听 %s:%d", self.host, self.port) @@ -176,11 +176,11 @@ async def start_server(self): logger.error("服务器运行异常: %s", e) raise - async def shutdown_trigger(self): + async def shutdown_trigger(self) -> None: """关闭触发器""" await self.shutdown_event.wait() - async def shutdown(self): + async def shutdown(self) -> None: """关闭服务器""" logger.info("企业微信智能机器人服务器正在关闭...") self.shutdown_event.set() diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py index 2828c0392..ba53db945 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py @@ -2,7 +2,7 @@ import sys import uuid from collections.abc import Awaitable, Callable -from typing import Any, cast +from typing import cast import quart from requests import Response @@ -35,7 +35,7 @@ class WeixinOfficialAccountServer: - def __init__(self, event_queue: asyncio.Queue, config: dict): + def __init__(self, event_queue: asyncio.Queue, config: dict) -> None: self.server = quart.Quart(__name__) self.port = int(cast(int | str, config.get("port"))) self.callback_server_host = config.get("callback_server_host", "0.0.0.0") @@ -129,7 +129,7 @@ async def handle_callback(self, request) -> str: return "success" - async def start_polling(self): + async def start_polling(self) -> None: logger.info( f"将在 {self.callback_server_host}:{self.port} 端口启动 微信公众平台 适配器。", ) @@ -139,7 +139,7 @@ async def start_polling(self): shutdown_trigger=self.shutdown_trigger, ) - async def shutdown_trigger(self): + async def shutdown_trigger(self) -> None: await self.shutdown_event.wait() @@ -153,7 +153,7 @@ def __init__( platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(platform_config, event_queue) + super().__init__(platform_config, platform_settings, event_queue) self.settingss = platform_settings self.client_self_id = uuid.uuid4().hex[:8] self.api_base_url = platform_config.get( @@ -218,7 +218,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: await super().send_by_session(session, message_chain) @override @@ -231,7 +231,7 @@ def meta(self) -> PlatformMetadata: ) @override - async def run(self): + async def run(self) -> None: # 如果启用统一 webhook 模式,则不启动独立服务器 webhook_uuid = self.config.get("webhook_uuid") if self.unified_webhook_mode and webhook_uuid: @@ -241,7 +241,9 @@ async def run(self): else: await self.server.start_polling() - async def webhook_callback(self, request: Any) -> Any: + async def webhook_callback( + self, request: quart.Request + ) -> quart.ResponseReturnValue: """统一 Webhook 回调入口""" # 根据请求方法分发到不同的处理函数 if request.method == "GET": @@ -330,7 +332,7 @@ async def convert_message( logger.info(f"abm: {abm}") await self.handle_msg(abm) - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: message_event = WeixinOfficialAccountPlatformEvent( message_str=message.message_str, message_obj=message, @@ -343,7 +345,7 @@ async def handle_msg(self, message: AstrBotMessage): def get_client(self) -> WeChatClient: return self.client - async def terminate(self): + async def terminate(self) -> None: self.server.shutdown_event.set() try: await self.server.server.shutdown() diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py index c1f137a41..995b16690 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py @@ -26,7 +26,7 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, client: WeChatClient, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client @@ -35,7 +35,7 @@ async def send_with_client( client: WeChatClient, message: MessageChain, user_name: str, - ): + ) -> None: pass async def split_plain(self, plain: str) -> list[str]: @@ -84,7 +84,7 @@ async def split_plain(self, plain: str) -> list[str]: return result - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: message_obj = self.message_obj active_send_mode = cast(dict, message_obj.raw_message).get( "active_send_mode", False diff --git a/astrbot/core/platform_message_history_mgr.py b/astrbot/core/platform_message_history_mgr.py index d6d524698..ad8bb44f6 100644 --- a/astrbot/core/platform_message_history_mgr.py +++ b/astrbot/core/platform_message_history_mgr.py @@ -3,7 +3,7 @@ class PlatformMessageHistoryManager: - def __init__(self, db_helper: BaseDatabase): + def __init__(self, db_helper: BaseDatabase) -> None: self.db = db_helper async def insert( @@ -40,7 +40,9 @@ async def get( history.reverse() return history - async def delete(self, platform_id: str, user_id: str, offset_sec: int = 86400): + async def delete( + self, platform_id: str, user_id: str, offset_sec: int = 86400 + ) -> None: """Delete platform message history records older than the specified offset.""" await self.db.delete_platform_message_offset( platform_id=platform_id, diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index a1a6039f4..fc27f0c1a 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -111,7 +111,7 @@ class ProviderRequest: model: str | None = None """模型名称,为 None 时使用提供商的默认模型""" - def __repr__(self): + def __repr__(self) -> str: return ( f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, " f"image_count={len(self.image_urls or [])}, " @@ -121,10 +121,10 @@ def __repr__(self): f"conversation_id={self.conversation.cid if self.conversation else 'N/A'}, " ) - def __str__(self): + def __str__(self) -> str: return self.__repr__() - def append_tool_calls_result(self, tool_calls_result: ToolCallsResult): + def append_tool_calls_result(self, tool_calls_result: ToolCallsResult) -> None: """添加工具调用结果到请求中""" if not self.tool_calls_result: self.tool_calls_result = [] @@ -132,7 +132,7 @@ def append_tool_calls_result(self, tool_calls_result: ToolCallsResult): self.tool_calls_result = [self.tool_calls_result] self.tool_calls_result.append(tool_calls_result) - def _print_friendly_context(self): + def _print_friendly_context(self) -> list[str] | str: """打印友好的消息上下文。将 image_url 的值替换为 """ if not self.contexts: return f"prompt: {self.prompt}, image_count: {len(self.image_urls or [])}" @@ -309,7 +309,7 @@ def __init__( is_chunk: bool = False, id: str | None = None, usage: TokenUsage | None = None, - ): + ) -> None: """初始化 LLMResponse Args: @@ -333,7 +333,7 @@ def __init__( tools_call_extra_content = {} self.role = role - self.completion_text = completion_text + self.completion_text = completion_text or "" self.result_chain = result_chain self.tools_call_args = tools_call_args self.tools_call_name = tools_call_name @@ -350,13 +350,13 @@ def __init__( self.usage = usage @property - def completion_text(self): + def completion_text(self) -> str: if self.result_chain: return self.result_chain.get_plain_text() return self._completion_text @completion_text.setter - def completion_text(self, value): + def completion_text(self, value: str) -> None: if self.result_chain: self.result_chain.chain = [ comp diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 7aad86bdd..f24a3b8a5 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -169,7 +169,7 @@ def remove_func(self, name: str) -> None: self.func_list.pop(i) break - def get_func(self, name) -> FuncTool | None: + def get_func(self, name: str) -> FuncTool | None: for f in self.func_list: if f.name == name: return f @@ -401,7 +401,9 @@ async def disable_mcp_server( f for f in self.func_list if not isinstance(f, MCPTool) ] - def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list: + def get_func_desc_openai_style( + self, omit_empty_parameter_field: bool = False + ) -> list: """获得 OpenAI API 风格的**已经激活**的工具描述""" tools = [f for f in self.func_list if f.active] toolset = ToolSet(tools) @@ -481,11 +483,11 @@ def activate_llm_tool(self, name: str, star_map: dict) -> bool: return False @property - def mcp_config_path(self): + def mcp_config_path(self) -> str: data_dir = get_astrbot_data_path() return os.path.join(data_dir, "mcp_server.json") - def load_mcp_config(self): + def load_mcp_config(self) -> dict: if not os.path.exists(self.mcp_config_path): # 配置文件不存在,创建默认配置 os.makedirs(os.path.dirname(self.mcp_config_path), exist_ok=True) @@ -500,7 +502,7 @@ def load_mcp_config(self): logger.error(f"加载 MCP 配置失败: {e}") return DEFAULT_MCP_CONFIG - def save_mcp_config(self, config: dict): + def save_mcp_config(self, config: dict) -> bool: try: with open(self.mcp_config_path, "w", encoding="utf-8") as f: json.dump(config, f, ensure_ascii=False, indent=4) @@ -575,10 +577,10 @@ async def sync_modelscope_mcp_servers(self, access_token: str) -> None: except Exception as e: raise Exception(f"同步 ModelScope MCP 服务器时发生错误: {e!s}") - def __str__(self): + def __str__(self) -> str: return str(self.func_list) - def __repr__(self): + def __repr__(self) -> str: return str(self.func_list) diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index b523a0661..65c766980 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -6,6 +6,7 @@ from astrbot.core import astrbot_config, logger, sp from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.db import BaseDatabase +from astrbot.core.db.po import Personality from ..persona_mgr import PersonaManager from .entities import ProviderType @@ -31,7 +32,7 @@ def __init__( acm: AstrBotConfigManager, db_helper: BaseDatabase, persona_mgr: PersonaManager, - ): + ) -> None: self.reload_lock = asyncio.Lock() self.resource_lock = asyncio.Lock() self.persona_mgr = persona_mgr @@ -82,7 +83,7 @@ def personas(self) -> list: return self.persona_mgr.personas_v3 @property - def selected_default_persona(self): + def selected_default_persona(self) -> Personality | None: """动态获取最新的默认选中 persona。已弃用,请使用 context.persona_mgr.get_default_persona_v3()""" return self.persona_mgr.selected_default_persona_v3 @@ -91,7 +92,7 @@ async def set_provider( provider_id: str, provider_type: ProviderType, umo: str | None = None, - ): + ) -> None: """设置提供商。 Args: @@ -153,7 +154,7 @@ async def get_provider_by_id(self, provider_id: str) -> Providers | None: return self.inst_map.get(provider_id) def get_using_provider( - self, provider_type: ProviderType, umo=None + self, provider_type: ProviderType, umo: str | None = None ) -> Providers | None: """获取正在使用的提供商实例。 @@ -212,7 +213,7 @@ def get_using_provider( return provider - async def initialize(self): + async def initialize(self) -> None: # 逐个初始化提供商 for provider_config in self.providers_config: try: @@ -276,7 +277,7 @@ async def initialize(self): # 初始化 MCP Client 连接 asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients") - def dynamic_import_provider(self, type: str): + def dynamic_import_provider(self, type: str) -> None: """动态导入提供商适配器模块 Args: @@ -402,7 +403,7 @@ def get_merged_provider_config(self, provider_config: dict) -> dict: pc = merged_config return pc - async def load_provider(self, provider_config: dict): + async def load_provider(self, provider_config: dict) -> None: # 如果 provider_source_id 存在且不为空,则从 provider_sources 中找到对应的配置并合并 provider_config = self.get_merged_provider_config(provider_config) @@ -553,7 +554,7 @@ async def load_provider(self, provider_config: dict): f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}", ) - async def reload(self, provider_config: dict): + async def reload(self, provider_config: dict) -> None: async with self.reload_lock: await self.terminate_provider(provider_config["id"]) if provider_config["enable"]: @@ -596,10 +597,10 @@ async def reload(self, provider_config: dict): f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。", ) - def get_insts(self): + def get_insts(self) -> list[Provider]: return self.provider_insts - async def terminate_provider(self, provider_id: str): + async def terminate_provider(self, provider_id: str) -> None: if provider_id in self.inst_map: logger.info( f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ...", @@ -635,7 +636,7 @@ async def terminate_provider(self, provider_id: str): async def delete_provider( self, provider_id: str | None = None, provider_source_id: str | None = None - ): + ) -> None: """Delete provider and/or provider source from config and terminate the instances. Config will be saved after deletion.""" async with self.resource_lock: # delete from config @@ -655,7 +656,7 @@ async def delete_provider( config.save_config() logger.info(f"Provider {target_prov_ids} 已从配置中删除。") - async def update_provider(self, origin_provider_id: str, new_config: dict): + async def update_provider(self, origin_provider_id: str, new_config: dict) -> None: """Update provider config and reload the instance. Config will be saved after update.""" async with self.resource_lock: npid = new_config.get("id", None) @@ -679,7 +680,7 @@ async def update_provider(self, origin_provider_id: str, new_config: dict): # reload instance await self.reload(new_config) - async def create_provider(self, new_config: dict): + async def create_provider(self, new_config: dict) -> None: """Add new provider config and load the instance. Config will be saved after addition.""" async with self.resource_lock: npid = new_config.get("id", None) @@ -695,7 +696,7 @@ async def create_provider(self, new_config: dict): # load instance await self.load_provider(new_config) - async def terminate(self): + async def terminate(self) -> None: for provider_inst in self.provider_insts: if hasattr(provider_inst, "terminate"): await provider_inst.terminate() # type: ignore diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 6fb6d8953..524f099e9 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -1,8 +1,8 @@ import abc import asyncio import os -from collections.abc import AsyncGenerator -from typing import TypeAlias, Union +from collections.abc import AsyncGenerator, Awaitable, Callable +from typing import NoReturn, TypeAlias, Union from astrbot.core.agent.message import ContentPart, Message from astrbot.core.agent.tool import ToolSet @@ -32,7 +32,7 @@ def __init__(self, provider_config: dict) -> None: self.model_name = "" self.provider_config = provider_config - def set_model(self, model_name: str): + def set_model(self, model_name: str) -> None: """Set the current model name""" self.model_name = model_name @@ -54,7 +54,7 @@ def meta(self) -> ProviderMeta: ) return meta - async def test(self): + async def test(self) -> None: """test the provider is a raises: @@ -84,7 +84,7 @@ def get_keys(self) -> list[str]: return keys or [""] @abc.abstractmethod - def set_key(self, key: str): + def set_key(self, key: str) -> NoReturn: raise NotImplementedError @abc.abstractmethod @@ -104,7 +104,7 @@ async def text_chat( tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None, model: str | None = None, extra_user_content_parts: list[ContentPart] | None = None, - **kwargs, + **kwargs: object, ) -> LLMResponse: """获得 LLM 的文本对话结果。会使用当前的模型进行对话。 @@ -157,7 +157,7 @@ async def text_chat_stream( yield None # type: ignore raise NotImplementedError() - async def pop_record(self, context: list): + async def pop_record(self, context: list) -> None: """弹出 context 第一条非系统提示词对话记录""" poped = 0 indexs_to_pop = [] @@ -188,7 +188,7 @@ def _ensure_message_to_dicts( return dicts - async def test(self, timeout: float = 45.0): + async def test(self, timeout: float = 45.0) -> None: await asyncio.wait_for( self.text_chat(prompt="REPLY `PONG` ONLY"), timeout=timeout, @@ -206,7 +206,7 @@ async def get_text(self, audio_url: str) -> str: """获取音频的文本""" raise NotImplementedError - async def test(self): + async def test(self) -> None: sample_audio_path = os.path.join( get_astrbot_path(), "samples", @@ -226,7 +226,7 @@ async def get_audio(self, text: str) -> str: """获取文本的音频,返回音频文件路径""" raise NotImplementedError - async def test(self): + async def test(self) -> None: await self.get_audio("hi") @@ -251,7 +251,7 @@ def get_dim(self) -> int: """获取向量的维度""" ... - async def test(self): + async def test(self) -> None: await self.get_embedding("astrbot") async def get_embeddings_batch( @@ -260,7 +260,7 @@ async def get_embeddings_batch( batch_size: int = 16, tasks_limit: int = 3, max_retries: int = 3, - progress_callback=None, + progress_callback: Callable[[int, int], Awaitable[None]] | None = None, ) -> list[list[float]]: """批量获取文本的向量,分批处理以节省内存 @@ -281,7 +281,7 @@ async def get_embeddings_batch( completed_count = 0 total_count = len(texts) - async def process_batch(batch_idx: int, batch_texts: list[str]): + async def process_batch(batch_idx: int, batch_texts: list[str]) -> None: nonlocal completed_count async with semaphore: for attempt in range(max_retries): @@ -338,7 +338,7 @@ async def rerank( """获取查询和文档的重排序分数""" ... - async def test(self): + async def test(self) -> None: result = await self.rerank("Apple", documents=["apple", "banana"]) if not result: raise Exception("Rerank provider test failed, no results returned") diff --git a/astrbot/core/provider/register.py b/astrbot/core/provider/register.py index 3ad83784e..3b927e129 100644 --- a/astrbot/core/provider/register.py +++ b/astrbot/core/provider/register.py @@ -1,8 +1,18 @@ +from collections.abc import Callable +from typing import TYPE_CHECKING, TypeVar + from astrbot.core import logger from .entities import ProviderMetaData, ProviderType from .func_tool_manager import FuncCall +if TYPE_CHECKING: + from .provider import AbstractProvider + + T = TypeVar("T", bound=AbstractProvider) +else: + T = TypeVar("T") + provider_registry: list[ProviderMetaData] = [] """维护了通过装饰器注册的 Provider""" provider_cls_map: dict[str, ProviderMetaData] = {} @@ -17,10 +27,10 @@ def register_provider_adapter( provider_type: ProviderType = ProviderType.CHAT_COMPLETION, default_config_tmpl: dict | None = None, provider_display_name: str | None = None, -): +) -> Callable[[type[T]], type[T]]: """用于注册平台适配器的带参装饰器""" - def decorator(cls): + def decorator(cls: type[T]) -> type[T]: if provider_type_name in provider_cls_map: raise ValueError( f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。", diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index 7ce36b0f5..12ccbae46 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -562,5 +562,5 @@ async def get_models(self) -> list[str]: models_str.append(model.id) return models_str - def set_key(self, key: str): + def set_key(self, key: str) -> None: self.chosen_api_key = key diff --git a/astrbot/core/provider/sources/azure_tts_source.py b/astrbot/core/provider/sources/azure_tts_source.py index 2ccf146ca..08180222a 100644 --- a/astrbot/core/provider/sources/azure_tts_source.py +++ b/astrbot/core/provider/sources/azure_tts_source.py @@ -21,7 +21,7 @@ class OTTSProvider: - def __init__(self, config: dict): + def __init__(self, config: dict) -> None: self.skey = config["OTTS_SKEY"] self.api_url = config["OTTS_URL"] self.auth_time_url = config["OTTS_AUTH_TIME"] @@ -48,7 +48,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self._client.aclose() self._client = None - async def _sync_time(self): + async def _sync_time(self) -> None: try: response = await self.client.get(self.auth_time_url) response.raise_for_status() @@ -103,7 +103,7 @@ async def get_audio(self, text: str, voice_params: dict) -> str: class AzureNativeProvider(TTSProvider): - def __init__(self, provider_config: dict, provider_settings: dict): + def __init__(self, provider_config: dict, provider_settings: dict) -> None: super().__init__(provider_config, provider_settings) self.subscription_key = provider_config.get( "azure_tts_subscription_key", @@ -149,7 +149,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self._client.aclose() self._client = None - async def _refresh_token(self): + async def _refresh_token(self) -> None: token_url = ( f"https://{self.region}.api.cognitive.microsoft.com/sts/v1.0/issuetoken" ) @@ -195,7 +195,7 @@ async def get_audio(self, text: str) -> str: @register_provider_adapter("azure_tts", "Azure TTS", ProviderType.TEXT_TO_SPEECH) class AzureTTSProvider(TTSProvider): - def __init__(self, provider_config: dict, provider_settings: dict): + def __init__(self, provider_config: dict, provider_settings: dict) -> None: super().__init__(provider_config, provider_settings) key_value = provider_config.get("azure_tts_subscription_key", "") self.provider = self._parse_provider(key_value, provider_config) diff --git a/astrbot/core/provider/sources/fishaudio_tts_api_source.py b/astrbot/core/provider/sources/fishaudio_tts_api_source.py index e246e00ed..70eabd289 100644 --- a/astrbot/core/provider/sources/fishaudio_tts_api_source.py +++ b/astrbot/core/provider/sources/fishaudio_tts_api_source.py @@ -63,7 +63,7 @@ def __init__( self.headers = { "Authorization": f"Bearer {self.chosen_api_key}", } - self.set_model(provider_config.get("model", None)) + self.set_model(provider_config.get("model", "")) async def _get_reference_id_by_character(self, character: str) -> str | None: """获取角色的reference_id diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 97c072d0e..2a6001038 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -834,7 +834,7 @@ def get_current_key(self) -> str: def get_keys(self) -> list[str]: return self.api_keys - def set_key(self, key): + def set_key(self, key) -> None: self.chosen_api_key = key self._init_client() @@ -916,5 +916,5 @@ async def encode_image_bs64(self, image_url: str) -> str: image_bs64 = base64.b64encode(f.read()).decode("utf-8") return "data:image/jpeg;base64," + image_bs64 - async def terminate(self): + async def terminate(self) -> None: logger.info("Google GenAI 适配器已终止。") diff --git a/astrbot/core/provider/sources/gsv_selfhosted_source.py b/astrbot/core/provider/sources/gsv_selfhosted_source.py index 7f8d39eac..029f6af10 100644 --- a/astrbot/core/provider/sources/gsv_selfhosted_source.py +++ b/astrbot/core/provider/sources/gsv_selfhosted_source.py @@ -39,7 +39,7 @@ def __init__( self.timeout = provider_config.get("timeout", 60) self._session: aiohttp.ClientSession | None = None - async def initialize(self): + async def initialize(self) -> None: """异步初始化:在 ProviderManager 中被调用""" self._session = aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=self.timeout), @@ -85,7 +85,7 @@ async def _make_request( logger.error(f"[GSV TTS] 请求 {endpoint} 最终失败:{e}") raise - async def _set_model_weights(self): + async def _set_model_weights(self) -> None: """设置模型路径""" try: if self.gpt_weights_path: @@ -144,7 +144,7 @@ def build_synthesis_params(self, text: str) -> dict: # TODO: 在此处添加情绪分析,例如 params["emotion"] = detect_emotion(text) return params - async def terminate(self): + async def terminate(self) -> None: """终止释放资源:在 ProviderManager 中被调用""" if self._session and not self._session.closed: await self._session.close() diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 2544782f4..1015fc8da 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -617,7 +617,7 @@ def get_current_key(self) -> str: def get_keys(self) -> list[str]: return self.api_keys - def set_key(self, key): + def set_key(self, key) -> None: self.client.api_key = key async def assemble_context( diff --git a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py index a41bd72fd..965b83a5a 100644 --- a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py +++ b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py @@ -37,7 +37,7 @@ def __init__( self.model = None self.is_emotion = provider_config.get("is_emotion", False) - async def initialize(self): + async def initialize(self) -> None: logger.info("下载或者加载 SenseVoice 模型中,这可能需要一些时间 ...") # 将模型加载放到线程池中执行 @@ -52,7 +52,7 @@ async def get_timestamped_path(self) -> str: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") return os.path.join("data", "temp", f"{timestamp}") - async def _is_silk_file(self, file_path): + async def _is_silk_file(self, file_path) -> bool: silk_header = b"SILK" with open(file_path, "rb") as f: file_header = f.read(8) diff --git a/astrbot/core/provider/sources/whisper_api_source.py b/astrbot/core/provider/sources/whisper_api_source.py index fa69206ef..9532ca5b9 100644 --- a/astrbot/core/provider/sources/whisper_api_source.py +++ b/astrbot/core/provider/sources/whisper_api_source.py @@ -38,7 +38,7 @@ def __init__( self.set_model(provider_config["model"]) - async def _get_audio_format(self, file_path): + async def _get_audio_format(self, file_path) -> str | None: # 定义要检测的头部字节 silk_header = b"SILK" amr_header = b"#!AMR" diff --git a/astrbot/core/provider/sources/whisper_selfhosted_source.py b/astrbot/core/provider/sources/whisper_selfhosted_source.py index a14f93f14..d5d2dc340 100644 --- a/astrbot/core/provider/sources/whisper_selfhosted_source.py +++ b/astrbot/core/provider/sources/whisper_selfhosted_source.py @@ -30,7 +30,7 @@ def __init__( self.set_model(provider_config["model"]) self.model = None - async def initialize(self): + async def initialize(self) -> None: loop = asyncio.get_event_loop() logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...") self.model = await loop.run_in_executor( @@ -40,7 +40,7 @@ async def initialize(self): ) logger.info("Whisper 模型加载完成。") - async def _is_silk_file(self, file_path): + async def _is_silk_file(self, file_path) -> bool: silk_header = b"SILK" with open(file_path, "rb") as f: file_header = f.read(8) diff --git a/astrbot/core/provider/sources/xinference_rerank_source.py b/astrbot/core/provider/sources/xinference_rerank_source.py index 960408550..9c3a77c15 100644 --- a/astrbot/core/provider/sources/xinference_rerank_source.py +++ b/astrbot/core/provider/sources/xinference_rerank_source.py @@ -37,7 +37,7 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: self.model: AsyncRESTfulRerankModelHandle | None = None self.model_uid = None - async def initialize(self): + async def initialize(self) -> None: if self.api_key: logger.info("Xinference Rerank: Using API key for authentication.") self.client = Client(self.base_url, api_key=self.api_key) diff --git a/astrbot/core/provider/sources/xinference_stt_provider.py b/astrbot/core/provider/sources/xinference_stt_provider.py index 4b947b3f0..a3e5be352 100644 --- a/astrbot/core/provider/sources/xinference_stt_provider.py +++ b/astrbot/core/provider/sources/xinference_stt_provider.py @@ -40,7 +40,7 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: self.client = None self.model_uid = None - async def initialize(self): + async def initialize(self) -> None: if self.api_key: logger.info("Xinference STT: Using API key for authentication.") self.client = Client(self.base_url, api_key=self.api_key) diff --git a/astrbot/core/star/__init__.py b/astrbot/core/star/__init__.py index c474962c5..570ab26e0 100644 --- a/astrbot/core/star/__init__.py +++ b/astrbot/core/star/__init__.py @@ -15,11 +15,11 @@ class Star(CommandParserMixin, PluginKVStoreMixin): author: str name: str - def __init__(self, context: Context, config: dict | None = None): + def __init__(self, context: Context, config: dict | None = None) -> None: StarTools.initialize(context) self.context = context - def __init_subclass__(cls, **kwargs): + def __init_subclass__(cls, **kwargs: object) -> None: super().__init_subclass__(**kwargs) if not star_map.get(cls.__module__): metadata = StarMetadata( @@ -32,7 +32,7 @@ def __init_subclass__(cls, **kwargs): star_map[cls.__module__].star_cls_type = cls star_map[cls.__module__].module_path = cls.__module__ - async def text_to_image(self, text: str, return_url=True) -> str: + async def text_to_image(self, text: str, return_url: bool = True) -> str: """将文本转换为图片""" return await html_renderer.render_t2i( text, @@ -44,7 +44,7 @@ async def html_render( self, tmpl: str, data: dict, - return_url=True, + return_url: bool = True, options: dict | None = None, ) -> str: """渲染 HTML""" @@ -55,13 +55,13 @@ async def html_render( options=options, ) - async def initialize(self): + async def initialize(self) -> None: """当插件被激活时会调用这个方法""" - async def terminate(self): + async def terminate(self) -> None: """当插件被禁用、重载插件时会调用这个方法""" - def __del__(self): + def __del__(self) -> None: """[Deprecated] 当插件被禁用、重载插件时会调用这个方法""" diff --git a/astrbot/core/star/config.py b/astrbot/core/star/config.py index a9af974c5..e556fe360 100644 --- a/astrbot/core/star/config.py +++ b/astrbot/core/star/config.py @@ -22,7 +22,13 @@ def load_config(namespace: str) -> dict | bool: return ret -def put_config(namespace: str, name: str, key: str, value, description: str): +def put_config( + namespace: str, + name: str, + key: str, + value: str | int | float | bool | list[object], + description: str, +) -> None: """将配置项写入以namespace为名字的配置文件,如果key不存在于目标配置文件中。当前 value 仅支持 str, int, float, bool, list 类型(暂不支持 dict)。 namespace: str, 配置的唯一识别符,也就是配置文件的名字。 name: str, 配置项的显示名字。 @@ -64,7 +70,9 @@ def put_config(namespace: str, name: str, key: str, value, description: str): f.flush() -def update_config(namespace: str, key: str, value): +def update_config( + namespace: str, key: str, value: str | int | float | bool | list[object] +) -> None: """更新配置文件中的配置项。 namespace: str, 配置的唯一识别符,也就是配置文件的名字。 key: str, 配置项的键。 diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index a64d2a9ee..49d250d66 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -1,7 +1,7 @@ import logging from asyncio import Queue from collections.abc import Awaitable, Callable -from typing import Any +from typing import Any, cast from deprecated import deprecated @@ -65,7 +65,7 @@ def __init__( persona_manager: PersonaManager, astrbot_config_mgr: AstrBotConfigManager, knowledge_base_manager: KnowledgeBaseManager, - ): + ) -> None: self._event_queue = event_queue """事件队列。消息平台通过事件队列传递消息事件。""" self._config = config @@ -89,7 +89,7 @@ async def llm_generate( tools: ToolSet | None = None, system_prompt: str | None = None, contexts: list[Message] | None = None, - **kwargs: Any, + **kwargs: object, ) -> LLMResponse: """Call the LLM to generate a response. The method will not automatically execute tool calls. If you want to use tool calls, please use `tool_loop_agent()`. @@ -117,7 +117,7 @@ async def llm_generate( func_tool=tools, contexts=contexts, system_prompt=system_prompt, - **kwargs, + **cast(dict[str, Any], kwargs), ) return llm_resp @@ -133,7 +133,7 @@ async def tool_loop_agent( contexts: list[Message] | None = None, max_steps: int = 30, tool_call_timeout: int = 60, - **kwargs: Any, + **kwargs, ) -> LLMResponse: """Run an agent loop that allows the LLM to call tools iteratively until a final answer is produced. If you do not pass the agent_context parameter, the method will recreate a new agent context. @@ -173,8 +173,18 @@ async def tool_loop_agent( if not prov or not isinstance(prov, Provider): raise ProviderNotFoundError(f"Provider {chat_provider_id} not found") - agent_hooks = kwargs.get("agent_hooks") or BaseAgentRunHooks[AstrAgentContext]() - agent_context = kwargs.get("agent_context") + _kw = cast(dict[str, Any], kwargs) + agent_hooks_obj = _kw.get("agent_hooks") + if isinstance(agent_hooks_obj, BaseAgentRunHooks): + agent_hooks = cast(BaseAgentRunHooks[Any], agent_hooks_obj) + else: + agent_hooks = BaseAgentRunHooks[AstrAgentContext]() + + agent_context_obj = _kw.get("agent_context") + if isinstance(agent_context_obj, AstrAgentContext): + agent_context = agent_context_obj + else: + agent_context = None context_ = [] for msg in contexts or []: @@ -415,7 +425,7 @@ def register_web_api( view_handler: Awaitable, methods: list, desc: str, - ): + ) -> None: for idx, api in enumerate(self.registered_web_apis): if api[0] == route and methods == api[2]: self.registered_web_apis[idx] = (route, view_handler, methods, desc) @@ -465,7 +475,7 @@ def get_db(self) -> BaseDatabase: """获取 AstrBot 数据库。""" return self._db - def register_provider(self, provider: Provider): + def register_provider(self, provider: Provider) -> None: """注册一个 LLM Provider(Chat_Completion 类型)。""" self.provider_manager.provider_insts.append(provider) @@ -508,9 +518,9 @@ def register_commands( desc: str, priority: int, awaitable: Callable[..., Awaitable[Any]], - use_regex=False, - ignore_prefix=False, - ): + use_regex: bool = False, + ignore_prefix: bool = False, + ) -> None: """注册一个命令。 [Deprecated] 推荐使用装饰器注册指令。该方法将在未来的版本中被移除。 @@ -539,6 +549,6 @@ def register_commands( ) star_handlers_registry.append(md) - def register_task(self, task: Awaitable, desc: str): + def register_task(self, task: Awaitable, desc: str) -> None: """[DEPRECATED]注册一个异步任务。""" self._register_tasks.append(task) diff --git a/astrbot/core/star/filter/command.py b/astrbot/core/star/filter/command.py index 51ad5f089..e42045e1e 100755 --- a/astrbot/core/star/filter/command.py +++ b/astrbot/core/star/filter/command.py @@ -16,7 +16,7 @@ class GreedyStr(str): """标记指令完成其他参数接收后的所有剩余文本。""" -def unwrap_optional(annotation) -> tuple: +def unwrap_optional(annotation: object) -> tuple: """去掉 Optional[T] / Union[T, None] / T|None,返回 T""" args = typing.get_args(annotation) non_none_args = [a for a in args if a is not type(None)] @@ -37,7 +37,7 @@ def __init__( alias: set | None = None, handler_md: StarHandlerMetadata | None = None, parent_command_names: list[str] | None = None, - ): + ) -> None: self.command_name = command_name self.alias = alias if alias else set() self._original_command_name = command_name @@ -51,7 +51,7 @@ def __init__( # Cache for complete command names list self._cmpl_cmd_names: list | None = None - def print_types(self): + def print_types(self) -> str: parts = [] for k, v in self.handler_params.items(): if isinstance(v, type): @@ -63,7 +63,7 @@ def print_types(self): result = "".join(parts).rstrip(",") return result - def init_handler_md(self, handle_md: StarHandlerMetadata): + def init_handler_md(self, handle_md: StarHandlerMetadata) -> None: self.handler_md = handle_md signature = inspect.signature(self.handler_md.handler) self.handler_params = {} # 参数名 -> 参数类型,如果有默认值则为默认值 @@ -81,7 +81,7 @@ def init_handler_md(self, handle_md: StarHandlerMetadata): def get_handler_md(self) -> StarHandlerMetadata: return self.handler_md - def add_custom_filter(self, custom_filter: CustomFilter): + def add_custom_filter(self, custom_filter: CustomFilter) -> None: self.custom_filter_list.append(custom_filter) def custom_filter_ok(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: @@ -172,7 +172,7 @@ def validate_and_convert_params( ) return result - def get_complete_command_names(self): + def get_complete_command_names(self) -> list[str]: if self._cmpl_cmd_names is not None: return self._cmpl_cmd_names self._cmpl_cmd_names = [ diff --git a/astrbot/core/star/filter/command_group.py b/astrbot/core/star/filter/command_group.py index 4cbd2c007..52fb6a452 100755 --- a/astrbot/core/star/filter/command_group.py +++ b/astrbot/core/star/filter/command_group.py @@ -15,7 +15,7 @@ def __init__( group_name: str, alias: set | None = None, parent_group: CommandGroupFilter | None = None, - ): + ) -> None: self.group_name = group_name self.alias = alias if alias else set() self._original_group_name = group_name @@ -29,10 +29,10 @@ def __init__( def add_sub_command_filter( self, sub_command_filter: CommandFilter | CommandGroupFilter, - ): + ) -> None: self.sub_command_filters.append(sub_command_filter) - def add_custom_filter(self, custom_filter: CustomFilter): + def add_custom_filter(self, custom_filter: CustomFilter) -> None: self.custom_filter_list.append(custom_filter) def get_complete_command_names(self) -> list[str]: diff --git a/astrbot/core/star/filter/custom_filter.py b/astrbot/core/star/filter/custom_filter.py index d57b5cac0..58b07804a 100644 --- a/astrbot/core/star/filter/custom_filter.py +++ b/astrbot/core/star/filter/custom_filter.py @@ -7,19 +7,19 @@ class CustomFilterMeta(ABCMeta): - def __and__(cls, other): + def __and__(cls, other: type) -> "CustomFilter": if not issubclass(other, CustomFilter): raise TypeError("Operands must be subclasses of CustomFilter.") return CustomFilterAnd(cls(), other()) - def __or__(cls, other): + def __or__(cls, other: type) -> "CustomFilter": if not issubclass(other, CustomFilter): raise TypeError("Operands must be subclasses of CustomFilter.") return CustomFilterOr(cls(), other()) class CustomFilter(HandlerFilter, metaclass=CustomFilterMeta): - def __init__(self, raise_error: bool = True, **kwargs): + def __init__(self, raise_error: bool = True, **kwargs: object) -> None: self.raise_error = raise_error @abstractmethod @@ -27,17 +27,17 @@ def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: """一个用于重写的自定义Filter""" raise NotImplementedError - def __or__(self, other): + def __or__(self, other: "CustomFilter") -> "CustomFilter": return CustomFilterOr(self, other) - def __and__(self, other): + def __and__(self, other: "CustomFilter") -> "CustomFilter": return CustomFilterAnd(self, other) class CustomFilterOr(CustomFilter): - def __init__(self, filter1: CustomFilter, filter2: CustomFilter): + def __init__(self, filter1: CustomFilter, filter2: CustomFilter) -> None: super().__init__() - if not isinstance(filter1, (CustomFilter, CustomFilterAnd, CustomFilterOr)): + if not isinstance(filter1, CustomFilter): raise ValueError( "CustomFilter lass can only operate with other CustomFilter.", ) @@ -49,9 +49,9 @@ def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: class CustomFilterAnd(CustomFilter): - def __init__(self, filter1: CustomFilter, filter2: CustomFilter): + def __init__(self, filter1: CustomFilter, filter2: CustomFilter) -> None: super().__init__() - if not isinstance(filter1, (CustomFilter, CustomFilterAnd, CustomFilterOr)): + if not isinstance(filter1, CustomFilter): raise ValueError( "CustomFilter lass can only operate with other CustomFilter.", ) diff --git a/astrbot/core/star/filter/event_message_type.py b/astrbot/core/star/filter/event_message_type.py index 7f350bd38..604fc3ed3 100644 --- a/astrbot/core/star/filter/event_message_type.py +++ b/astrbot/core/star/filter/event_message_type.py @@ -22,7 +22,7 @@ class EventMessageType(enum.Flag): class EventMessageTypeFilter(HandlerFilter): - def __init__(self, event_message_type: EventMessageType): + def __init__(self, event_message_type: EventMessageType) -> None: self.event_message_type = event_message_type def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: diff --git a/astrbot/core/star/filter/permission.py b/astrbot/core/star/filter/permission.py index 3374544c2..a70299fa9 100644 --- a/astrbot/core/star/filter/permission.py +++ b/astrbot/core/star/filter/permission.py @@ -14,7 +14,9 @@ class PermissionType(enum.Flag): class PermissionTypeFilter(HandlerFilter): - def __init__(self, permission_type: PermissionType, raise_error: bool = True): + def __init__( + self, permission_type: PermissionType, raise_error: bool = True + ) -> None: self.permission_type = permission_type self.raise_error = raise_error diff --git a/astrbot/core/star/filter/platform_adapter_type.py b/astrbot/core/star/filter/platform_adapter_type.py index 1182ff9b0..d5ff6146a 100644 --- a/astrbot/core/star/filter/platform_adapter_type.py +++ b/astrbot/core/star/filter/platform_adapter_type.py @@ -58,7 +58,7 @@ class PlatformAdapterType(enum.Flag): class PlatformAdapterTypeFilter(HandlerFilter): - def __init__(self, platform_adapter_type_or_str: PlatformAdapterType | str): + def __init__(self, platform_adapter_type_or_str: PlatformAdapterType | str) -> None: if isinstance(platform_adapter_type_or_str, str): self.platform_type = ADAPTER_NAME_2_TYPE.get(platform_adapter_type_or_str) else: diff --git a/astrbot/core/star/filter/regex.py b/astrbot/core/star/filter/regex.py index cd5bebdb4..abec5a488 100644 --- a/astrbot/core/star/filter/regex.py +++ b/astrbot/core/star/filter/regex.py @@ -10,7 +10,7 @@ class RegexFilter(HandlerFilter): """正则表达式过滤器""" - def __init__(self, regex: str): + def __init__(self, regex: str) -> None: self.regex_str = regex self.regex = re.compile(regex) diff --git a/astrbot/core/star/register/star.py b/astrbot/core/star/register/star.py index 617cd5ff7..c424f19f4 100644 --- a/astrbot/core/star/register/star.py +++ b/astrbot/core/star/register/star.py @@ -1,9 +1,13 @@ import warnings +from collections.abc import Callable +from typing import TypeVar from astrbot.core.star import StarMetadata, star_map _warned_register_star = False +T = TypeVar("T") + def register_star( name: str, @@ -11,7 +15,7 @@ def register_star( desc: str, version: str, repo: str | None = None, -): +) -> Callable[[type[T]], type[T]]: """注册一个插件(Star)。 [DEPRECATED] 该装饰器已废弃,将在未来版本中移除。 @@ -44,7 +48,7 @@ class MyPlugin(star.Star): stacklevel=2, ) - def decorator(cls): + def decorator(cls: type[T]) -> type[T]: if not star_map.get(cls.__module__): metadata = StarMetadata( name=name, diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py index 085414cd4..8eb5b7aba 100644 --- a/astrbot/core/star/register/star_handler.py +++ b/astrbot/core/star/register/star_handler.py @@ -155,9 +155,7 @@ def register_custom_filter(custom_type_filter, *args, **kwargs): def decorator(awaitable): # 裸指令,子指令与指令组的区分,指令组会因为标记跳过wake。 - if ( - not add_to_event_filters and isinstance(awaitable, RegisteringCommandable) - ) or (add_to_event_filters and isinstance(awaitable, RegisteringCommandable)): + if isinstance(awaitable, RegisteringCommandable): # 指令组 与 根指令组,添加到本层的grouphandle中一起判断 awaitable.parent_group.add_custom_filter(custom_filter) else: @@ -250,7 +248,7 @@ class RegisteringCommandable: command: Callable[..., Callable[..., None]] = register_command custom_filter: Callable[..., Callable[..., Any]] = register_custom_filter - def __init__(self, parent_group: CommandGroupFilter): + def __init__(self, parent_group: CommandGroupFilter) -> None: self.parent_group = parent_group @@ -516,7 +514,7 @@ def llm_tool(self, *args, **kwargs): kwargs["registering_agent"] = self return register_llm_tool(*args, **kwargs) - def __init__(self, agent: Agent[AstrAgentContext]): + def __init__(self, agent: Agent[AstrAgentContext]) -> None: self._agent = agent diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py index f36acedff..a59e14e2a 100644 --- a/astrbot/core/star/star_handler.py +++ b/astrbot/core/star/star_handler.py @@ -1,7 +1,7 @@ from __future__ import annotations import enum -from collections.abc import AsyncGenerator, Awaitable, Callable +from collections.abc import AsyncGenerator, Awaitable, Callable, Iterator from dataclasses import dataclass, field from typing import Any, Generic, Literal, TypeVar, overload @@ -12,11 +12,11 @@ class StarHandlerRegistry(Generic[T]): - def __init__(self): + def __init__(self) -> None: self.star_handlers_map: dict[str, StarHandlerMetadata] = {} self._handlers: list[StarHandlerMetadata] = [] - def append(self, handler: StarHandlerMetadata): + def append(self, handler: StarHandlerMetadata) -> None: """添加一个 Handler,并保持按优先级有序""" if "priority" not in handler.extras_configs: handler.extras_configs["priority"] = 0 @@ -25,7 +25,7 @@ def append(self, handler: StarHandlerMetadata): self._handlers.append(handler) self._handlers.sort(key=lambda h: -h.extras_configs["priority"]) - def _print_handlers(self): + def _print_handlers(self) -> None: for handler in self._handlers: print(handler.handler_full_name) @@ -33,7 +33,7 @@ def _print_handlers(self): def get_handlers_by_event_type( self, event_type: Literal[EventType.OnAstrBotLoadedEvent], - only_activated=True, + only_activated: bool = True, plugins_name: list[str] | None = None, ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ... @@ -41,7 +41,7 @@ def get_handlers_by_event_type( def get_handlers_by_event_type( self, event_type: Literal[EventType.OnPlatformLoadedEvent], - only_activated=True, + only_activated: bool = True, plugins_name: list[str] | None = None, ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ... @@ -49,7 +49,7 @@ def get_handlers_by_event_type( def get_handlers_by_event_type( self, event_type: Literal[EventType.AdapterMessageEvent], - only_activated=True, + only_activated: bool = True, plugins_name: list[str] | None = None, ) -> list[ StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]] @@ -59,7 +59,7 @@ def get_handlers_by_event_type( def get_handlers_by_event_type( self, event_type: Literal[EventType.OnLLMRequestEvent], - only_activated=True, + only_activated: bool = True, plugins_name: list[str] | None = None, ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ... @@ -67,7 +67,7 @@ def get_handlers_by_event_type( def get_handlers_by_event_type( self, event_type: Literal[EventType.OnLLMResponseEvent], - only_activated=True, + only_activated: bool = True, plugins_name: list[str] | None = None, ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ... @@ -75,7 +75,7 @@ def get_handlers_by_event_type( def get_handlers_by_event_type( self, event_type: Literal[EventType.OnDecoratingResultEvent], - only_activated=True, + only_activated: bool = True, plugins_name: list[str] | None = None, ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ... @@ -83,7 +83,7 @@ def get_handlers_by_event_type( def get_handlers_by_event_type( self, event_type: Literal[EventType.OnCallingFuncToolEvent], - only_activated=True, + only_activated: bool = True, plugins_name: list[str] | None = None, ) -> list[ StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]] @@ -93,7 +93,7 @@ def get_handlers_by_event_type( def get_handlers_by_event_type( self, event_type: Literal[EventType.OnAfterMessageSentEvent], - only_activated=True, + only_activated: bool = True, plugins_name: list[str] | None = None, ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ... @@ -101,7 +101,7 @@ def get_handlers_by_event_type( def get_handlers_by_event_type( self, event_type: EventType, - only_activated=True, + only_activated: bool = True, plugins_name: list[str] | None = None, ) -> list[ StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]] @@ -110,7 +110,7 @@ def get_handlers_by_event_type( def get_handlers_by_event_type( self, event_type: EventType, - only_activated=True, + only_activated: bool = True, plugins_name: list[str] | None = None, ) -> list[StarHandlerMetadata]: handlers = [] @@ -156,18 +156,18 @@ def get_handlers_by_module_name( if handler.handler_module_path == module_name ] - def clear(self): + def clear(self) -> None: self.star_handlers_map.clear() self._handlers.clear() - def remove(self, handler: StarHandlerMetadata): + def remove(self, handler: StarHandlerMetadata) -> None: self.star_handlers_map.pop(handler.handler_full_name, None) self._handlers = [h for h in self._handlers if h != handler] - def __iter__(self): + def __iter__(self) -> Iterator: return iter(self._handlers) - def __len__(self): + def __len__(self) -> int: return len(self._handlers) @@ -225,7 +225,7 @@ class StarHandlerMetadata(Generic[H]): enabled: bool = True - def __lt__(self, other: StarHandlerMetadata): + def __lt__(self, other: StarHandlerMetadata) -> bool: """定义小于运算符以支持优先队列""" return self.extras_configs.get("priority", 0) < other.extras_configs.get( "priority", diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index cf3ab0698..0a80177d3 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -11,6 +11,7 @@ from types import ModuleType import yaml +from watchfiles.main import FileChange from astrbot.core import logger, pip_installer, sp from astrbot.core.agent.handoff import FunctionTool, HandoffTool @@ -39,7 +40,7 @@ class PluginManager: - def __init__(self, context: Context, config: AstrBotConfig): + def __init__(self, context: Context, config: AstrBotConfig) -> None: self.updator = PluginUpdator() self.context = context @@ -64,7 +65,7 @@ def __init__(self, context: Context, config: AstrBotConfig): if os.getenv("ASTRBOT_RELOAD", "0") == "1": asyncio.create_task(self._watch_plugins_changes()) - async def _watch_plugins_changes(self): + async def _watch_plugins_changes(self) -> None: """监视插件文件变化""" try: async for changes in awatch( @@ -81,7 +82,7 @@ async def _watch_plugins_changes(self): logger.error(f"插件热重载监视任务异常: {e!s}") logger.error(traceback.format_exc()) - async def _handle_file_changes(self, changes): + async def _handle_file_changes(self, changes: set[FileChange]) -> None: """处理文件变化""" logger.info(f"检测到文件变化: {changes}") plugins_to_check = [] @@ -117,7 +118,7 @@ async def _handle_file_changes(self, changes): break @staticmethod - def _get_classes(arg: ModuleType): + def _get_classes(arg: ModuleType) -> list[str]: """获取指定模块(可以理解为一个 python 文件)下所有的类""" classes = [] clsmembers = inspect.getmembers(arg, inspect.isclass) @@ -128,7 +129,7 @@ def _get_classes(arg: ModuleType): return classes @staticmethod - def _get_modules(path): + def _get_modules(path: str) -> list: modules = [] dirs = os.listdir(path) @@ -165,7 +166,9 @@ def _get_plugin_modules(self) -> list[dict]: plugins.extend(_p) return plugins - async def _check_plugin_dept_update(self, target_plugin: str | None = None): + async def _check_plugin_dept_update( + self, target_plugin: str | None = None + ) -> bool | None: """检查插件的依赖 如果 target_plugin 为 None,则检查所有插件的依赖 """ @@ -189,7 +192,9 @@ async def _check_plugin_dept_update(self, target_plugin: str | None = None): logger.error(f"更新插件 {p} 的依赖失败。Code: {e!s}") @staticmethod - def _load_plugin_metadata(plugin_path: str, plugin_obj=None) -> StarMetadata | None: + def _load_plugin_metadata( + plugin_path: str, plugin_obj: object | None = None + ) -> StarMetadata | None: """先寻找 metadata.yaml 文件,如果不存在,则使用插件对象的 info() 函数获取元数据。 Notes: 旧版本 AstrBot 插件可能使用的是 info() 函数来获取元数据。 @@ -207,7 +212,7 @@ def _load_plugin_metadata(plugin_path: str, plugin_obj=None) -> StarMetadata | N metadata = yaml.safe_load(f) elif plugin_obj and hasattr(plugin_obj, "info"): # 使用 info() 函数 - metadata = plugin_obj.info() + metadata = getattr(plugin_obj, "info")() if isinstance(metadata, dict): if "desc" not in metadata and "description" in metadata: @@ -262,7 +267,7 @@ def _purge_modules( module_patterns: list[str] | None = None, root_dir_name: str | None = None, is_reserved: bool = False, - ): + ) -> None: """从 sys.modules 中移除指定的模块 可以基于模块名模式或插件目录名移除模块,用于清理插件相关的模块缓存 @@ -291,7 +296,9 @@ def _purge_modules( except KeyError: logger.warning(f"模块 {module_name} 未载入") - async def reload(self, specified_plugin_name=None): + async def reload( + self, specified_plugin_name: str | None = None + ) -> tuple[bool, str]: """重新加载插件 Args: @@ -347,7 +354,11 @@ async def reload(self, specified_plugin_name=None): return result - async def load(self, specified_module_path=None, specified_dir_name=None): + async def load( + self, + specified_module_path: str | None = None, + specified_dir_name: str | None = None, + ) -> tuple[bool, str]: """载入插件。 当 specified_module_path 或者 specified_dir_name 不为 None 时,只载入指定的插件。 @@ -636,11 +647,11 @@ async def load(self, specified_module_path=None, specified_dir_name=None): logger.error(traceback.format_exc()) if not fail_rec: - return True, None + return True, "" self.failed_plugin_info = fail_rec return False, fail_rec - async def install_plugin(self, repo_url: str, proxy=""): + async def install_plugin(self, repo_url: str, proxy: str = "") -> dict | None: """从仓库 URL 安装插件 从指定的仓库 URL 下载并安装插件,然后加载该插件到系统中 @@ -701,7 +712,7 @@ async def uninstall_plugin( plugin_name: str, delete_config: bool = False, delete_data: bool = False, - ): + ) -> None: """卸载指定的插件。 Args: @@ -790,7 +801,7 @@ async def uninstall_plugin( except Exception as e: logger.warning(f"删除插件持久化数据失败 (plugins_data): {e!s}") - async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str): + async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str) -> None: """解绑并移除一个插件。 Args: @@ -841,7 +852,7 @@ async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str): is_reserved=plugin.reserved, ) - async def update_plugin(self, plugin_name: str, proxy=""): + async def update_plugin(self, plugin_name: str, proxy: str = "") -> None: """升级一个插件""" plugin = self.context.get_registered_star(plugin_name) if not plugin: @@ -852,7 +863,7 @@ async def update_plugin(self, plugin_name: str, proxy=""): await self.updator.update(plugin, proxy=proxy) await self.reload(plugin_name) - async def turn_off_plugin(self, plugin_name: str): + async def turn_off_plugin(self, plugin_name: str) -> None: """禁用一个插件。 调用插件的 terminate() 方法, 将插件的 module_path 加入到 data/shared_preferences.json 的 inactivated_plugins 列表中。 @@ -894,7 +905,7 @@ async def turn_off_plugin(self, plugin_name: str): plugin.activated = False @staticmethod - async def _terminate_plugin(star_metadata: StarMetadata): + async def _terminate_plugin(star_metadata: StarMetadata) -> None: """终止插件,调用插件的 terminate() 和 __del__() 方法""" logger.info(f"正在终止插件 {star_metadata.name} ...") @@ -914,7 +925,7 @@ async def _terminate_plugin(star_metadata: StarMetadata): elif "terminate" in star_metadata.star_cls_type.__dict__: await star_metadata.star_cls.terminate() - async def turn_on_plugin(self, plugin_name: str): + async def turn_on_plugin(self, plugin_name: str) -> None: plugin = self.context.get_registered_star(plugin_name) if plugin is None: raise Exception(f"插件 {plugin_name} 不存在。") @@ -940,7 +951,7 @@ async def turn_on_plugin(self, plugin_name: str): await self.reload(plugin_name) - async def install_plugin_from_file(self, zip_file_path: str): + async def install_plugin_from_file(self, zip_file_path: str) -> dict | None: dir_name = os.path.basename(zip_file_path).replace(".zip", "") dir_name = dir_name.removesuffix("-master").removesuffix("-main").lower() desti_dir = os.path.join(self.plugin_store_path, dir_name) diff --git a/astrbot/core/star/star_tools.py b/astrbot/core/star/star_tools.py index 7a66449b4..4d85131fc 100644 --- a/astrbot/core/star/star_tools.py +++ b/astrbot/core/star/star_tools.py @@ -89,7 +89,7 @@ async def send_message_by_id( id: str, message_chain: MessageChain, platform: str = "aiocqhttp", - ): + ) -> None: """根据 id(例如qq号, 群号等) 直接, 主动地发送消息 Args: diff --git a/astrbot/core/star/updator.py b/astrbot/core/star/updator.py index 8793ad505..f3e4d44a0 100644 --- a/astrbot/core/star/updator.py +++ b/astrbot/core/star/updator.py @@ -18,7 +18,7 @@ def __init__(self, repo_mirror: str = "") -> None: def get_plugin_store_path(self) -> str: return self.plugin_store_path - async def install(self, repo_url: str, proxy="") -> str: + async def install(self, repo_url: str, proxy: str = "") -> str: _, repo_name, _ = self.parse_github_url(repo_url) repo_name = self.format_name(repo_name) plugin_path = os.path.join(self.plugin_store_path, repo_name) @@ -27,7 +27,7 @@ async def install(self, repo_url: str, proxy="") -> str: return plugin_path - async def update(self, plugin: StarMetadata, proxy="") -> str: + async def update(self, plugin: StarMetadata, proxy: str = "") -> str: repo_url = plugin.repo if not repo_url: @@ -52,7 +52,7 @@ async def update(self, plugin: StarMetadata, proxy="") -> str: return plugin_path - def unzip_file(self, zip_path: str, target_dir: str): + def unzip_file(self, zip_path: str, target_dir: str) -> None: os.makedirs(target_dir, exist_ok=True) update_dir = "" logger.info(f"正在解压压缩包: {zip_path}") diff --git a/astrbot/core/umop_config_router.py b/astrbot/core/umop_config_router.py index 1f2289f4d..d8b010d50 100644 --- a/astrbot/core/umop_config_router.py +++ b/astrbot/core/umop_config_router.py @@ -6,15 +6,15 @@ class UmopConfigRouter: """UMOP 配置路由器""" - def __init__(self, sp: SharedPreferences): + def __init__(self, sp: SharedPreferences) -> None: self.umop_to_conf_id: dict[str, str] = {} """UMOP 到配置文件 ID 的映射""" self.sp = sp - async def initialize(self): + async def initialize(self) -> None: await self._load_routing_table() - async def _load_routing_table(self): + async def _load_routing_table(self) -> None: """加载路由表""" # 从 SharedPreferences 中加载 umop_to_conf_id 映射 sp_data = await self.sp.get_async( @@ -50,7 +50,7 @@ def get_conf_id_for_umop(self, umo: str) -> str | None: return conf_id return None - async def update_routing_data(self, new_routing: dict[str, str]): + async def update_routing_data(self, new_routing: dict[str, str]) -> None: """更新路由表 Args: @@ -70,7 +70,7 @@ async def update_routing_data(self, new_routing: dict[str, str]): self.umop_to_conf_id = new_routing await self.sp.global_put("umop_config_routing", self.umop_to_conf_id) - async def update_route(self, umo: str, conf_id: str): + async def update_route(self, umo: str, conf_id: str) -> None: """更新一条路由 Args: @@ -89,7 +89,7 @@ async def update_route(self, umo: str, conf_id: str): self.umop_to_conf_id[umo] = conf_id await self.sp.global_put("umop_config_routing", self.umop_to_conf_id) - async def delete_route(self, umo: str): + async def delete_route(self, umo: str) -> None: """删除一条路由 Args: diff --git a/astrbot/core/updator.py b/astrbot/core/updator.py index 0a7116a0d..69065b717 100644 --- a/astrbot/core/updator.py +++ b/astrbot/core/updator.py @@ -23,7 +23,7 @@ def __init__(self, repo_mirror: str = "") -> None: self.MAIN_PATH = get_astrbot_path() self.ASTRBOT_RELEASE_API = "https://api.soulter.top/releases" - def terminate_child_processes(self): + def terminate_child_processes(self) -> None: """终止当前进程的所有子进程 使用 psutil 库获取当前进程的所有子进程,并尝试终止它们 """ @@ -44,7 +44,7 @@ def terminate_child_processes(self): except psutil.NoSuchProcess: pass - def _reboot(self, delay: int = 3): + def _reboot(self, delay: int = 3) -> None: """重启当前程序 在指定的延迟后,终止所有子进程并重新启动程序 这里只能使用 os.exec* 来重启程序 @@ -85,7 +85,13 @@ async def check_update( async def get_releases(self) -> list: return await self.fetch_release_info(self.ASTRBOT_RELEASE_API) - async def update(self, reboot=False, latest=True, version=None, proxy=""): + async def update( + self, + reboot: bool = False, + latest: bool = True, + version: object | None = None, + proxy: str = "", + ) -> None: update_data = await self.fetch_release_info(self.ASTRBOT_RELEASE_API, latest) file_url = None diff --git a/astrbot/core/utils/command_parser.py b/astrbot/core/utils/command_parser.py index 557793f0a..da7122a14 100644 --- a/astrbot/core/utils/command_parser.py +++ b/astrbot/core/utils/command_parser.py @@ -13,7 +13,7 @@ def get(self, idx: int) -> str | None: class CommandParserMixin: - def parse_commands(self, message: str): + def parse_commands(self, message: str) -> CommandTokens: cmd_tokens = CommandTokens() cmd_tokens.tokens = re.split(r"\s+", message) cmd_tokens.len = len(cmd_tokens.tokens) diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index fcf5bb3c7..0a4a37405 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -7,7 +7,10 @@ import time import uuid import zipfile +from collections.abc import Callable from pathlib import Path +from types import TracebackType +from typing import Any import aiohttp import certifi @@ -19,7 +22,11 @@ logger = logging.getLogger("astrbot") -def on_error(func, path, exc_info): +def on_error( + func: Callable[..., Any], + path: str, + exc_info: tuple[type[BaseException], BaseException, TracebackType], +) -> None: """A callback of the rmtree function.""" import stat @@ -37,7 +44,7 @@ def remove_dir(file_path: str) -> bool: return True -def port_checker(port: int, host: str = "localhost"): +def port_checker(port: int, host: str = "localhost") -> bool | None: sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sk.settimeout(1) try: @@ -134,7 +141,7 @@ async def download_image_by_url( raise e -async def download_file(url: str, path: str, show_progress: bool = False): +async def download_file(url: str, path: str, show_progress: bool = False) -> None: """从指定 url 下载文件到指定路径 path""" try: ssl_context = ssl.create_default_context( @@ -217,7 +224,7 @@ def file_to_base64(file_path: str) -> str: return "base64://" + base64_str -def get_local_ip_addresses(): +def get_local_ip_addresses() -> list[str]: net_interfaces = psutil.net_if_addrs() network_ips = [] @@ -229,7 +236,7 @@ def get_local_ip_addresses(): return network_ips -async def get_dashboard_version(): +async def get_dashboard_version() -> str | None: dist_dir = os.path.join(get_astrbot_data_path(), "dist") if os.path.exists(dist_dir): version_file = os.path.join(dist_dir, "assets", "version") diff --git a/astrbot/core/utils/llm_metadata.py b/astrbot/core/utils/llm_metadata.py index 540c1efd9..915d8d8f9 100644 --- a/astrbot/core/utils/llm_metadata.py +++ b/astrbot/core/utils/llm_metadata.py @@ -29,7 +29,7 @@ class LLMMetadata(TypedDict): LLM_METADATAS: dict[str, LLMMetadata] = {} -async def update_llm_metadata(): +async def update_llm_metadata() -> None: url = "https://models.dev/api.json" try: async with aiohttp.ClientSession() as session: diff --git a/astrbot/core/utils/log_pipe.py b/astrbot/core/utils/log_pipe.py index 2e931dd81..7cfe0be95 100644 --- a/astrbot/core/utils/log_pipe.py +++ b/astrbot/core/utils/log_pipe.py @@ -1,16 +1,18 @@ import os import threading +from collections.abc import Callable from logging import Logger +from typing import Any class LogPipe(threading.Thread): def __init__( self, - level, + level: int, logger: Logger, - identifier=None, - callback=None, - ): + identifier: str | None = None, + callback: Callable[..., Any] | None = None, + ) -> None: threading.Thread.__init__(self) self.daemon = True self.level = level @@ -21,10 +23,10 @@ def __init__( self.reader = os.fdopen(self.fd_read) self.start() - def fileno(self): + def fileno(self) -> int: return self.fd_write - def run(self): + def run(self) -> None: for line in iter(self.reader.readline, ""): if self.callback: self.callback(line.strip()) @@ -32,5 +34,5 @@ def run(self): self.reader.close() - def close(self): + def close(self) -> None: os.close(self.fd_write) diff --git a/astrbot/core/utils/metrics.py b/astrbot/core/utils/metrics.py index f12019e3c..7745246bb 100644 --- a/astrbot/core/utils/metrics.py +++ b/astrbot/core/utils/metrics.py @@ -10,10 +10,10 @@ class Metric: - _iid_cache = None + _iid_cache: str | None = None @staticmethod - def get_installation_id(): + def get_installation_id() -> str: """获取或创建一个唯一的安装ID""" if Metric._iid_cache is not None: return Metric._iid_cache @@ -40,7 +40,7 @@ def get_installation_id(): return "null" @staticmethod - async def upload(**kwargs): + async def upload(**kwargs) -> None: """上传相关非敏感的指标以更好地了解 AstrBot 的使用情况。上传的指标不会包含任何有关消息文本、用户信息等敏感信息。 Powered by TickStats. diff --git a/astrbot/core/utils/migra_helper.py b/astrbot/core/utils/migra_helper.py index 6a300302d..f91878447 100644 --- a/astrbot/core/utils/migra_helper.py +++ b/astrbot/core/utils/migra_helper.py @@ -2,9 +2,11 @@ from astrbot.core import astrbot_config, logger from astrbot.core.astrbot_config_mgr import AstrBotConfig, AstrBotConfigManager +from astrbot.core.db import BaseDatabase from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46 from astrbot.core.db.migration.migra_token_usage import migrate_token_usage from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session +from astrbot.core.umop_config_router import UmopConfigRouter def _migra_agent_runner_configs(conf: AstrBotConfig, ids_map: dict) -> None: @@ -120,7 +122,10 @@ def _migra_provider_to_source_structure(conf: AstrBotConfig) -> None: async def migra( - db, astrbot_config_mgr, umop_config_router, acm: AstrBotConfigManager + db: BaseDatabase, + astrbot_config_mgr: AstrBotConfigManager, + umop_config_router: UmopConfigRouter, + acm: AstrBotConfigManager, ) -> None: """ Stores the migration logic here. diff --git a/astrbot/core/utils/path_util.py b/astrbot/core/utils/path_util.py index 9520d481d..d88a1a695 100644 --- a/astrbot/core/utils/path_util.py +++ b/astrbot/core/utils/path_util.py @@ -3,7 +3,7 @@ from astrbot.core import logger -def path_Mapping(mappings, srcPath: str) -> str: +def path_Mapping(mappings: list[str], srcPath: str) -> str: """路径映射处理函数。尝试支援 Windows 和 Linux 的路径映射。 Args: mappings: 映射规则列表 diff --git a/astrbot/core/utils/pip_installer.py b/astrbot/core/utils/pip_installer.py index 663afc081..1fcbd52dd 100644 --- a/astrbot/core/utils/pip_installer.py +++ b/astrbot/core/utils/pip_installer.py @@ -25,7 +25,7 @@ def _robust_decode(line: bytes) -> str: class PipInstaller: - def __init__(self, pip_install_arg: str, pypi_index_url: str | None = None): + def __init__(self, pip_install_arg: str, pypi_index_url: str | None = None) -> None: self.pip_install_arg = pip_install_arg self.pypi_index_url = pypi_index_url @@ -34,7 +34,7 @@ async def install( package_name: str | None = None, requirements_path: str | None = None, mirror: str | None = None, - ): + ) -> None: args = ["install"] if package_name: args.append(package_name) diff --git a/astrbot/core/utils/session_lock.py b/astrbot/core/utils/session_lock.py index 912d91e53..16a1be89d 100644 --- a/astrbot/core/utils/session_lock.py +++ b/astrbot/core/utils/session_lock.py @@ -1,16 +1,17 @@ import asyncio from collections import defaultdict +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager class SessionLockManager: - def __init__(self): + def __init__(self) -> None: self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) self._lock_count: dict[str, int] = defaultdict(int) self._access_lock = asyncio.Lock() @asynccontextmanager - async def acquire_lock(self, session_id: str): + async def acquire_lock(self, session_id: str) -> AsyncGenerator[None, None]: async with self._access_lock: lock = self._locks[session_id] self._lock_count[session_id] += 1 diff --git a/astrbot/core/utils/session_waiter.py b/astrbot/core/utils/session_waiter.py index e1f2fbef7..4ae9c4fbd 100644 --- a/astrbot/core/utils/session_waiter.py +++ b/astrbot/core/utils/session_waiter.py @@ -18,7 +18,7 @@ class SessionController: """控制一个 Session 是否已经结束""" - def __init__(self): + def __init__(self) -> None: self.future = asyncio.Future() self.current_event: asyncio.Event | None = None """当前正在等待的所用的异步事件""" @@ -29,7 +29,7 @@ def __init__(self): self.history_chains: list[list[Comp.BaseMessageComponent]] = [] - def stop(self, error: Exception | None = None): + def stop(self, error: Exception | None = None) -> None: """立即结束这个会话""" if not self.future.done(): if error: @@ -37,7 +37,7 @@ def stop(self, error: Exception | None = None): else: self.future.set_result(None) - def keep(self, timeout: float = 0, reset_timeout=False): + def keep(self, timeout: float = 0, reset_timeout: bool = False) -> None: """保持这个会话 Args: @@ -71,7 +71,7 @@ def keep(self, timeout: float = 0, reset_timeout=False): asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep - async def _holding(self, event: asyncio.Event, timeout: float): + async def _holding(self, event: asyncio.Event, timeout: float) -> None: """等待事件结束或超时""" try: await asyncio.wait_for(event.wait(), timeout) @@ -107,7 +107,7 @@ def __init__( session_filter: SessionFilter, session_id: str, record_history_chains: bool, - ): + ) -> None: self.session_id = session_id self.session_filter = session_filter self.handler: ( @@ -141,7 +141,7 @@ async def register_wait( finally: self._cleanup() - def _cleanup(self, error: Exception | None = None): + def _cleanup(self, error: Exception | None = None) -> None: """清理会话""" USER_SESSIONS.pop(self.session_id, None) try: @@ -151,7 +151,7 @@ def _cleanup(self, error: Exception | None = None): self.session_controller.stop(error) @classmethod - async def trigger(cls, session_id: str, event: AstrMessageEvent): + async def trigger(cls, session_id: str, event: AstrMessageEvent) -> None: """外部输入触发会话处理""" session = USER_SESSIONS.get(session_id) if not session or session.session_controller.future.done(): diff --git a/astrbot/core/utils/shared_preferences.py b/astrbot/core/utils/shared_preferences.py index ccd394ee4..1e8325bce 100644 --- a/astrbot/core/utils/shared_preferences.py +++ b/astrbot/core/utils/shared_preferences.py @@ -1,7 +1,7 @@ import asyncio import os import threading -from typing import Any, TypeVar, overload +from typing import TypeVar, overload from astrbot.core.db import BaseDatabase from astrbot.core.db.po import Preference @@ -12,7 +12,9 @@ class SharedPreferences: - def __init__(self, db_helper: BaseDatabase, json_storage_path=None): + def __init__( + self, db_helper: BaseDatabase, json_storage_path: str | None = None + ) -> None: if json_storage_path is None: json_storage_path = os.path.join( get_astrbot_data_path(), @@ -66,7 +68,7 @@ async def session_get( self, umo: None, key: str, - default: Any = None, + default: object = None, ) -> list[Preference]: ... @overload @@ -74,7 +76,7 @@ async def session_get( self, umo: str, key: None, - default: Any = None, + default: object = None, ) -> list[Preference]: ... @overload @@ -82,7 +84,7 @@ async def session_get( self, umo: None, key: None, - default: Any = None, + default: object = None, ) -> list[Preference]: ... async def session_get( @@ -100,7 +102,9 @@ async def session_get( return await self.get_async("umo", umo, key, default) @overload - async def global_get(self, key: None, default: Any = None) -> list[Preference]: ... + async def global_get( + self, key: None, default: object = None + ) -> list[Preference]: ... @overload async def global_get(self, key: str, default: _VT = None) -> _VT: ... @@ -118,7 +122,9 @@ async def global_get( return await self.range_get_async("global", "global", key) return await self.get_async("global", "global", key, default) - async def put_async(self, scope: str, scope_id: str, key: str, value: Any): + async def put_async( + self, scope: str, scope_id: str, key: str, value: object + ) -> None: """设置指定范围和键的偏好设置""" await self.db_helper.insert_preference_or_update( scope, @@ -127,24 +133,24 @@ async def put_async(self, scope: str, scope_id: str, key: str, value: Any): {"val": value}, ) - async def session_put(self, umo: str, key: str, value: Any): + async def session_put(self, umo: str, key: str, value: object) -> None: await self.put_async("umo", umo, key, value) - async def global_put(self, key: str, value: Any): + async def global_put(self, key: str, value: object) -> None: await self.put_async("global", "global", key, value) - async def remove_async(self, scope: str, scope_id: str, key: str): + async def remove_async(self, scope: str, scope_id: str, key: str) -> None: """删除指定范围和键的偏好设置""" await self.db_helper.remove_preference(scope, scope_id, key) - async def session_remove(self, umo: str, key: str): + async def session_remove(self, umo: str, key: str) -> None: await self.remove_async("umo", umo, key) - async def global_remove(self, key: str): + async def global_remove(self, key: str) -> None: """删除全局偏好设置""" await self.remove_async("global", "global", key) - async def clear_async(self, scope: str, scope_id: str): + async def clear_async(self, scope: str, scope_id: str) -> None: """清空指定范围的所有偏好设置""" await self.db_helper.clear_preferences(scope, scope_id) @@ -188,21 +194,29 @@ def range_get( return result - def put(self, key, value, scope: str | None = None, scope_id: str | None = None): + def put( + self, + key: str, + value: object, + scope: str | None = None, + scope_id: str | None = None, + ) -> None: """设置偏好设置(已弃用)""" asyncio.run_coroutine_threadsafe( self.put_async(scope or "unknown", scope_id or "unknown", key, value), self._sync_loop, ).result() - def remove(self, key, scope: str | None = None, scope_id: str | None = None): + def remove( + self, key: str, scope: str | None = None, scope_id: str | None = None + ) -> None: """删除偏好设置(已弃用)""" asyncio.run_coroutine_threadsafe( self.remove_async(scope or "unknown", scope_id or "unknown", key), self._sync_loop, ).result() - def clear(self, scope: str | None = None, scope_id: str | None = None): + def clear(self, scope: str | None = None, scope_id: str | None = None) -> None: """清空偏好设置(已弃用)""" asyncio.run_coroutine_threadsafe( self.clear_async(scope or "unknown", scope_id or "unknown"), diff --git a/astrbot/core/utils/t2i/network_strategy.py b/astrbot/core/utils/t2i/network_strategy.py index 7ebba5669..b7ec16782 100644 --- a/astrbot/core/utils/t2i/network_strategy.py +++ b/astrbot/core/utils/t2i/network_strategy.py @@ -28,7 +28,7 @@ def __init__(self, base_url: str | None = None) -> None: self.endpoints = [self.BASE_RENDER_URL] self.template_manager = TemplateManager() - async def initialize(self): + async def initialize(self) -> None: if self.BASE_RENDER_URL == ASTRBOT_T2I_DEFAULT_ENDPOINT: asyncio.create_task(self.get_official_endpoints()) @@ -36,7 +36,7 @@ async def get_template(self, name: str = "base") -> str: """通过名称获取文转图 HTML 模板""" return self.template_manager.get_template(name) - async def get_official_endpoints(self): + async def get_official_endpoints(self) -> None: """获取官方的 t2i 端点列表。""" try: async with aiohttp.ClientSession() as session: @@ -57,7 +57,7 @@ async def get_official_endpoints(self): except Exception as e: logger.error(f"Failed to get official endpoints: {e}") - def _clean_url(self, url: str): + def _clean_url(self, url: str) -> str: url = url.removesuffix("/") if not url.endswith("text2img"): url += "/text2img" diff --git a/astrbot/core/utils/t2i/renderer.py b/astrbot/core/utils/t2i/renderer.py index 2ce7a5ebf..19c695d5e 100644 --- a/astrbot/core/utils/t2i/renderer.py +++ b/astrbot/core/utils/t2i/renderer.py @@ -7,11 +7,11 @@ class HtmlRenderer: - def __init__(self, endpoint_url: str | None = None): + def __init__(self, endpoint_url: str | None = None) -> None: self.network_strategy = NetworkRenderStrategy(endpoint_url) self.local_strategy = LocalRenderStrategy() - async def initialize(self): + async def initialize(self) -> None: await self.network_strategy.initialize() async def render_custom_template( @@ -20,7 +20,7 @@ async def render_custom_template( tmpl_data: dict, return_url: bool = False, options: dict | None = None, - ): + ) -> str: """使用自定义文转图模板。该方法会通过网络调用 t2i 终结点图文渲染API。 @param tmpl_str: HTML Jinja2 模板。 @param tmpl_data: jinja2 模板数据。 @@ -43,7 +43,7 @@ async def render_t2i( use_network: bool = True, return_url: bool = False, template_name: str | None = None, - ): + ) -> str: """使用默认文转图模板。""" if use_network: try: diff --git a/astrbot/core/utils/t2i/template_manager.py b/astrbot/core/utils/t2i/template_manager.py index 6d44f735b..b3eb0c9ff 100644 --- a/astrbot/core/utils/t2i/template_manager.py +++ b/astrbot/core/utils/t2i/template_manager.py @@ -14,7 +14,7 @@ class TemplateManager: CORE_TEMPLATES = ["base.html", "astrbot_powershell.html"] - def __init__(self): + def __init__(self) -> None: self.builtin_template_dir = os.path.join( get_astrbot_path(), "astrbot", @@ -28,7 +28,7 @@ def __init__(self): os.makedirs(self.user_template_dir, exist_ok=True) self._initialize_user_templates() - def _copy_core_templates(self, overwrite: bool = False): + def _copy_core_templates(self, overwrite: bool = False) -> None: """从内置目录复制核心模板到用户目录。""" for filename in self.CORE_TEMPLATES: src = os.path.join(self.builtin_template_dir, filename) @@ -36,7 +36,7 @@ def _copy_core_templates(self, overwrite: bool = False): if os.path.exists(src) and (overwrite or not os.path.exists(dst)): shutil.copyfile(src, dst) - def _initialize_user_templates(self): + def _initialize_user_templates(self) -> None: """如果用户目录下缺少核心模板,则进行复制。""" self._copy_core_templates(overwrite=False) @@ -80,7 +80,7 @@ def get_template(self, name: str) -> str: raise FileNotFoundError("模板不存在。") - def create_template(self, name: str, content: str): + def create_template(self, name: str, content: str) -> None: """在用户目录中创建一个新的模板文件。""" path = self._get_user_template_path(name) if os.path.exists(path): @@ -88,7 +88,7 @@ def create_template(self, name: str, content: str): with open(path, "w", encoding="utf-8") as f: f.write(content) - def update_template(self, name: str, content: str): + def update_template(self, name: str, content: str) -> None: """更新一个模板。此操作始终写入用户目录。 如果更新的是一个内置模板,此操作实际上会在用户目录中创建一个修改后的副本, 从而实现对内置模板的“覆盖”。 @@ -97,7 +97,7 @@ def update_template(self, name: str, content: str): with open(path, "w", encoding="utf-8") as f: f.write(content) - def delete_template(self, name: str): + def delete_template(self, name: str) -> None: """仅删除用户目录中的模板文件。 如果删除的是一个覆盖了内置模板的用户模板,这将有效地“恢复”到内置版本。 """ @@ -106,6 +106,6 @@ def delete_template(self, name: str): raise FileNotFoundError("用户模板不存在,无法删除。") os.remove(path) - def reset_default_template(self): + def reset_default_template(self) -> None: """将核心模板从内置目录强制重置到用户目录。""" self._copy_core_templates(overwrite=True) diff --git a/astrbot/core/utils/version_comparator.py b/astrbot/core/utils/version_comparator.py index 4ad2da10e..eea89fb42 100644 --- a/astrbot/core/utils/version_comparator.py +++ b/astrbot/core/utils/version_comparator.py @@ -13,7 +13,7 @@ def compare_version(v1: str, v2: str) -> int: v1 = v1.lower().replace("v", "") v2 = v2.lower().replace("v", "") - def split_version(version): + def split_version(version: str) -> tuple[list, list | None]: match = re.match( r"^([0-9]+(?:\.[0-9]+)*)(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+(.+))?$", version, @@ -75,7 +75,7 @@ def split_version(version): return 0 # 数字部分和预发布标签都相同 @staticmethod - def _split_prerelease(prerelease): + def _split_prerelease(prerelease: str) -> list[str] | None: if not prerelease: return None parts = prerelease.split(".") diff --git a/astrbot/core/utils/webhook_utils.py b/astrbot/core/utils/webhook_utils.py index 0e1c3f9cd..07abc115a 100644 --- a/astrbot/core/utils/webhook_utils.py +++ b/astrbot/core/utils/webhook_utils.py @@ -20,7 +20,7 @@ def _get_dashboard_port() -> int: return 6185 -def log_webhook_info(platform_name: str, webhook_uuid: str): +def log_webhook_info(platform_name: str, webhook_uuid: str) -> None: """打印美观的 webhook 信息日志 Args: diff --git a/astrbot/core/zip_updator.py b/astrbot/core/zip_updator.py index 728dfdabb..75318437e 100644 --- a/astrbot/core/zip_updator.py +++ b/astrbot/core/zip_updator.py @@ -3,6 +3,7 @@ import shutil import ssl import zipfile +from typing import NoReturn import aiohttp import certifi @@ -101,10 +102,10 @@ def github_api_release_parser(self, releases: list) -> list: ) return ret - def unzip(self): + def unzip(self) -> NoReturn: raise NotImplementedError - async def update(self): + async def update(self) -> NoReturn: raise NotImplementedError def compare_version(self, v1: str, v2: str) -> int: @@ -148,7 +149,9 @@ async def check_update( body=f"{tag_name}\n\n{sel_release_data['body']}", ) - async def download_from_repo_url(self, target_path: str, repo_url: str, proxy=""): + async def download_from_repo_url( + self, target_path: str, repo_url: str, proxy: str = "" + ) -> None: author, repo, branch = self.parse_github_url(repo_url) logger.info(f"正在下载更新 {repo} ...") @@ -185,7 +188,7 @@ async def download_from_repo_url(self, target_path: str, repo_url: str, proxy="" await download_file(release_url, target_path + ".zip") - def parse_github_url(self, url: str): + def parse_github_url(self, url: str) -> tuple[str, str, str]: """使用正则表达式解析 GitHub 仓库 URL,支持 `.git` 后缀和 `tree/branch` 结构 Returns: tuple[str, str, str]: 返回作者名、仓库名和分支名 @@ -203,7 +206,7 @@ def parse_github_url(self, url: str): return author, repo, branch raise ValueError("无效的 GitHub URL") - def unzip_file(self, zip_path: str, target_dir: str): + def unzip_file(self, zip_path: str, target_dir: str) -> None: """解压缩文件, 并将压缩包内**第一个**文件夹内的文件移动到 target_dir""" os.makedirs(target_dir, exist_ok=True) update_dir = "" diff --git a/astrbot/dashboard/routes/auth.py b/astrbot/dashboard/routes/auth.py index 4ee0d57d4..75c80b0bb 100644 --- a/astrbot/dashboard/routes/auth.py +++ b/astrbot/dashboard/routes/auth.py @@ -19,7 +19,7 @@ def __init__(self, context: RouteContext) -> None: } self.register_routes() - async def login(self): + async def login(self) -> dict: username = self.config["dashboard"]["username"] password = self.config["dashboard"]["password"] post_data = await request.json @@ -47,7 +47,7 @@ async def login(self): await asyncio.sleep(3) return Response().error("用户名或密码错误").__dict__ - async def edit_account(self): + async def edit_account(self) -> dict: if DEMO_MODE: return ( Response() @@ -77,7 +77,7 @@ async def edit_account(self): return Response().ok(None, "修改成功").__dict__ - def generate_jwt(self, username): + def generate_jwt(self, username: str) -> str: payload = { "username": username, "exp": datetime.datetime.utcnow() + datetime.timedelta(days=7), diff --git a/astrbot/dashboard/routes/backup.py b/astrbot/dashboard/routes/backup.py index ee39399dc..6b536f31b 100644 --- a/astrbot/dashboard/routes/backup.py +++ b/astrbot/dashboard/routes/backup.py @@ -9,11 +9,13 @@ import traceback import uuid import zipfile +from collections.abc import Awaitable, Callable from datetime import datetime from pathlib import Path import jwt -from quart import request, send_file +import quart +from quart import ResponseReturnValue, request, send_file from astrbot.core import logger from astrbot.core.backup.exporter import AstrBotExporter @@ -180,10 +182,14 @@ def _update_progress( if message is not None: p["message"] = message - def _make_progress_callback(self, task_id: str): + def _make_progress_callback( + self, task_id: str + ) -> Callable[[str, int, int, str], Awaitable[None]]: """创建进度回调函数""" - async def _callback(stage: str, current: int, total: int, message: str = ""): + async def _callback( + stage: str, current: int, total: int, message: str = "" + ) -> None: self._update_progress( task_id, status="processing", @@ -195,7 +201,7 @@ async def _callback(stage: str, current: int, total: int, message: str = ""): return _callback - def _ensure_cleanup_task_started(self): + def _ensure_cleanup_task_started(self) -> None: """确保后台清理任务已启动(在异步上下文中延迟启动)""" if self._cleanup_task is None or self._cleanup_task.done(): try: @@ -206,7 +212,7 @@ def _ensure_cleanup_task_started(self): # 如果没有运行中的事件循环,跳过(等待下次异步调用时启动) pass - async def _cleanup_expired_uploads(self): + async def _cleanup_expired_uploads(self) -> None: """定期清理过期的上传会话 基于 last_activity 字段判断过期,避免清理活跃的上传会话。 @@ -233,7 +239,7 @@ async def _cleanup_expired_uploads(self): except Exception as e: logger.error(f"清理过期上传会话失败: {e}") - async def _cleanup_upload_session(self, upload_id: str): + async def _cleanup_upload_session(self, upload_id: str) -> None: """清理上传会话""" if upload_id in self.upload_sessions: session = self.upload_sessions[upload_id] @@ -266,7 +272,7 @@ def _get_backup_manifest(self, zip_path: str) -> dict | None: logger.debug(f"读取备份 manifest 失败: {e}") return None # 无法读取,不是有效备份 - async def list_backups(self): + async def list_backups(self) -> dict: # 确保后台清理任务已启动 self._ensure_cleanup_task_started() @@ -340,7 +346,7 @@ async def list_backups(self): logger.error(traceback.format_exc()) return Response().error(f"获取备份列表失败: {e!s}").__dict__ - async def export_backup(self): + async def export_backup(self) -> dict: """创建备份 返回: @@ -371,7 +377,7 @@ async def export_backup(self): logger.error(traceback.format_exc()) return Response().error(f"创建备份失败: {e!s}").__dict__ - async def _background_export_task(self, task_id: str): + async def _background_export_task(self, task_id: str) -> None: """后台导出任务""" try: self._update_progress(task_id, status="processing", message="正在初始化...") @@ -409,7 +415,7 @@ async def _background_export_task(self, task_id: str): logger.error(traceback.format_exc()) self._set_task_result(task_id, "failed", error=str(e)) - async def upload_backup(self): + async def upload_backup(self) -> dict: """上传备份文件 将备份文件上传到服务器,返回保存的文件名。 @@ -459,7 +465,7 @@ async def upload_backup(self): logger.error(traceback.format_exc()) return Response().error(f"上传备份文件失败: {e!s}").__dict__ - async def upload_init(self): + async def upload_init(self) -> dict: """初始化分片上传 创建一个上传会话,返回 upload_id 供后续分片上传使用。 @@ -538,7 +544,7 @@ async def upload_init(self): logger.error(traceback.format_exc()) return Response().error(f"初始化分片上传失败: {e!s}").__dict__ - async def upload_chunk(self): + async def upload_chunk(self) -> dict: """上传分片 上传单个分片数据。 @@ -642,7 +648,7 @@ def _mark_backup_as_uploaded(self, zip_path: str) -> None: except Exception as e: logger.warning(f"标记备份来源失败: {e}") - async def upload_complete(self): + async def upload_complete(self) -> dict: """完成分片上传 合并所有分片为完整文件。 @@ -732,7 +738,7 @@ async def upload_complete(self): logger.error(traceback.format_exc()) return Response().error(f"完成分片上传失败: {e!s}").__dict__ - async def upload_abort(self): + async def upload_abort(self) -> dict: """取消分片上传 取消上传并清理已上传的分片。 @@ -762,7 +768,7 @@ async def upload_abort(self): logger.error(traceback.format_exc()) return Response().error(f"取消上传失败: {e!s}").__dict__ - async def check_backup(self): + async def check_backup(self) -> dict: """预检查备份文件 检查备份文件的版本兼容性,返回确认信息。 @@ -806,7 +812,7 @@ async def check_backup(self): logger.error(traceback.format_exc()) return Response().error(f"预检查备份文件失败: {e!s}").__dict__ - async def import_backup(self): + async def import_backup(self) -> dict: """执行备份导入 在用户确认后执行实际的导入操作。 @@ -866,7 +872,7 @@ async def import_backup(self): logger.error(traceback.format_exc()) return Response().error(f"导入备份失败: {e!s}").__dict__ - async def _background_import_task(self, task_id: str, zip_path: str): + async def _background_import_task(self, task_id: str, zip_path: str) -> None: """后台导入任务""" try: self._update_progress(task_id, status="processing", message="正在初始化...") @@ -908,7 +914,7 @@ async def _background_import_task(self, task_id: str, zip_path: str): logger.error(traceback.format_exc()) self._set_task_result(task_id, "failed", error=str(e)) - async def get_progress(self): + async def get_progress(self) -> dict: """获取任务进度 Query 参数: @@ -949,7 +955,7 @@ async def get_progress(self): logger.error(traceback.format_exc()) return Response().error(f"获取任务进度失败: {e!s}").__dict__ - async def download_backup(self): + async def download_backup(self) -> dict | quart.Response: """下载备份文件 Query 参数: @@ -1000,7 +1006,7 @@ async def download_backup(self): logger.error(traceback.format_exc()) return Response().error(f"下载备份失败: {e!s}").__dict__ - async def delete_backup(self): + async def delete_backup(self) -> ResponseReturnValue: """删除备份文件 Body: @@ -1027,7 +1033,7 @@ async def delete_backup(self): logger.error(traceback.format_exc()) return Response().error(f"删除备份失败: {e!s}").__dict__ - async def rename_backup(self): + async def rename_backup(self) -> dict: """重命名备份文件 Body: diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 71c3fecd3..2b125cf40 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -3,15 +3,18 @@ import mimetypes import os import uuid +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from typing import cast +import quart from quart import Response as QuartResponse from quart import g, make_response, request, send_file from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.db import BaseDatabase +from astrbot.core.db.po import PlatformMessageHistory from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr from astrbot.core.utils.astrbot_path import get_astrbot_data_path @@ -19,7 +22,7 @@ @asynccontextmanager -async def track_conversation(convs: dict, conv_id: str): +async def track_conversation(convs: dict, conv_id: str) -> AsyncGenerator[None, None]: convs[conv_id] = True try: yield @@ -62,7 +65,7 @@ def __init__( self.running_convs: dict[str, bool] = {} - async def get_file(self): + async def get_file(self) -> dict | quart.Response: filename = request.args.get("filename") if not filename: return Response().error("Missing key: filename").__dict__ @@ -85,7 +88,7 @@ async def get_file(self): except (FileNotFoundError, OSError): return Response().error("File access error").__dict__ - async def get_attachment(self): + async def get_attachment(self) -> dict | quart.Response: """Get attachment file by attachment_id.""" attachment_id = request.args.get("attachment_id") if not attachment_id: @@ -104,7 +107,7 @@ async def get_attachment(self): except (FileNotFoundError, OSError): return Response().error("File access error").__dict__ - async def post_file(self): + async def post_file(self) -> dict: """Upload a file and create an attachment record, return attachment_id.""" post_data = await request.files if "file" not in post_data: @@ -228,7 +231,7 @@ async def _save_bot_message( media_parts: list, reasoning: str, agent_stats: dict, - ): + ) -> PlatformMessageHistory: """保存 bot 消息到历史记录,返回保存的记录""" bot_message_parts = [] bot_message_parts.extend(media_parts) @@ -250,7 +253,7 @@ async def _save_bot_message( ) return record - async def chat(self): + async def chat(self) -> dict | quart.Response: username = g.get("username", "guest") post_data = await request.json @@ -292,7 +295,7 @@ async def chat(self): # 构建用户消息段(包含 path 用于传递给 adapter) message_parts = await self._build_user_message_parts(message) - async def stream(): + async def stream() -> AsyncGenerator[str, None]: client_disconnected = False accumulated_parts = [] accumulated_text = "" @@ -484,7 +487,7 @@ async def stream(): response.timeout = None # fix SSE auto disconnect issue return response - async def delete_webchat_session(self): + async def delete_webchat_session(self) -> dict: """Delete a Platform session and all its related data.""" session_id = request.args.get("session_id") if not session_id: @@ -540,7 +543,9 @@ async def delete_webchat_session(self): return Response().ok().__dict__ - def _extract_attachment_ids(self, history_list) -> list[str]: + def _extract_attachment_ids( + self, history_list: list[PlatformMessageHistory] + ) -> list[str]: """从消息历史中提取所有 attachment_id""" attachment_ids = [] for history in history_list: @@ -553,7 +558,7 @@ def _extract_attachment_ids(self, history_list) -> list[str]: attachment_ids.append(part["attachment_id"]) return attachment_ids - async def _delete_attachments(self, attachment_ids: list[str]): + async def _delete_attachments(self, attachment_ids: list[str]) -> None: """删除附件(包括数据库记录和磁盘文件)""" try: attachments = await self.db.get_attachments(attachment_ids) @@ -575,7 +580,7 @@ async def _delete_attachments(self, attachment_ids: list[str]): except Exception as e: logger.warning(f"Failed to delete attachments: {e}") - async def new_session(self): + async def new_session(self) -> dict: """Create a new Platform session (default: webchat).""" username = g.get("username", "guest") @@ -600,7 +605,7 @@ async def new_session(self): .__dict__ ) - async def get_sessions(self): + async def get_sessions(self) -> dict: """Get all Platform sessions for the current user.""" username = g.get("username", "guest") @@ -631,7 +636,7 @@ async def get_sessions(self): return Response().ok(data=sessions_data).__dict__ - async def get_session(self): + async def get_session(self) -> dict: """Get session information and message history by session_id.""" session_id = request.args.get("session_id") if not session_id: @@ -662,7 +667,7 @@ async def get_session(self): .__dict__ ) - async def update_session_display_name(self): + async def update_session_display_name(self) -> dict: """Update a Platform session's display name.""" post_data = await request.json diff --git a/astrbot/dashboard/routes/command.py b/astrbot/dashboard/routes/command.py index abd38d886..b42131ee3 100644 --- a/astrbot/dashboard/routes/command.py +++ b/astrbot/dashboard/routes/command.py @@ -25,7 +25,7 @@ def __init__(self, context: RouteContext) -> None: } self.register_routes() - async def get_commands(self): + async def get_commands(self) -> dict: commands = await list_commands() summary = { "total": len(commands), @@ -34,11 +34,11 @@ async def get_commands(self): } return Response().ok({"items": commands, "summary": summary}).__dict__ - async def get_conflicts(self): + async def get_conflicts(self) -> dict: conflicts = await list_command_conflicts() return Response().ok(conflicts).__dict__ - async def toggle_command(self): + async def toggle_command(self) -> dict: data = await request.get_json() handler_full_name = data.get("handler_full_name") enabled = data.get("enabled") @@ -57,7 +57,7 @@ async def toggle_command(self): payload = await _get_command_payload(handler_full_name) return Response().ok(payload).__dict__ - async def rename_command(self): + async def rename_command(self) -> dict: data = await request.get_json() handler_full_name = data.get("handler_full_name") new_name = data.get("new_name") @@ -75,7 +75,7 @@ async def rename_command(self): return Response().ok(payload).__dict__ -async def _get_command_payload(handler_full_name: str): +async def _get_command_payload(handler_full_name: str) -> dict: commands = await list_commands() for cmd in commands: if cmd["handler_full_name"] == handler_full_name: diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index bd2f9a264..224d6537f 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -2,9 +2,10 @@ import inspect import os import traceback -from typing import Any +from collections.abc import Callable, Iterable +from typing import Any, Literal, overload -from quart import request +from quart import ResponseReturnValue, request from astrbot.core import astrbot_config, file_token_service, logger from astrbot.core.config.astrbot_config import AstrBotConfig @@ -17,8 +18,10 @@ ) from astrbot.core.config.i18n_utils import ConfigMetadataI18n from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.platform.platform_metadata import PlatformMetadata from astrbot.core.platform.register import platform_cls_map, platform_registry from astrbot.core.provider import Provider +from astrbot.core.provider.provider import AbstractProvider from astrbot.core.provider.register import provider_registry from astrbot.core.star.star import star_registry from astrbot.core.utils.llm_metadata import LLM_METADATAS @@ -26,8 +29,19 @@ from .route import Response, Route, RouteContext +# 兼容 isinstance 的类型提示,等价于内建 _ClassInfo 的可用子集 +ClassInfo = type | tuple[type, ...] -def try_cast(value: Any, type_: str): + +@overload +def try_cast(value: object, type_: Literal["int"]) -> int | None: ... + + +@overload +def try_cast(value: object, type_: Literal["float"]) -> float | None: ... + + +def try_cast(value: Any, type_: str): # noqa:ANN401 if type_ == "int": try: return int(value) @@ -46,17 +60,34 @@ def try_cast(value: Any, type_: str): return None -def _expect_type(value, expected_type, path_key, errors, expected_name=None): +def _expect_type( + value: object, + expected_type: ClassInfo, + path_key: str, + errors: list, + expected_name: str | None = None, +) -> bool: if not isinstance(value, expected_type): + exp_name = expected_name or ( + expected_type.__name__ + if isinstance(expected_type, type) + else " | ".join(t.__name__ for t in expected_type if isinstance(t, type)) + or "unknown" + ) errors.append( - f"错误的类型 {path_key}: 期望是 {expected_name or expected_type.__name__}, " - f"得到了 {type(value).__name__}" + f"错误的类型 {path_key}: 期望是 {exp_name}, 得到了 {type(value).__name__}" ) return False return True -def _validate_template_list(value, meta, path_key, errors, validate_fn): +def _validate_template_list( + value: Iterable, + meta: dict, + path_key: str, + errors: list, + validate_fn: Callable[[dict, dict, str], None], +) -> None: if not _expect_type(value, list, path_key, errors, "list"): return @@ -82,14 +113,14 @@ def _validate_template_list(value, meta, path_key, errors, validate_fn): validate_fn( item, template_meta.get("items", {}), - path=f"{item_path}.", + f"{item_path}.", ) -def validate_config(data, schema: dict, is_core: bool) -> tuple[list[str], dict]: +def validate_config(data: dict, schema: dict, is_core: bool) -> tuple[list[str], dict]: errors = [] - def validate(data: dict, metadata: dict = schema, path=""): + def validate(data: dict, metadata: dict = schema, path: str = "") -> None: for key, value in data.items(): if key not in metadata: continue @@ -167,7 +198,9 @@ def validate(data: dict, metadata: dict = schema, path=""): return errors, data -def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False): +def save_config( + post_config: dict, config: AstrBotConfig, is_core: bool = False +) -> None: """验证并保存配置""" errors = None logger.info(f"Saving config, is_core={is_core}") @@ -245,7 +278,7 @@ def __init__( } self.register_routes() - async def delete_provider_source(self): + async def delete_provider_source(self) -> ResponseReturnValue: """删除 provider_source,并更新关联的 providers""" post_data = await request.json if not post_data: @@ -287,7 +320,7 @@ async def delete_provider_source(self): return Response().ok(message="删除 provider source 成功").__dict__ - async def update_provider_source(self): + async def update_provider_source(self) -> ResponseReturnValue: """更新或新增 provider_source,并重载关联的 providers""" post_data = await request.json if not post_data: @@ -365,7 +398,7 @@ async def update_provider_source(self): return Response().ok(message="更新 provider source 成功").__dict__ - async def get_provider_template(self): + async def get_provider_template(self) -> ResponseReturnValue: config_schema = { "provider": CONFIG_METADATA_2["provider_group"]["metadata"]["provider"] } @@ -376,11 +409,11 @@ async def get_provider_template(self): } return Response().ok(data=data).__dict__ - async def get_uc_table(self): + async def get_uc_table(self) -> ResponseReturnValue: """获取 UMOP 配置路由表""" return Response().ok({"routing": self.ucr.umop_to_conf_id}).__dict__ - async def update_ucr_all(self): + async def update_ucr_all(self) -> ResponseReturnValue: """更新 UMOP 配置路由表的全部内容""" post_data = await request.json if not post_data: @@ -398,7 +431,7 @@ async def update_ucr_all(self): logger.error(traceback.format_exc()) return Response().error(f"更新路由表失败: {e!s}").__dict__ - async def update_ucr(self): + async def update_ucr(self) -> ResponseReturnValue: """更新 UMOP 配置路由表""" post_data = await request.json if not post_data: @@ -417,7 +450,7 @@ async def update_ucr(self): logger.error(traceback.format_exc()) return Response().error(f"更新路由表失败: {e!s}").__dict__ - async def delete_ucr(self): + async def delete_ucr(self) -> ResponseReturnValue: """删除 UMOP 配置路由表中的一项""" post_data = await request.json if not post_data: @@ -437,17 +470,17 @@ async def delete_ucr(self): logger.error(traceback.format_exc()) return Response().error(f"删除路由表项失败: {e!s}").__dict__ - async def get_default_config(self): + async def get_default_config(self) -> ResponseReturnValue: """获取默认配置文件""" metadata = ConfigMetadataI18n.convert_to_i18n_keys(CONFIG_METADATA_3) return Response().ok({"config": DEFAULT_CONFIG, "metadata": metadata}).__dict__ - async def get_abconf_list(self): + async def get_abconf_list(self) -> ResponseReturnValue: """获取所有 AstrBot 配置文件的列表""" abconf_list = self.acm.get_conf_list() return Response().ok({"info_list": abconf_list}).__dict__ - async def create_abconf(self): + async def create_abconf(self) -> ResponseReturnValue: """创建新的 AstrBot 配置文件""" post_data = await request.json if not post_data: @@ -461,7 +494,7 @@ async def create_abconf(self): except ValueError as e: return Response().error(str(e)).__dict__ - async def get_abconf(self): + async def get_abconf(self) -> ResponseReturnValue: """获取指定 AstrBot 配置文件""" abconf_id = request.args.get("id") system_config = request.args.get("system_config", "0").lower() == "1" @@ -483,7 +516,7 @@ async def get_abconf(self): except ValueError as e: return Response().error(str(e)).__dict__ - async def delete_abconf(self): + async def delete_abconf(self) -> ResponseReturnValue: """删除指定 AstrBot 配置文件""" post_data = await request.json if not post_data: @@ -504,7 +537,7 @@ async def delete_abconf(self): logger.error(traceback.format_exc()) return Response().error(f"删除配置文件失败: {e!s}").__dict__ - async def update_abconf(self): + async def update_abconf(self) -> ResponseReturnValue: """更新指定 AstrBot 配置文件信息""" post_data = await request.json if not post_data: @@ -527,7 +560,7 @@ async def update_abconf(self): logger.error(traceback.format_exc()) return Response().error(f"更新配置文件失败: {e!s}").__dict__ - async def _test_single_provider(self, provider): + async def _test_single_provider(self, provider: AbstractProvider) -> dict: """辅助函数:测试单个 provider 的可用性""" meta = provider.meta() provider_name = provider.provider_config.get("id", "Unknown Provider") @@ -567,15 +600,15 @@ def _error_response( self, message: str, status_code: int = 500, - log_fn=logger.error, - ): + log_fn: Callable[[str], Any] = logger.error, + ) -> ResponseReturnValue: log_fn(message) # 记录更详细的traceback信息,但只在是严重错误时 if status_code == 500: log_fn(traceback.format_exc()) return Response().error(message).__dict__ - async def check_one_provider_status(self): + async def check_one_provider_status(self) -> ResponseReturnValue: """API: check a single LLM Provider's status by id""" provider_id = request.args.get("id") if not provider_id: @@ -609,7 +642,7 @@ async def check_one_provider_status(self): 500, ) - async def get_configs(self): + async def get_configs(self) -> ResponseReturnValue: # plugin_name 为空时返回 AstrBot 配置 # 否则返回指定 plugin_name 的插件配置 plugin_name = request.args.get("plugin_name", None) @@ -617,7 +650,7 @@ async def get_configs(self): return Response().ok(await self._get_astrbot_config()).__dict__ return Response().ok(await self._get_plugin_config(plugin_name)).__dict__ - async def get_provider_config_list(self): + async def get_provider_config_list(self) -> ResponseReturnValue: provider_type = request.args.get("provider_type", None) if not provider_type: return Response().error("缺少参数 provider_type").__dict__ @@ -645,7 +678,7 @@ async def get_provider_config_list(self): provider_list.append(provider) return Response().ok(provider_list).__dict__ - async def get_provider_model_list(self): + async def get_provider_model_list(self) -> ResponseReturnValue: """获取指定提供商的模型列表""" provider_id = request.args.get("provider_id", None) if not provider_id: @@ -682,7 +715,7 @@ async def get_provider_model_list(self): logger.error(traceback.format_exc()) return Response().error(str(e)).__dict__ - async def get_embedding_dim(self): + async def get_embedding_dim(self) -> ResponseReturnValue: """获取嵌入模型的维度""" post_data = await request.json provider_config = post_data.get("provider_config", None) @@ -737,7 +770,7 @@ async def get_embedding_dim(self): logger.error(traceback.format_exc()) return Response().error(f"获取嵌入维度失败: {e!s}").__dict__ - async def get_provider_source_models(self): + async def get_provider_source_models(self) -> ResponseReturnValue: """获取指定 provider_source 支持的模型列表 本质上会临时初始化一个 Provider 实例,调用 get_models() 获取模型列表,然后销毁实例 @@ -835,14 +868,14 @@ async def get_provider_source_models(self): logger.error(traceback.format_exc()) return Response().error(f"获取模型列表失败: {e!s}").__dict__ - async def get_platform_list(self): + async def get_platform_list(self) -> ResponseReturnValue: """获取所有平台的列表""" platform_list = [] for platform in self.config["platform"]: platform_list.append(platform) return Response().ok({"platforms": platform_list}).__dict__ - async def post_astrbot_configs(self): + async def post_astrbot_configs(self) -> ResponseReturnValue: data = await request.json config = data.get("config", None) conf_id = data.get("conf_id", None) @@ -862,7 +895,7 @@ async def post_astrbot_configs(self): logger.error(traceback.format_exc()) return Response().error(str(e)).__dict__ - async def post_plugin_configs(self): + async def post_plugin_configs(self) -> ResponseReturnValue: post_configs = await request.json plugin_name = request.args.get("plugin_name", "unknown") try: @@ -876,7 +909,7 @@ async def post_plugin_configs(self): except Exception as e: return Response().error(str(e)).__dict__ - async def post_new_platform(self): + async def post_new_platform(self) -> ResponseReturnValue: new_platform_config = await request.json # 如果是支持统一 webhook 模式的平台,生成 webhook_uuid @@ -892,7 +925,7 @@ async def post_new_platform(self): return Response().error(str(e)).__dict__ return Response().ok(None, "新增平台配置成功~").__dict__ - async def post_new_provider(self): + async def post_new_provider(self) -> ResponseReturnValue: new_provider_config = await request.json try: @@ -903,7 +936,7 @@ async def post_new_provider(self): return Response().error(str(e)).__dict__ return Response().ok(None, "新增服务提供商配置成功").__dict__ - async def post_update_platform(self): + async def post_update_platform(self) -> ResponseReturnValue: update_platform_config = await request.json origin_platform_id = update_platform_config.get("id", None) new_config = update_platform_config.get("config", None) @@ -930,7 +963,7 @@ async def post_update_platform(self): return Response().error(str(e)).__dict__ return Response().ok(None, "更新平台配置成功~").__dict__ - async def post_update_provider(self): + async def post_update_provider(self) -> ResponseReturnValue: update_provider_config = await request.json origin_provider_id = update_provider_config.get("id", None) new_config = update_provider_config.get("config", None) @@ -945,7 +978,7 @@ async def post_update_provider(self): return Response().error(str(e)).__dict__ return Response().ok(None, "更新成功,已经实时生效~").__dict__ - async def post_delete_platform(self): + async def post_delete_platform(self) -> ResponseReturnValue: platform_id = await request.json platform_id = platform_id.get("id") for i, platform in enumerate(self.config["platform"]): @@ -961,7 +994,7 @@ async def post_delete_platform(self): return Response().error(str(e)).__dict__ return Response().ok(None, "删除平台配置成功~").__dict__ - async def post_delete_provider(self): + async def post_delete_provider(self) -> ResponseReturnValue: provider_id = await request.json provider_id = provider_id.get("id", "") if not provider_id: @@ -975,13 +1008,15 @@ async def post_delete_provider(self): return Response().error(str(e)).__dict__ return Response().ok(None, "删除成功,已经实时生效。").__dict__ - async def get_llm_tools(self): + async def get_llm_tools(self) -> dict: """获取函数调用工具。包含了本地加载的以及 MCP 服务的工具""" tool_mgr = self.core_lifecycle.provider_manager.llm_tools tools = tool_mgr.get_func_desc_openai_style() return Response().ok(tools).__dict__ - async def _register_platform_logo(self, platform, platform_default_tmpl): + async def _register_platform_logo( + self, platform: PlatformMetadata, platform_default_tmpl: dict + ) -> None: """注册平台logo文件并生成访问令牌""" if not platform.logo_path: return @@ -1048,7 +1083,7 @@ async def _register_platform_logo(self, platform, platform_default_tmpl): f"Unexpected error registering logo for platform {platform.name}: {e}", ) - async def _get_astrbot_config(self): + async def _get_astrbot_config(self) -> dict: config = self.config # 平台适配器的默认配置模板注入 @@ -1081,7 +1116,7 @@ async def _get_astrbot_config(self): return {"metadata": CONFIG_METADATA_2, "config": config} - async def _get_plugin_config(self, plugin_name: str): + async def _get_plugin_config(self, plugin_name: str) -> dict: ret: dict = {"metadata": None, "config": None} for plugin_md in star_registry: @@ -1104,7 +1139,7 @@ async def _get_plugin_config(self, plugin_name: str): async def _save_astrbot_configs( self, post_configs: dict, conf_id: str | None = None - ): + ) -> None: try: if conf_id not in self.acm.confs: raise ValueError(f"配置文件 {conf_id} 不存在") @@ -1120,7 +1155,7 @@ async def _save_astrbot_configs( except Exception as e: raise e - async def _save_plugin_configs(self, post_configs: dict, plugin_name: str): + async def _save_plugin_configs(self, post_configs: dict, plugin_name: str) -> None: md = None for plugin_md in star_registry: if plugin_md.name == plugin_name: diff --git a/astrbot/dashboard/routes/conversation.py b/astrbot/dashboard/routes/conversation.py index 513d3603f..4bb1d3578 100644 --- a/astrbot/dashboard/routes/conversation.py +++ b/astrbot/dashboard/routes/conversation.py @@ -3,7 +3,7 @@ from datetime import datetime from io import BytesIO -from quart import request, send_file +from quart import ResponseReturnValue, request, send_file from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle @@ -39,7 +39,7 @@ def __init__( self.core_lifecycle = core_lifecycle self.register_routes() - async def list_conversations(self): + async def list_conversations(self) -> ResponseReturnValue: """获取对话列表,支持分页、排序和筛选""" try: # 获取分页参数 @@ -104,7 +104,7 @@ async def list_conversations(self): logger.error(error_msg) return Response().error(f"获取对话列表失败: {e!s}").__dict__ - async def get_conv_detail(self): + async def get_conv_detail(self) -> ResponseReturnValue: """获取指定对话详情(通过POST请求)""" try: data = await request.get_json() @@ -141,7 +141,7 @@ async def get_conv_detail(self): logger.error(f"获取对话详情失败: {e!s}\n{traceback.format_exc()}") return Response().error(f"获取对话详情失败: {e!s}").__dict__ - async def upd_conv(self): + async def upd_conv(self) -> ResponseReturnValue: """更新对话信息(标题和角色ID)""" try: data = await request.get_json() @@ -171,7 +171,7 @@ async def upd_conv(self): logger.error(f"更新对话信息失败: {e!s}\n{traceback.format_exc()}") return Response().error(f"更新对话信息失败: {e!s}").__dict__ - async def del_conv(self): + async def del_conv(self) -> ResponseReturnValue: """删除对话""" try: data = await request.get_json() @@ -240,7 +240,7 @@ async def del_conv(self): logger.error(f"删除对话失败: {e!s}\n{traceback.format_exc()}") return Response().error(f"删除对话失败: {e!s}").__dict__ - async def update_history(self): + async def update_history(self) -> ResponseReturnValue: """更新对话历史内容""" try: data = await request.get_json() @@ -287,7 +287,7 @@ async def update_history(self): logger.error(f"更新对话历史失败: {e!s}\n{traceback.format_exc()}") return Response().error(f"更新对话历史失败: {e!s}").__dict__ - async def export_conversations(self): + async def export_conversations(self) -> ResponseReturnValue: """批量导出对话为 JSONL 格式""" try: data = await request.get_json() diff --git a/astrbot/dashboard/routes/file.py b/astrbot/dashboard/routes/file.py index 71d867fe1..2dcfb6f23 100644 --- a/astrbot/dashboard/routes/file.py +++ b/astrbot/dashboard/routes/file.py @@ -1,4 +1,4 @@ -from quart import abort, send_file +from quart import ResponseReturnValue, abort, send_file from astrbot import logger from astrbot.core import file_token_service @@ -17,7 +17,7 @@ def __init__( } self.register_routes() - async def serve_file(self, file_token: str): + async def serve_file(self, file_token: str) -> ResponseReturnValue: try: file_path = await file_token_service.handle_file(file_token) return await send_file(file_path) diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py index 537a81f0b..851a33c29 100644 --- a/astrbot/dashboard/routes/knowledge_base.py +++ b/astrbot/dashboard/routes/knowledge_base.py @@ -4,12 +4,15 @@ import os import traceback import uuid +from collections.abc import Awaitable, Callable import aiofiles -from quart import request +from quart import ResponseReturnValue, request from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.knowledge_base.kb_helper import KBHelper +from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider from ..utils import generate_tsne_visualization @@ -64,7 +67,7 @@ def __init__( } self.register_routes() - def _get_kb_manager(self): + def _get_kb_manager(self) -> KnowledgeBaseManager: return self.core_lifecycle.kb_manager def _init_task(self, task_id: str, status: str = "pending") -> None: @@ -75,7 +78,11 @@ def _init_task(self, task_id: str, status: str = "pending") -> None: } def _set_task_result( - self, task_id: str, status: str, result: any = None, error: str | None = None + self, + task_id: str, + status: str, + result: str | dict | None = None, + error: str | None = None, ) -> None: self.upload_tasks[task_id] = { "status": status, @@ -112,8 +119,10 @@ def _update_progress( if total is not None: p["total"] = total - def _make_progress_callback(self, task_id: str, file_idx: int, file_name: str): - async def _callback(stage: str, current: int, total: int): + def _make_progress_callback( + self, task_id: str, file_idx: int, file_name: str + ) -> Callable[[str, int, int], Awaitable[None]]: + async def _callback(stage: str, current: int, total: int) -> None: self._update_progress( task_id, status="processing", @@ -129,14 +138,14 @@ async def _callback(stage: str, current: int, total: int): async def _background_upload_task( self, task_id: str, - kb_helper, + kb_helper: KBHelper, files_to_upload: list, chunk_size: int, chunk_overlap: int, batch_size: int, tasks_limit: int, max_retries: int, - ): + ) -> None: """后台上传任务""" try: # 初始化任务状态 @@ -210,12 +219,12 @@ async def _background_upload_task( async def _background_import_task( self, task_id: str, - kb_helper, + kb_helper: KBHelper, documents: list, batch_size: int, tasks_limit: int, max_retries: int, - ): + ) -> None: """后台导入预切片文档任务""" try: # 初始化任务状态 @@ -294,7 +303,7 @@ async def _background_import_task( logger.error(traceback.format_exc()) self._set_task_result(task_id, "failed", error=str(e)) - async def list_kbs(self): + async def list_kbs(self) -> ResponseReturnValue: """获取知识库列表 Query 参数: @@ -326,7 +335,7 @@ async def list_kbs(self): logger.error(traceback.format_exc()) return Response().error(f"获取知识库列表失败: {e!s}").__dict__ - async def create_kb(self): + async def create_kb(self) -> ResponseReturnValue: """创建知识库 Body: @@ -423,7 +432,7 @@ async def create_kb(self): logger.error(traceback.format_exc()) return Response().error(f"创建知识库失败: {e!s}").__dict__ - async def get_kb(self): + async def get_kb(self) -> ResponseReturnValue: """获取知识库详情 Query 参数: @@ -449,7 +458,7 @@ async def get_kb(self): logger.error(traceback.format_exc()) return Response().error(f"获取知识库详情失败: {e!s}").__dict__ - async def update_kb(self): + async def update_kb(self) -> ResponseReturnValue: """更新知识库 Body: @@ -529,7 +538,7 @@ async def update_kb(self): logger.error(traceback.format_exc()) return Response().error(f"更新知识库失败: {e!s}").__dict__ - async def delete_kb(self): + async def delete_kb(self) -> ResponseReturnValue: """删除知识库 Body: @@ -556,7 +565,7 @@ async def delete_kb(self): logger.error(traceback.format_exc()) return Response().error(f"删除知识库失败: {e!s}").__dict__ - async def get_kb_stats(self): + async def get_kb_stats(self) -> ResponseReturnValue: """获取知识库统计信息 Query 参数: @@ -593,7 +602,7 @@ async def get_kb_stats(self): # ===== 文档管理 API ===== - async def list_documents(self): + async def list_documents(self) -> ResponseReturnValue: """获取文档列表 Query 参数: @@ -633,7 +642,7 @@ async def list_documents(self): logger.error(traceback.format_exc()) return Response().error(f"获取文档列表失败: {e!s}").__dict__ - async def upload_document(self): + async def upload_document(self) -> ResponseReturnValue: """上传文档 支持两种方式: @@ -771,7 +780,7 @@ async def upload_document(self): logger.error(traceback.format_exc()) return Response().error(f"上传文档失败: {e!s}").__dict__ - def _validate_import_request(self, data: dict): + def _validate_import_request(self, data: dict) -> tuple[str, list, int, int, int]: kb_id = data.get("kb_id") if not kb_id: raise ValueError("缺少参数 kb_id") @@ -795,7 +804,7 @@ def _validate_import_request(self, data: dict): max_retries = data.get("max_retries", 3) return kb_id, documents, batch_size, tasks_limit, max_retries - async def import_documents(self): + async def import_documents(self) -> ResponseReturnValue: """导入预切片文档 Body: @@ -858,7 +867,7 @@ async def import_documents(self): logger.error(traceback.format_exc()) return Response().error(f"导入文档失败: {e!s}").__dict__ - async def get_upload_progress(self): + async def get_upload_progress(self) -> ResponseReturnValue: """获取上传进度和结果 Query 参数: @@ -911,7 +920,7 @@ async def get_upload_progress(self): logger.error(traceback.format_exc()) return Response().error(f"获取上传进度失败: {e!s}").__dict__ - async def get_document(self): + async def get_document(self) -> ResponseReturnValue: """获取文档详情 Query 参数: @@ -942,7 +951,7 @@ async def get_document(self): logger.error(traceback.format_exc()) return Response().error(f"获取文档详情失败: {e!s}").__dict__ - async def delete_document(self): + async def delete_document(self) -> ResponseReturnValue: """删除文档 Body: @@ -974,7 +983,7 @@ async def delete_document(self): logger.error(traceback.format_exc()) return Response().error(f"删除文档失败: {e!s}").__dict__ - async def delete_chunk(self): + async def delete_chunk(self) -> ResponseReturnValue: """删除文本块 Body: @@ -1009,7 +1018,7 @@ async def delete_chunk(self): logger.error(traceback.format_exc()) return Response().error(f"删除文本块失败: {e!s}").__dict__ - async def list_chunks(self): + async def list_chunks(self) -> ResponseReturnValue: """获取块列表 Query 参数: @@ -1058,7 +1067,7 @@ async def list_chunks(self): # ===== 检索 API ===== - async def retrieve(self): + async def retrieve(self) -> ResponseReturnValue: """检索知识库 Body: @@ -1121,7 +1130,7 @@ async def retrieve(self): logger.error(traceback.format_exc()) return Response().error(f"检索失败: {e!s}").__dict__ - async def upload_document_from_url(self): + async def upload_document_from_url(self) -> ResponseReturnValue: """从 URL 上传文档 Body: @@ -1205,7 +1214,7 @@ async def upload_document_from_url(self): async def _background_upload_from_url_task( self, task_id: str, - kb_helper, + kb_helper: KBHelper, url: str, chunk_size: int, chunk_overlap: int, @@ -1214,7 +1223,7 @@ async def _background_upload_from_url_task( max_retries: int, enable_cleaning: bool, cleaning_provider_id: str | None, - ): + ) -> None: """后台上传URL任务""" try: # 初始化任务状态 diff --git a/astrbot/dashboard/routes/log.py b/astrbot/dashboard/routes/log.py index d5aa7c1de..360118e40 100644 --- a/astrbot/dashboard/routes/log.py +++ b/astrbot/dashboard/routes/log.py @@ -5,7 +5,7 @@ from typing import cast from quart import Response as QuartResponse -from quart import make_response, request +from quart import ResponseReturnValue, make_response, request from astrbot.core import LogBroker, logger @@ -54,7 +54,7 @@ async def _replay_cached_logs( async def log(self) -> QuartResponse: last_event_id = request.headers.get("Last-Event-ID") - async def stream(): + async def stream() -> AsyncGenerator[str, None]: queue = None try: if last_event_id: @@ -90,7 +90,7 @@ async def stream(): response.timeout = None # type: ignore return response - async def log_history(self): + async def log_history(self) -> ResponseReturnValue: """获取日志历史""" try: logs = list(self.log_broker.log_cache) diff --git a/astrbot/dashboard/routes/persona.py b/astrbot/dashboard/routes/persona.py index 7ddb75f17..3ee1c4882 100644 --- a/astrbot/dashboard/routes/persona.py +++ b/astrbot/dashboard/routes/persona.py @@ -1,6 +1,6 @@ import traceback -from quart import request +from quart import ResponseReturnValue, request from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle @@ -28,7 +28,7 @@ def __init__( self.persona_mgr = core_lifecycle.persona_mgr self.register_routes() - async def list_personas(self): + async def list_personas(self) -> ResponseReturnValue: """获取所有人格列表""" try: personas = await self.persona_mgr.get_all_personas() @@ -57,7 +57,7 @@ async def list_personas(self): logger.error(f"获取人格列表失败: {e!s}\n{traceback.format_exc()}") return Response().error(f"获取人格列表失败: {e!s}").__dict__ - async def get_persona_detail(self): + async def get_persona_detail(self) -> ResponseReturnValue: """获取指定人格的详细信息""" try: data = await request.get_json() @@ -92,7 +92,7 @@ async def get_persona_detail(self): logger.error(f"获取人格详情失败: {e!s}\n{traceback.format_exc()}") return Response().error(f"获取人格详情失败: {e!s}").__dict__ - async def create_persona(self): + async def create_persona(self) -> ResponseReturnValue: """创建新人格""" try: data = await request.get_json() @@ -149,7 +149,7 @@ async def create_persona(self): logger.error(f"创建人格失败: {e!s}\n{traceback.format_exc()}") return Response().error(f"创建人格失败: {e!s}").__dict__ - async def update_persona(self): + async def update_persona(self) -> ResponseReturnValue: """更新人格信息""" try: data = await request.get_json() @@ -183,7 +183,7 @@ async def update_persona(self): logger.error(f"更新人格失败: {e!s}\n{traceback.format_exc()}") return Response().error(f"更新人格失败: {e!s}").__dict__ - async def delete_persona(self): + async def delete_persona(self) -> ResponseReturnValue: """删除人格""" try: data = await request.get_json() diff --git a/astrbot/dashboard/routes/platform.py b/astrbot/dashboard/routes/platform.py index 4d8fdddfe..835ea6224 100644 --- a/astrbot/dashboard/routes/platform.py +++ b/astrbot/dashboard/routes/platform.py @@ -3,7 +3,7 @@ 提供统一的 webhook 回调入口,支持多个平台使用同一端口接收回调。 """ -from quart import request +from quart import ResponseReturnValue, request from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle @@ -26,7 +26,7 @@ def __init__( self._register_webhook_routes() - def _register_webhook_routes(self): + def _register_webhook_routes(self) -> None: """注册 webhook 路由""" # 统一 webhook 入口,支持 GET 和 POST self.app.add_url_rule( @@ -42,7 +42,7 @@ def _register_webhook_routes(self): methods=["GET"], ) - async def unified_webhook_callback(self, webhook_uuid: str): + async def unified_webhook_callback(self, webhook_uuid: str) -> ResponseReturnValue: """统一 webhook 回调入口 Args: @@ -86,7 +86,7 @@ def _find_platform_by_uuid(self, webhook_uuid: str) -> Platform | None: return platform return None - async def get_platform_stats(self): + async def get_platform_stats(self) -> ResponseReturnValue: """获取所有平台的统计信息 Returns: diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py index fd808c6c9..c4fca25f9 100644 --- a/astrbot/dashboard/routes/plugin.py +++ b/astrbot/dashboard/routes/plugin.py @@ -9,7 +9,7 @@ import aiohttp import certifi -from quart import request +from quart import ResponseReturnValue, request from astrbot.api import sp from astrbot.core import DEMO_MODE, file_token_service, logger @@ -73,7 +73,7 @@ def __init__( self._logo_cache = {} - async def reload_plugins(self): + async def reload_plugins(self) -> ResponseReturnValue: if DEMO_MODE: return ( Response() @@ -92,7 +92,7 @@ async def reload_plugins(self): logger.error(f"/api/plugin/reload: {traceback.format_exc()}") return Response().error(str(e)).__dict__ - async def get_online_plugins(self): + async def get_online_plugins(self) -> ResponseReturnValue: custom = request.args.get("custom_registry") force_refresh = request.args.get("force_refresh", "false").lower() == "true" @@ -244,7 +244,7 @@ async def _is_cache_valid(self, source: RegistrySource) -> bool: logger.warning(f"检查缓存有效性失败: {e}") return False - def _load_plugin_cache(self, cache_file: str): + def _load_plugin_cache(self, cache_file: str) -> dict | None: """加载本地缓存的插件市场数据""" try: if os.path.exists(cache_file): @@ -260,7 +260,9 @@ def _load_plugin_cache(self, cache_file: str): logger.warning(f"加载插件市场缓存失败: {e}") return None - def _save_plugin_cache(self, cache_file: str, data, md5: str | None = None): + def _save_plugin_cache( + self, cache_file: str, data: object, md5: str | None = None + ) -> None: """保存插件市场数据到本地缓存""" try: # 确保目录存在 @@ -278,7 +280,7 @@ def _save_plugin_cache(self, cache_file: str, data, md5: str | None = None): except Exception as e: logger.warning(f"保存插件市场缓存失败: {e}") - async def get_plugin_logo_token(self, logo_path: str): + async def get_plugin_logo_token(self, logo_path: str) -> str | None: try: if token := self._logo_cache.get(logo_path): if not await file_token_service.check_token_expired(token): @@ -290,7 +292,7 @@ async def get_plugin_logo_token(self, logo_path: str): logger.warning(f"获取插件 Logo 失败: {e}") return None - async def get_plugins(self): + async def get_plugins(self) -> ResponseReturnValue: _plugin_resp = [] plugin_name = request.args.get("name") for plugin in self.plugin_manager.context.get_all_stars(): @@ -321,7 +323,7 @@ async def get_plugins(self): .__dict__ ) - async def get_plugin_handlers_info(self, handler_full_names: list[str]): + async def get_plugin_handlers_info(self, handler_full_names: list[str]) -> list: """解析插件行为""" handlers = [] @@ -382,7 +384,7 @@ async def get_plugin_handlers_info(self, handler_full_names: list[str]): return handlers - async def install_plugin(self): + async def install_plugin(self) -> ResponseReturnValue: if DEMO_MODE: return ( Response() @@ -407,7 +409,7 @@ async def install_plugin(self): logger.error(traceback.format_exc()) return Response().error(str(e)).__dict__ - async def install_plugin_upload(self): + async def install_plugin_upload(self) -> ResponseReturnValue: if DEMO_MODE: return ( Response() @@ -429,7 +431,7 @@ async def install_plugin_upload(self): logger.error(traceback.format_exc()) return Response().error(str(e)).__dict__ - async def uninstall_plugin(self): + async def uninstall_plugin(self) -> ResponseReturnValue: if DEMO_MODE: return ( Response() @@ -454,7 +456,7 @@ async def uninstall_plugin(self): logger.error(traceback.format_exc()) return Response().error(str(e)).__dict__ - async def update_plugin(self): + async def update_plugin(self) -> ResponseReturnValue: if DEMO_MODE: return ( Response() @@ -476,7 +478,7 @@ async def update_plugin(self): logger.error(f"/api/plugin/update: {traceback.format_exc()}") return Response().error(str(e)).__dict__ - async def update_all_plugins(self): + async def update_all_plugins(self) -> ResponseReturnValue: if DEMO_MODE: return ( Response() @@ -494,7 +496,7 @@ async def update_all_plugins(self): results = [] sem = asyncio.Semaphore(PLUGIN_UPDATE_CONCURRENCY) - async def _update_one(name: str): + async def _update_one(name: str) -> dict[str, str]: async with sem: try: logger.info(f"批量更新插件 {name}") @@ -529,7 +531,7 @@ async def _update_one(name: str): return Response().ok({"results": results}, message).__dict__ - async def off_plugin(self): + async def off_plugin(self) -> ResponseReturnValue: if DEMO_MODE: return ( Response() @@ -547,7 +549,7 @@ async def off_plugin(self): logger.error(f"/api/plugin/off: {traceback.format_exc()}") return Response().error(str(e)).__dict__ - async def on_plugin(self): + async def on_plugin(self) -> ResponseReturnValue: if DEMO_MODE: return ( Response() @@ -565,7 +567,7 @@ async def on_plugin(self): logger.error(f"/api/plugin/on: {traceback.format_exc()}") return Response().error(str(e)).__dict__ - async def get_plugin_readme(self): + async def get_plugin_readme(self) -> ResponseReturnValue: plugin_name = request.args.get("name") logger.debug(f"正在获取插件 {plugin_name} 的README文件内容") @@ -615,12 +617,12 @@ async def get_plugin_readme(self): logger.error(f"/api/plugin/readme: {traceback.format_exc()}") return Response().error(f"读取README文件失败: {e!s}").__dict__ - async def get_custom_source(self): + async def get_custom_source(self) -> ResponseReturnValue: """获取自定义插件源""" sources = await sp.global_get("custom_plugin_sources", []) return Response().ok(sources).__dict__ - async def save_custom_source(self): + async def save_custom_source(self) -> ResponseReturnValue: """保存自定义插件源""" try: data = await request.get_json() diff --git a/astrbot/dashboard/routes/route.py b/astrbot/dashboard/routes/route.py index 01ab292d4..a332e58b3 100644 --- a/astrbot/dashboard/routes/route.py +++ b/astrbot/dashboard/routes/route.py @@ -1,6 +1,7 @@ +from collections.abc import Awaitable, Callable from dataclasses import dataclass -from quart import Quart +from quart import Quart, ResponseReturnValue from astrbot.core.config.astrbot_config import AstrBotConfig @@ -14,12 +15,14 @@ class RouteContext: class Route: routes: list | dict - def __init__(self, context: RouteContext): + def __init__(self, context: RouteContext) -> None: self.app = context.app self.config = context.config - def register_routes(self): - def _add_rule(path, method, func): + def register_routes(self) -> None: + def _add_rule( + path: str, method: str, func: Callable[[], Awaitable[ResponseReturnValue]] + ) -> None: # 统一添加 /api 前缀 full_path = f"/api{path}" self.app.add_url_rule(full_path, view_func=func, methods=[method]) @@ -45,12 +48,14 @@ class Response: message: str | None = None data: dict | list | None = None - def error(self, message: str): + def error(self, message: str) -> "Response": self.status = "error" self.message = message return self - def ok(self, data: dict | list | None = None, message: str | None = None): + def ok( + self, data: dict | list | None = None, message: str | None = None + ) -> "Response": self.status = "ok" if data is None: data = {} diff --git a/astrbot/dashboard/routes/session_management.py b/astrbot/dashboard/routes/session_management.py index a938d662d..7274fb36c 100644 --- a/astrbot/dashboard/routes/session_management.py +++ b/astrbot/dashboard/routes/session_management.py @@ -1,4 +1,4 @@ -from quart import request +from quart import ResponseReturnValue, request from sqlalchemy.ext.asyncio import AsyncSession from sqlmodel import col, select @@ -109,7 +109,7 @@ async def _get_umo_rules( return paginated_rules, total - async def list_session_rule(self): + async def list_session_rule(self) -> ResponseReturnValue: """获取所有自定义的规则(支持分页和搜索) 返回已配置规则的 umo 列表及其规则内容,以及可用的 personas 和 providers @@ -240,7 +240,7 @@ async def list_session_rule(self): logger.error(f"获取规则列表失败: {e!s}") return Response().error(f"获取规则列表失败: {e!s}").__dict__ - async def update_session_rule(self): + async def update_session_rule(self) -> ResponseReturnValue: """更新某个 umo 的自定义规则 请求体: @@ -280,7 +280,7 @@ async def update_session_rule(self): logger.error(f"更新会话规则失败: {e!s}") return Response().error(f"更新会话规则失败: {e!s}").__dict__ - async def delete_session_rule(self): + async def delete_session_rule(self) -> ResponseReturnValue: """删除某个 umo 的自定义规则 请求体: @@ -315,7 +315,7 @@ async def delete_session_rule(self): logger.error(f"删除会话规则失败: {e!s}") return Response().error(f"删除会话规则失败: {e!s}").__dict__ - async def batch_delete_session_rule(self): + async def batch_delete_session_rule(self) -> ResponseReturnValue: """批量删除多个 umo 的自定义规则 请求体: @@ -371,7 +371,7 @@ async def batch_delete_session_rule(self): logger.error(f"批量删除会话规则失败: {e!s}") return Response().error(f"批量删除会话规则失败: {e!s}").__dict__ - async def list_umos(self): + async def list_umos(self) -> ResponseReturnValue: """列出所有有对话记录的 umo,从 Conversations 表中找 仅返回 umo 字符串列表,用于用户在创建规则时选择 umo diff --git a/astrbot/dashboard/routes/stat.py b/astrbot/dashboard/routes/stat.py index 054eec995..3516a777f 100644 --- a/astrbot/dashboard/routes/stat.py +++ b/astrbot/dashboard/routes/stat.py @@ -7,7 +7,7 @@ import aiohttp import psutil -from quart import request +from quart import ResponseReturnValue, request from astrbot.core import DEMO_MODE, logger from astrbot.core.config import VERSION @@ -42,7 +42,7 @@ def __init__( self.register_routes() self.core_lifecycle = core_lifecycle - async def restart_core(self): + async def restart_core(self) -> ResponseReturnValue: if DEMO_MODE: return ( Response() @@ -53,13 +53,13 @@ async def restart_core(self): await self.core_lifecycle.restart() return Response().ok().__dict__ - def _get_running_time_components(self, total_seconds: int): + def _get_running_time_components(self, total_seconds: int) -> dict[str, int]: """将总秒数转换为时分秒组件""" minutes, seconds = divmod(total_seconds, 60) hours, minutes = divmod(minutes, 60) return {"hours": hours, "minutes": minutes, "seconds": seconds} - def is_default_cred(self): + def is_default_cred(self) -> bool: username = self.config["dashboard"]["username"] password = self.config["dashboard"]["password"] return ( @@ -68,7 +68,7 @@ def is_default_cred(self): and not DEMO_MODE ) - async def get_version(self): + async def get_version(self) -> ResponseReturnValue: need_migration = await check_migration_needed_v4(self.core_lifecycle.db) return ( @@ -84,10 +84,10 @@ async def get_version(self): .__dict__ ) - async def get_start_time(self): + async def get_start_time(self) -> ResponseReturnValue: return Response().ok({"start_time": self.core_lifecycle.start_time}).__dict__ - async def get_stat(self): + async def get_stat(self) -> ResponseReturnValue: offset_sec = request.args.get("offset_sec", 86400) offset_sec = int(offset_sec) try: @@ -156,7 +156,7 @@ async def get_stat(self): logger.error(traceback.format_exc()) return Response().error(e.__str__()).__dict__ - async def test_ghproxy_connection(self): + async def test_ghproxy_connection(self) -> ResponseReturnValue: """测试 GitHub 代理连接是否可用。""" try: data = await request.get_json() @@ -191,7 +191,7 @@ async def test_ghproxy_connection(self): logger.error(traceback.format_exc()) return Response().error(f"Error: {e!s}").__dict__ - async def get_changelog(self): + async def get_changelog(self) -> ResponseReturnValue: """获取指定版本的更新日志""" try: version = request.args.get("version") @@ -249,7 +249,7 @@ async def get_changelog(self): logger.error(traceback.format_exc()) return Response().error(f"Error: {e!s}").__dict__ - async def list_changelog_versions(self): + async def list_changelog_versions(self) -> ResponseReturnValue: """获取所有可用的更新日志版本列表""" try: project_path = get_astrbot_path() diff --git a/astrbot/dashboard/routes/static_file.py b/astrbot/dashboard/routes/static_file.py index 3d3d0ca51..1a16963d8 100644 --- a/astrbot/dashboard/routes/static_file.py +++ b/astrbot/dashboard/routes/static_file.py @@ -1,3 +1,5 @@ +from quart import ResponseReturnValue + from .route import Route, RouteContext @@ -30,8 +32,8 @@ def __init__(self, context: RouteContext) -> None: self.app.add_url_rule(i, view_func=self.index) @self.app.errorhandler(404) - async def page_not_found(e): + async def page_not_found(e: object) -> ResponseReturnValue: return "404 Not found。如果你初次使用打开面板发现 404, 请参考文档: https://astrbot.app/faq.html。如果你正在测试回调地址可达性,显示这段文字说明测试成功了。" - async def index(self): + async def index(self) -> ResponseReturnValue: return await self.app.send_static_file("index.html") diff --git a/astrbot/dashboard/routes/t2i.py b/astrbot/dashboard/routes/t2i.py index db70a8820..deea9dbfd 100644 --- a/astrbot/dashboard/routes/t2i.py +++ b/astrbot/dashboard/routes/t2i.py @@ -2,7 +2,7 @@ from dataclasses import asdict -from quart import jsonify, request +from quart import ResponseReturnValue, jsonify, request from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle @@ -12,7 +12,9 @@ class T2iRoute(Route): - def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle): + def __init__( + self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle + ) -> None: super().__init__(context) self.core_lifecycle = core_lifecycle self.config = core_lifecycle.astrbot_config @@ -36,7 +38,7 @@ def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle): ] self.register_routes() - async def list_templates(self): + async def list_templates(self) -> ResponseReturnValue: """获取所有T2I模板列表""" try: templates = self.manager.list_templates() @@ -46,7 +48,7 @@ async def list_templates(self): response.status_code = 500 return response - async def get_active_template(self): + async def get_active_template(self) -> ResponseReturnValue: """获取当前激活的T2I模板""" try: active_template = self.config.get("t2i_active_template", "base") @@ -59,7 +61,7 @@ async def get_active_template(self): response.status_code = 500 return response - async def get_template(self, name: str): + async def get_template(self, name: str) -> ResponseReturnValue: """获取指定名称的T2I模板内容""" try: content = self.manager.get_template(name) @@ -75,7 +77,7 @@ async def get_template(self, name: str): response.status_code = 500 return response - async def create_template(self): + async def create_template(self) -> ResponseReturnValue: """创建一个新的T2I模板""" try: data = await request.json @@ -115,7 +117,7 @@ async def create_template(self): response.status_code = 500 return response - async def update_template(self, name: str): + async def update_template(self, name: str) -> ResponseReturnValue: """更新一个已存在的T2I模板""" try: name = name.strip() @@ -146,7 +148,7 @@ async def update_template(self, name: str): response.status_code = 500 return response - async def delete_template(self, name: str): + async def delete_template(self, name: str) -> ResponseReturnValue: """删除一个T2I模板""" try: name = name.strip() @@ -167,7 +169,7 @@ async def delete_template(self, name: str): response.status_code = 500 return response - async def set_active_template(self): + async def set_active_template(self) -> ResponseReturnValue: """设置当前活动的T2I模板""" try: data = await request.json @@ -202,7 +204,7 @@ async def set_active_template(self): response.status_code = 500 return response - async def reset_default_template(self): + async def reset_default_template(self) -> ResponseReturnValue: """重置默认的'base'模板""" try: self.manager.reset_default_template() diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py index d7b082000..1674ff03d 100644 --- a/astrbot/dashboard/routes/tools.py +++ b/astrbot/dashboard/routes/tools.py @@ -1,6 +1,6 @@ import traceback -from quart import request +from quart import ResponseReturnValue, request from astrbot.core import logger from astrbot.core.agent.mcp_client import MCPTool @@ -33,7 +33,7 @@ def __init__( self.register_routes() self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools - async def get_mcp_servers(self): + async def get_mcp_servers(self) -> ResponseReturnValue: try: config = self.tool_mgr.load_mcp_config() servers = [] @@ -69,7 +69,7 @@ async def get_mcp_servers(self): logger.error(traceback.format_exc()) return Response().error(f"获取 MCP 服务器列表失败: {e!s}").__dict__ - async def add_mcp_server(self): + async def add_mcp_server(self) -> ResponseReturnValue: try: server_data = await request.json @@ -125,7 +125,7 @@ async def add_mcp_server(self): logger.error(traceback.format_exc()) return Response().error(f"添加 MCP 服务器失败: {e!s}").__dict__ - async def update_mcp_server(self): + async def update_mcp_server(self) -> ResponseReturnValue: try: server_data = await request.json @@ -229,7 +229,7 @@ async def update_mcp_server(self): logger.error(traceback.format_exc()) return Response().error(f"更新 MCP 服务器失败: {e!s}").__dict__ - async def delete_mcp_server(self): + async def delete_mcp_server(self) -> ResponseReturnValue: try: server_data = await request.json name = server_data.get("name", "") @@ -265,7 +265,7 @@ async def delete_mcp_server(self): logger.error(traceback.format_exc()) return Response().error(f"删除 MCP 服务器失败: {e!s}").__dict__ - async def test_mcp_connection(self): + async def test_mcp_connection(self) -> ResponseReturnValue: """测试 MCP 服务器连接""" try: server_data = await request.json @@ -293,7 +293,7 @@ async def test_mcp_connection(self): logger.error(traceback.format_exc()) return Response().error(f"测试 MCP 连接失败: {e!s}").__dict__ - async def get_tool_list(self): + async def get_tool_list(self) -> ResponseReturnValue: """获取所有注册的工具列表""" try: tools = self.tool_mgr.func_list @@ -326,7 +326,7 @@ async def get_tool_list(self): logger.error(traceback.format_exc()) return Response().error(f"获取工具列表失败: {e!s}").__dict__ - async def toggle_tool(self): + async def toggle_tool(self) -> ResponseReturnValue: """启用或停用指定的工具""" try: data = await request.json @@ -352,7 +352,7 @@ async def toggle_tool(self): logger.error(traceback.format_exc()) return Response().error(f"操作工具失败: {e!s}").__dict__ - async def sync_provider(self): + async def sync_provider(self) -> ResponseReturnValue: """同步 MCP 提供者配置""" try: data = await request.json diff --git a/astrbot/dashboard/routes/update.py b/astrbot/dashboard/routes/update.py index b0520c315..5a6fa5520 100644 --- a/astrbot/dashboard/routes/update.py +++ b/astrbot/dashboard/routes/update.py @@ -1,6 +1,6 @@ import traceback -from quart import request +from quart import ResponseReturnValue, request from astrbot.core import DEMO_MODE, logger, pip_installer from astrbot.core.config.default import VERSION @@ -34,7 +34,7 @@ def __init__( self.core_lifecycle = core_lifecycle self.register_routes() - async def do_migration(self): + async def do_migration(self) -> ResponseReturnValue: need_migration = await check_migration_needed_v4(self.core_lifecycle.db) if not need_migration: return Response().ok(None, "不需要进行迁移。").__dict__ @@ -51,7 +51,7 @@ async def do_migration(self): logger.error(f"迁移失败: {traceback.format_exc()}") return Response().error(f"迁移失败: {e!s}").__dict__ - async def check_update(self): + async def check_update(self) -> ResponseReturnValue: type_ = request.args.get("type", None) try: @@ -77,7 +77,7 @@ async def check_update(self): logger.warning(f"检查更新失败: {e!s} (不影响除项目更新外的正常使用)") return Response().error(e.__str__()).__dict__ - async def get_releases(self): + async def get_releases(self) -> ResponseReturnValue: try: ret = await self.astrbot_updator.get_releases() return Response().ok(ret).__dict__ @@ -85,7 +85,7 @@ async def get_releases(self): logger.error(f"/api/update/releases: {traceback.format_exc()}") return Response().error(e.__str__()).__dict__ - async def update_project(self): + async def update_project(self) -> ResponseReturnValue: data = await request.json version = data.get("version", "") reboot = data.get("reboot", True) @@ -136,7 +136,7 @@ async def update_project(self): logger.error(f"/api/update_project: {traceback.format_exc()}") return Response().error(e.__str__()).__dict__ - async def update_dashboard(self): + async def update_dashboard(self) -> ResponseReturnValue: try: try: await download_dashboard(version=f"v{VERSION}", latest=False) @@ -149,7 +149,7 @@ async def update_dashboard(self): logger.error(f"/api/update_dashboard: {traceback.format_exc()}") return Response().error(e.__str__()).__dict__ - async def install_pip_package(self): + async def install_pip_package(self) -> ResponseReturnValue: if DEMO_MODE: return ( Response() diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index ad83c4886..6e1870482 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -177,7 +177,7 @@ def get_process_using_port(self, port: int) -> str: except Exception as e: return f"获取进程信息失败: {e!s}" - def _init_jwt_secret(self): + def _init_jwt_secret(self) -> None: if not self.config.get("dashboard", {}).get("jwt_secret", None): # 如果没有设置 JWT 密钥,则生成一个新的密钥 jwt_secret = os.urandom(32).hex() @@ -247,6 +247,6 @@ def run(self): shutdown_trigger=self.shutdown_trigger, ) - async def shutdown_trigger(self): + async def shutdown_trigger(self) -> None: await self.shutdown_event.wait() logger.info("AstrBot WebUI 已经被优雅地关闭") diff --git a/main.py b/main.py index 60879f065..339e3a728 100644 --- a/main.py +++ b/main.py @@ -25,7 +25,7 @@ """ -def check_env(): +def check_env() -> None: if not (sys.version_info.major == 3 and sys.version_info.minor >= 10): logger.error("请使用 Python3.10+ 运行本项目。") exit() diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index 969f0da6d..0e41d13e1 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -61,7 +61,7 @@ async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecyc @pytest.mark.asyncio -async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle): +async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle) -> None: """Tests the login functionality with both wrong and correct credentials.""" test_client = app.test_client() response = await test_client.post( @@ -83,7 +83,7 @@ async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle): @pytest.mark.asyncio -async def test_get_stat(app: Quart, authenticated_header: dict): +async def test_get_stat(app: Quart, authenticated_header: dict) -> None: test_client = app.test_client() response = await test_client.get("/api/stat/get") assert response.status_code == 401 @@ -94,7 +94,7 @@ async def test_get_stat(app: Quart, authenticated_header: dict): @pytest.mark.asyncio -async def test_plugins(app: Quart, authenticated_header: dict): +async def test_plugins(app: Quart, authenticated_header: dict) -> None: test_client = app.test_client() # 已经安装的插件 response = await test_client.get("/api/plugin/get", headers=authenticated_header) @@ -189,7 +189,7 @@ async def test_commands_api(app: Quart, authenticated_header: dict): @pytest.mark.asyncio -async def test_check_update(app: Quart, authenticated_header: dict): +async def test_check_update(app: Quart, authenticated_header: dict) -> None: test_client = app.test_client() response = await test_client.get("/api/update/check", headers=authenticated_header) assert response.status_code == 200 @@ -204,22 +204,22 @@ async def test_do_update( core_lifecycle_td: AstrBotCoreLifecycle, monkeypatch, tmp_path_factory, -): +) -> None: test_client = app.test_client() # Use a temporary path for the mock update to avoid side effects temp_release_dir = tmp_path_factory.mktemp("release") release_path = temp_release_dir / "astrbot" - async def mock_update(*args, **kwargs): + async def mock_update(*args, **kwargs) -> None: """Mocks the update process by creating a directory in the temp path.""" os.makedirs(release_path, exist_ok=True) - async def mock_download_dashboard(*args, **kwargs): + async def mock_download_dashboard(*args, **kwargs) -> None: """Mocks the dashboard download to prevent network access.""" return - async def mock_pip_install(*args, **kwargs): + async def mock_pip_install(*args, **kwargs) -> None: """Mocks pip install to prevent actual installation.""" return diff --git a/tests/test_main.py b/tests/test_main.py index 0453a51ee..d84cd44c9 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -12,12 +12,12 @@ class _version_info: - def __init__(self, major, minor): + def __init__(self, major, minor) -> None: self.major = major self.minor = minor -def test_check_env(monkeypatch): +def test_check_env(monkeypatch) -> None: version_info_correct = _version_info(3, 10) version_info_wrong = _version_info(3, 9) monkeypatch.setattr(sys, "version_info", version_info_correct) @@ -33,7 +33,7 @@ def test_check_env(monkeypatch): @pytest.mark.asyncio -async def test_check_dashboard_files_not_exists(monkeypatch): +async def test_check_dashboard_files_not_exists(monkeypatch) -> None: """Tests dashboard download when files do not exist.""" monkeypatch.setattr(os.path, "exists", lambda x: False) @@ -43,7 +43,7 @@ async def test_check_dashboard_files_not_exists(monkeypatch): @pytest.mark.asyncio -async def test_check_dashboard_files_exists_and_version_match(monkeypatch): +async def test_check_dashboard_files_exists_and_version_match(monkeypatch) -> None: """Tests that dashboard is not downloaded when it exists and version matches.""" # Mock os.path.exists to return True monkeypatch.setattr(os.path, "exists", lambda x: True) @@ -62,7 +62,7 @@ async def test_check_dashboard_files_exists_and_version_match(monkeypatch): @pytest.mark.asyncio -async def test_check_dashboard_files_exists_but_version_mismatch(monkeypatch): +async def test_check_dashboard_files_exists_but_version_mismatch(monkeypatch) -> None: """Tests that a warning is logged when dashboard version mismatches.""" monkeypatch.setattr(os.path, "exists", lambda x: True) @@ -77,7 +77,7 @@ async def test_check_dashboard_files_exists_but_version_mismatch(monkeypatch): @pytest.mark.asyncio -async def test_check_dashboard_files_with_webui_dir_arg(monkeypatch): +async def test_check_dashboard_files_with_webui_dir_arg(monkeypatch) -> None: """Tests that providing a valid webui_dir skips all checks.""" valid_dir = "/tmp/my-custom-webui" monkeypatch.setattr(os.path, "exists", lambda path: path == valid_dir) diff --git a/tests/test_plugin_manager.py b/tests/test_plugin_manager.py index 1e4cd866a..6e3cbc2c9 100644 --- a/tests/test_plugin_manager.py +++ b/tests/test_plugin_manager.py @@ -59,21 +59,21 @@ def plugin_manager_pm(tmp_path): return manager -def test_plugin_manager_initialization(plugin_manager_pm: PluginManager): +def test_plugin_manager_initialization(plugin_manager_pm: PluginManager) -> None: assert plugin_manager_pm is not None assert plugin_manager_pm.context is not None assert plugin_manager_pm.config is not None @pytest.mark.asyncio -async def test_plugin_manager_reload(plugin_manager_pm: PluginManager): +async def test_plugin_manager_reload(plugin_manager_pm: PluginManager) -> None: success, err_message = await plugin_manager_pm.reload() assert success is True assert err_message is None @pytest.mark.asyncio -async def test_install_plugin(plugin_manager_pm: PluginManager): +async def test_install_plugin(plugin_manager_pm: PluginManager) -> None: """Tests successful plugin installation in an isolated environment.""" test_repo = "https://github.com/Soulter/astrbot_plugin_essential" plugin_info = await plugin_manager_pm.install_plugin(test_repo) @@ -90,7 +90,7 @@ async def test_install_plugin(plugin_manager_pm: PluginManager): @pytest.mark.asyncio -async def test_install_nonexistent_plugin(plugin_manager_pm: PluginManager): +async def test_install_nonexistent_plugin(plugin_manager_pm: PluginManager) -> None: """Tests that installing a non-existent plugin raises an exception.""" with pytest.raises(Exception): await plugin_manager_pm.install_plugin( @@ -99,7 +99,7 @@ async def test_install_nonexistent_plugin(plugin_manager_pm: PluginManager): @pytest.mark.asyncio -async def test_update_plugin(plugin_manager_pm: PluginManager): +async def test_update_plugin(plugin_manager_pm: PluginManager) -> None: """Tests updating an existing plugin in an isolated environment.""" # First, install the plugin test_repo = "https://github.com/Soulter/astrbot_plugin_essential" @@ -110,14 +110,14 @@ async def test_update_plugin(plugin_manager_pm: PluginManager): @pytest.mark.asyncio -async def test_update_nonexistent_plugin(plugin_manager_pm: PluginManager): +async def test_update_nonexistent_plugin(plugin_manager_pm: PluginManager) -> None: """Tests that updating a non-existent plugin raises an exception.""" with pytest.raises(Exception): await plugin_manager_pm.update_plugin("non_existent_plugin") @pytest.mark.asyncio -async def test_uninstall_plugin(plugin_manager_pm: PluginManager): +async def test_uninstall_plugin(plugin_manager_pm: PluginManager) -> None: """Tests successful plugin uninstallation in an isolated environment.""" # First, install the plugin test_repo = "https://github.com/Soulter/astrbot_plugin_essential" @@ -144,7 +144,7 @@ async def test_uninstall_plugin(plugin_manager_pm: PluginManager): @pytest.mark.asyncio -async def test_uninstall_nonexistent_plugin(plugin_manager_pm: PluginManager): +async def test_uninstall_nonexistent_plugin(plugin_manager_pm: PluginManager) -> None: """Tests that uninstalling a non-existent plugin raises an exception.""" with pytest.raises(Exception): await plugin_manager_pm.uninstall_plugin("non_existent_plugin") diff --git a/tests/test_security_fixes.py b/tests/test_security_fixes.py index d4e455541..8e5cf7f6d 100644 --- a/tests/test_security_fixes.py +++ b/tests/test_security_fixes.py @@ -10,7 +10,7 @@ import pytest -def test_wecom_crypto_uses_secrets(): +def test_wecom_crypto_uses_secrets() -> None: """Test that WXBizJsonMsgCrypt uses secrets module instead of random.""" from astrbot.core.platform.sources.wecom_ai_bot.WXBizJsonMsgCrypt import Prpcrypt @@ -33,7 +33,7 @@ def test_wecom_crypto_uses_secrets(): assert 1000000000000000 <= int(decoded) <= 9999999999999999 -def test_wecomai_utils_uses_secrets(): +def test_wecomai_utils_uses_secrets() -> None: """Test that wecomai_utils uses secrets module for random string generation.""" from astrbot.core.platform.sources.wecom_ai_bot.wecomai_utils import ( generate_random_string, @@ -53,7 +53,7 @@ def test_wecomai_utils_uses_secrets(): assert len(set(random_strings)) >= 19 # Allow for 1 collision in 20 (very unlikely) -def test_azure_tts_signature_uses_secrets(): +def test_azure_tts_signature_uses_secrets() -> None: """Test that Azure TTS signature generation uses secrets module.""" import asyncio @@ -66,7 +66,7 @@ def test_azure_tts_signature_uses_secrets(): "OTTS_AUTH_TIME": "https://example.com/api/time", } - async def test_nonce_generation(): + async def test_nonce_generation() -> None: async with OTTSProvider(config) as provider: # Mock time sync to avoid actual API calls provider.time_offset = 0 @@ -94,7 +94,7 @@ async def test_nonce_generation(): asyncio.run(test_nonce_generation()) -def test_ssl_context_fallback_explicit(): +def test_ssl_context_fallback_explicit() -> None: """Test that SSL context fallback is properly configured.""" # This test verifies the SSL context configuration # We can't easily test the full io.py functions without network calls, @@ -113,7 +113,7 @@ def test_ssl_context_fallback_explicit(): # The actual code only uses this when certificate validation fails -def test_io_module_has_ssl_imports(): +def test_io_module_has_ssl_imports() -> None: """Verify that io.py properly imports ssl module.""" from astrbot.core.utils import io @@ -124,7 +124,7 @@ def test_io_module_has_ssl_imports(): assert hasattr(io.ssl, "CERT_NONE") -def test_secrets_module_randomness_quality(): +def test_secrets_module_randomness_quality() -> None: """Test that secrets module provides high-quality randomness.""" import secrets