diff --git a/migrations/versions/24c7dc8e7e67_.py b/migrations/versions/24c7dc8e7e67_.py new file mode 100644 index 00000000..f643686d --- /dev/null +++ b/migrations/versions/24c7dc8e7e67_.py @@ -0,0 +1,63 @@ +"""empty message + +迁移 ID: 24c7dc8e7e67 +父迁移: a1b2c3d4e5f6 +创建时间: 2026-01-24 23:33:36.695599 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa + +revision: str = "24c7dc8e7e67" +down_revision: str | Sequence[str] | None = "a1b2c3d4e5f6" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade(name: str = "") -> None: + if name: + return + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("nonebot_plugin_chat_chatuser") + op.drop_table("nonebot_plugin_chat_sessionmessage") + with op.batch_alter_table("nonebot_plugin_chat_sticker", schema=None) as batch_op: + batch_op.add_column(sa.Column("is_meme", sa.Boolean(), nullable=True)) + batch_op.add_column(sa.Column("meme_text", sa.Text(), nullable=True)) + batch_op.add_column(sa.Column("emotion", sa.String(length=64), nullable=True)) + batch_op.add_column(sa.Column("labels", sa.Text(), nullable=True)) + batch_op.add_column(sa.Column("context_keywords", sa.Text(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(name: str = "") -> None: + if name: + return + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("nonebot_plugin_chat_sticker", schema=None) as batch_op: + batch_op.drop_column("context_keywords") + batch_op.drop_column("labels") + batch_op.drop_column("emotion") + batch_op.drop_column("meme_text") + batch_op.drop_column("is_meme") + + op.create_table( + "nonebot_plugin_chat_sessionmessage", + sa.Column("id_", sa.INTEGER(), nullable=False), + sa.Column("user_id", sa.VARCHAR(length=128), nullable=False), + sa.Column("content", sa.TEXT(), nullable=False), + sa.Column("role", sa.VARCHAR(length=16), nullable=False), + sa.PrimaryKeyConstraint("id_", name=op.f("pk_nonebot_plugin_chat_sessionmessage")), + ) + op.create_table( + "nonebot_plugin_chat_chatuser", + sa.Column("user_id", sa.VARCHAR(length=128), nullable=False), + sa.Column("latest_chat", sa.DATETIME(), nullable=False), + sa.PrimaryKeyConstraint("user_id", name=op.f("pk_nonebot_plugin_chat_chatuser")), + ) + # ### end Alembic commands ### diff --git a/migrations/versions/3875dac35e58_.py b/migrations/versions/3875dac35e58_.py new file mode 100644 index 00000000..38119573 --- /dev/null +++ b/migrations/versions/3875dac35e58_.py @@ -0,0 +1,39 @@ +"""empty message + +迁移 ID: 3875dac35e58 +父迁移: 24c7dc8e7e67 +创建时间: 2026-01-25 13:33:07.981017 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa + +revision: str = "3875dac35e58" +down_revision: str | Sequence[str] | None = "24c7dc8e7e67" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade(name: str = "") -> None: + if name: + return + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("nonebot_plugin_chat_sticker", schema=None) as batch_op: + batch_op.drop_column("is_meme") + + # ### end Alembic commands ### + + +def downgrade(name: str = "") -> None: + if name: + return + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("nonebot_plugin_chat_sticker", schema=None) as batch_op: + batch_op.add_column(sa.Column("is_meme", sa.BOOLEAN(), nullable=True)) + + # ### end Alembic commands ### diff --git a/src/lang/zh_hans/chat.yaml b/src/lang/zh_hans/chat.yaml index bb8585b5..bb15190d 100644 --- a/src/lang/zh_hans/chat.yaml +++ b/src/lang/zh_hans/chat.yaml @@ -82,7 +82,7 @@ prompt_group: ```markdown ## 思考过程: - 当前状态: 你对当前群聊中正在发生的事情的观察和判断,以及你对群聊氛围的评估。 - - 心情和想法: 你的心情和想法(我正在想XX事,我很开心/难过/生气/平静等) + - 心情: 你当前的心情(很高兴/悲伤/生气/恐惧/平静) - 是否需要回应: 你是否需要回应群友的消息,你想要回应哪些消息或话题,以及你决定回应的原因。 - 连续发送消息数量: 评估你在本次回应中需要发送的消息数量,判断该数量是否合理。 @@ -187,8 +187,8 @@ prompt_group: 部分群友的介绍: {4} - - + 表情包推荐(根据当前你的情绪或群聊中的话题筛选): + {5} image_describe_system: | # Role @@ -469,3 +469,4 @@ sticker: send_failed: '发送失败: {}' id_not_found: '未找到 ID 为 {} 的表情包' duplicate: '该表情包已经收藏过了 (ID: {}, 相似度: {:.1%}),无需重复收藏' + not_meme: '该图片不是表情包,无法收藏' diff --git a/src/plugins/nonebot_plugin_chat/matcher/group.py b/src/plugins/nonebot_plugin_chat/matcher/group.py index 66d4b8d9..1843bea2 100644 --- a/src/plugins/nonebot_plugin_chat/matcher/group.py +++ b/src/plugins/nonebot_plugin_chat/matcher/group.py @@ -29,6 +29,7 @@ from typing import Literal, TypedDict, Optional, Any from nonebot_plugin_apscheduler import scheduler from nonebot_plugin_alconna import UniMessage, Target, get_target +from nonebot_plugin_chat.utils.sticker_manager import get_sticker_manager from nonebot_plugin_userinfo import EventUserInfo, UserInfo from nonebot_plugin_larkuser import get_user @@ -138,6 +139,7 @@ def __init__(self, processor: "MessageProcessor", max_message_count: int = 10) - self.processor = processor self.max_message_count = max_message_count self.messages: list[OpenAIMessage] = [] + self.fetcher_lock = asyncio.Lock() self.consecutive_bot_messages = 0 # 连续发送消息计数器 @@ -148,11 +150,11 @@ def clean_special_message(self) -> None: break self.messages.pop(0) - async def get_messages(self) -> list[OpenAIMessage]: + async def get_messages(self, reasoning_content: Optional[str] = None) -> list[OpenAIMessage]: self.clean_special_message() self.messages = self.messages[-self.max_message_count :] messages = copy.deepcopy(self.messages) - messages.insert(0, await self.processor.generate_system_prompt()) + messages.insert(0, await self.processor.generate_system_prompt(reasoning_content)) return messages async def fetch_reply(self) -> None: @@ -161,6 +163,27 @@ async def fetch_reply(self) -> None: async with self.fetcher_lock: await self._fetch_reply() + def _extract_reasoning_content(self, message: OpenAIMessage) -> Optional[str]: + """ + 从消息中提取思考过程内容 + + Args: + message: OpenAI 消息对象 + + Returns: + 思考过程内容,如果没有找到则返回 None + """ + content = None + if isinstance(message, dict): + content = message.get("content", "") + elif hasattr(message, "content"): + content = message.content + + if content and isinstance(content, str) and content.strip().startswith("## 思考过程"): + return content + + return None + async def _fetch_reply(self) -> None: messages = await self.get_messages() self.messages.clear() @@ -176,8 +199,28 @@ async def _fetch_reply(self) -> None: logger.info(f"Moonlark 说: {message}") fetcher.session.messages.extend(self.messages) self.messages = [] + + # 在消息流结束后检测思考过程并更新 system 消息 self.messages = fetcher.get_messages() + # 检查返回的消息中是否包含思考过程 + reasoning_content: Optional[str] = None + for msg in self.messages: + extracted = self._extract_reasoning_content(msg) + if extracted: + reasoning_content = extracted + break + + # 如果检测到思考过程,更新表情包推荐并重新生成 system 消息 + if reasoning_content: + logger.debug("检测到思考过程,正在更新表情包推荐...") + new_system_prompt = await self.processor.generate_system_prompt(reasoning_content) + # 更新 self.messages 中的 system 消息(如果有的话),或在开头插入 + if self.messages and get_role(self.messages[0]) == "system": + self.messages[0] = new_system_prompt + else: + self.messages.insert(0, new_system_prompt) + def append_user_message(self, message: str) -> None: self.consecutive_bot_messages = 0 # 收到用户消息时重置计数器 self.messages.append(generate_message(message, "user")) @@ -208,10 +251,127 @@ def insert_warning_message(self) -> None: class MessageProcessor: + async def get_sticker_recommendations(self, reasoning_content: Optional[str] = None) -> list[str]: + """ + 根据思考过程中的心情和上下文关键词获取表情包推荐 + + Args: + reasoning_content: LLM 输出的思考过程内容(以 "## 思考过程" 开头) + 如果为 None,则只根据 context_keywords 进行匹配,不根据心情筛选 + + Returns: + 推荐的表情包列表(格式为 "ID: 描述") + """ + recommendations: list[str] = [] + seen_ids: set[int] = set() # 用于去重 + + # 获取聊天记录内容 + chat_history = "\n".join(self.get_message_content_list()) + + # 只有当提供了 reasoning_content 时才根据心情筛选 + # 如果请求来自 MessageProcessor 或其他地方(reasoning_content 为 None),则跳过心情筛选 + if reasoning_content: + # 从思考过程中提取心情 + mood = self._extract_mood_from_reasoning(reasoning_content) + + # 根据心情筛选表情包 + if mood: + stickers = await self.sticker_manager.filter_by_emotion(mood, limit=3) + for sticker in stickers: + if sticker.id not in seen_ids: + seen_ids.add(sticker.id) + desc = sticker.description + recommendations.append(f"{sticker.id}: {desc}") + + # 根据 context_keywords 匹配聊天记录和思考过程 + combined_text = chat_history + if reasoning_content: + combined_text += "\n" + reasoning_content + + matched_stickers = await self._match_stickers_by_context(combined_text, exclude_ids=seen_ids) + for sticker in matched_stickers: + if sticker.id not in seen_ids: + seen_ids.add(sticker.id) + desc = sticker.description + recommendations.append(f"{sticker.id}: {desc}") + + # 限制推荐数量 + return recommendations[:10] + + def _extract_mood_from_reasoning(self, reasoning_content: Optional[str]) -> Optional[str]: + """ + 从思考过程中提取心情 + + Args: + reasoning_content: LLM 输出的思考过程内容 + + Returns: + 提取的心情字符串,如果未找到返回 None + """ + if not reasoning_content: + return None + + # 匹配 "- 心情: XXX" 格式 + mood_pattern = r"-\s*心情[::]\s*(.+?)(?:\n|$)" + match = re.search(mood_pattern, reasoning_content) + if match: + mood = match.group(1).strip() + # 清理可能的括号内容,如 "很高兴(因为...)" -> "很高兴" + mood = re.sub(r"[((].+?[))]", "", mood).strip() + return mood + + return None + + async def _match_stickers_by_context(self, text: str, exclude_ids: set[int], limit: int = 5) -> list: + """ + 根据上下文关键词匹配表情包 + + Args: + text: 要匹配的文本(聊天记录 + 思考过程) + exclude_ids: 要排除的表情包 ID 集合 + limit: 返回的最大数量 + + Returns: + 匹配的 Sticker 对象列表 + """ + from nonebot_plugin_orm import get_session + from sqlalchemy import select + from ..models import Sticker + + matched: list = [] + + async with get_session() as session: + # 获取所有有 context_keywords 的表情包 + stmt = select(Sticker).where(Sticker.context_keywords.isnot(None)) + result = await session.scalars(stmt) + stickers = list(result.all()) + + for sticker in stickers: + if sticker.id in exclude_ids: + continue + + # 解析 context_keywords JSON + try: + keywords = json.loads(sticker.context_keywords) if sticker.context_keywords else [] + except json.JSONDecodeError: + continue + + # 检查关键词是否出现在文本中 + for keyword in keywords: + if keyword and keyword in text: + matched.append(sticker) + break + + if len(matched) >= limit: + break + + return matched + def __init__(self, session: "GroupSession"): self.openai_messages = MessageQueue(self, 50) self.session = session self.enabled = True + self.sticker_manager = get_sticker_manager() self.interrupter = Interrupter(session) self.cold_until = datetime.now() self.blocked = False @@ -587,7 +747,7 @@ async def _get_user_profiles(self, chat_history: str) -> dict[str, str]: profiles[nickname] = (await session.get_one(UserProfile, {"user_id": user_id})).profile_content return profiles - async def generate_system_prompt(self) -> OpenAIMessage: + async def generate_system_prompt(self, reasoning_content: Optional[str] = None) -> OpenAIMessage: chat_history = "\n".join(self.get_message_content_list()) # 获取相关笔记 note_manager = await get_context_notes(self.session.group_id) @@ -606,6 +766,12 @@ def format_note(note): created_time = datetime.fromtimestamp(note.created_time).strftime("%y-%m-%d") return f"- {note.content} (#{note.id},创建于 {created_time})" + # 获取表情包推荐 + sticker_recommendations = await self.get_sticker_recommendations(reasoning_content) + sticker_text = ( + "\n".join([f"- {rec}" for rec in sticker_recommendations]) if sticker_recommendations else "暂无推荐" + ) + return generate_message( await lang.text( "prompt_group.default", @@ -619,6 +785,7 @@ def format_note(note): else "暂无" ), profiles_text, + sticker_text, ), "system", ) diff --git a/src/plugins/nonebot_plugin_chat/models.py b/src/plugins/nonebot_plugin_chat/models.py index 24de473f..778aef63 100644 --- a/src/plugins/nonebot_plugin_chat/models.py +++ b/src/plugins/nonebot_plugin_chat/models.py @@ -10,18 +10,6 @@ CompatibleBlob = LargeBinary().with_variant(MEDIUMBLOB(), "mysql") -class SessionMessage(Model): - id_: Mapped[int] = mapped_column(autoincrement=True, primary_key=True) - user_id: Mapped[str] = mapped_column(String(128)) - content: Mapped[str] = mapped_column(Text()) - role: Mapped[str] = mapped_column(String(16)) - - -class ChatUser(Model): - user_id: Mapped[str] = mapped_column(String(128), primary_key=True) - latest_chat: Mapped[datetime] - - class ChatGroup(Model): group_id: Mapped[str] = mapped_column(String(128), primary_key=True) blocked_user: Mapped[str] = mapped_column(Text(), default="[]") @@ -56,3 +44,8 @@ class Sticker(Model): group_id: Mapped[Optional[str]] = mapped_column(String(128), nullable=True, index=True) # 来源群聊 created_time: Mapped[float] = mapped_column(Float()) # 创建时间戳 p_hash: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) # 感知哈希,用于图片查重 + # 表情包分类索引信息(LLM 生成) + meme_text: Mapped[Optional[str]] = mapped_column(Text(), nullable=True) # 表情包中的文本 + emotion: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) # 表情包表达的情绪 + labels: Mapped[Optional[str]] = mapped_column(Text(), nullable=True) # 表情包标签(JSON 数组) + context_keywords: Mapped[Optional[str]] = mapped_column(Text(), nullable=True) # 适用语境关键词(JSON 数组) diff --git a/src/plugins/nonebot_plugin_chat/utils/sticker_manager.py b/src/plugins/nonebot_plugin_chat/utils/sticker_manager.py index 34538406..4e697ddc 100644 --- a/src/plugins/nonebot_plugin_chat/utils/sticker_manager.py +++ b/src/plugins/nonebot_plugin_chat/utils/sticker_manager.py @@ -15,16 +15,149 @@ # along with this program. If not, see . # ############################################################################## +import base64 +import json +import re +import traceback from datetime import datetime -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, TypedDict +from nonebot import logger from nonebot_plugin_orm import get_session +from nonebot_plugin_openai.utils.chat import fetch_message +from nonebot_plugin_openai.utils.message import generate_message from sqlalchemy import select from ..models import Sticker from .sticker_similarity import calculate_hash_async, check_sticker_duplicate +# 表情包分类结果类型 +class MemeClassification(TypedDict): + is_meme: bool + text: str + emotion: str + labels: List[str] + context_keywords: List[str] + + +# 表情包分类提示词 +MEME_CLASSIFICATION_PROMPT = """你是一个表情包分析 AI。 +我会向你提供一张表情包图片,你需要分析表情包的内容,并对其进行分类。 + +### 输出格式 +一段 JSON,不要包含除了 JSON 结构以外的任何内容。 + +{ + "is_meme": boolen, // 这张图片是一个表情包吗?如果是为 true。 + "text": string, // 表情包中的文本,如果没有请填空字符串。 + "emotion": string, // 表情包所表达的情绪的类型,如:高兴、难过、生气、恐惧。 + "labels": array[string], // 表情包的标签,按照参考的标签库分类中给出的示例进行编写。 + "context_keywords": array[string] // 表情包适用的语境,这个表情包适合在群聊中谈到什么关键词时出现? +} + +### 参考的标签库分类 +1. **社交回应类**:`赞同`、`反对`、`无语`、`震惊`、`委屈`、`认怂`。 +2. **网络梗类**:`吃瓜`、`摆烂`、`摸鱼`、`内卷`、`抽象`、`典`。 +3. **时间/天气类**:`早安`、`周五`、`放假`。 +4. **互动类**:`贴贴`、`抱抱`、`禁言`、`反弹`。""" + + +def extract_json_from_response(response: str) -> Optional[Dict[str, Any]]: + """ + 从 LLM 响应中提取 JSON,处理可能包含 markdown 代码块的情况 + + Args: + response: LLM 返回的原始响应文本 + + Returns: + 解析后的 JSON 字典,如果解析失败返回 None + """ + # 去除首尾空白 + response = response.strip() + + # 尝试直接解析 + try: + return json.loads(response) + except json.JSONDecodeError: + pass + + # 尝试提取 markdown 代码块中的 JSON + # 匹配 ```json ... ``` 或 ``` ... ``` + code_block_pattern = r"```(?:json)?\s*\n?([\s\S]*?)\n?```" + matches = re.findall(code_block_pattern, response) + + for match in matches: + try: + return json.loads(match.strip()) + except json.JSONDecodeError: + continue + + # 尝试查找 JSON 对象(以 { 开头,以 } 结尾) + json_pattern = r"\{[\s\S]*\}" + json_matches = re.findall(json_pattern, response) + + for match in json_matches: + try: + return json.loads(match) + except json.JSONDecodeError: + continue + + return None + + +async def classify_meme(image_data: bytes) -> Optional[MemeClassification]: + """ + 使用 LLM 对表情包进行分类 + + Args: + image_data: 图片二进制数据 + + Returns: + MemeClassification 分类结果,如果分类失败返回 None + """ + try: + # 转换图片为 base64 + image_base64 = base64.b64encode(image_data).decode("utf-8") + + # 构建消息 + messages = [ + generate_message(MEME_CLASSIFICATION_PROMPT, "system"), + generate_message( + [ + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}, + {"type": "text", "text": "请分析这张图片并输出分类 JSON。"}, + ], + "user", + ), + ] + + # 调用 LLM + response = (await fetch_message(messages, identify="Meme Classification")).strip() + + # 解析 JSON(处理可能的 markdown 代码块) + result = extract_json_from_response(response) + + if result is None: + logger.warning(f"Failed to parse meme classification response: {response}") + return None + + # 验证并转换结果 + classification: MemeClassification = { + "is_meme": bool(result.get("is_meme", False)), + "text": str(result.get("text", "")), + "emotion": str(result.get("emotion", "")), + "labels": list(result.get("labels", [])), + "context_keywords": list(result.get("context_keywords", [])), + } + + return classification + + except Exception as e: + logger.warning(f"Failed to classify meme: {e}\n{traceback.format_exc()}") + return None + + class DuplicateStickerError(Exception): """表情包重复异常""" @@ -34,6 +167,14 @@ def __init__(self, existing_sticker: Sticker, similarity: float): super().__init__(f"发现重复的表情包 (ID: {existing_sticker.id}, 相似度: {similarity:.2%})") +class NotMemeError(Exception): + """图片不是表情包异常""" + + def __init__(self, message: str = "该图片不是表情包"): + self.message = message + super().__init__(message) + + class StickerManager: """Sticker management system for saving, searching and retrieving stickers""" @@ -67,12 +208,35 @@ async def save_sticker( # 计算感知哈希 p_hash = await calculate_hash_async(raw) + # 调用 LLM 进行表情包分类 + classification = await classify_meme(raw) + + # 准备分类数据 + meme_text: Optional[str] = None + emotion: Optional[str] = None + labels_json: Optional[str] = None + context_keywords_json: Optional[str] = None + + if classification is not None: + # 如果不是表情包,拒绝添加 + if not classification["is_meme"]: + raise NotMemeError("该图片不是表情包,无法收藏") + + meme_text = classification["text"] + emotion = classification["emotion"] + labels_json = json.dumps(classification["labels"], ensure_ascii=False) + context_keywords_json = json.dumps(classification["context_keywords"], ensure_ascii=False) + sticker = Sticker( description=description, raw=raw, group_id=group_id, created_time=current_time.timestamp(), p_hash=p_hash if p_hash else None, + meme_text=meme_text, + emotion=emotion, + labels=labels_json, + context_keywords=context_keywords_json, ) session.add(sticker) @@ -186,12 +350,220 @@ async def get_all_stickers(self, limit: int = 100) -> List[Sticker]: result = await session.scalars(stmt) return list(result.all()) + async def filter_by_emotion(self, emotion: str, limit: int = 10) -> List[Sticker]: + """ + Filter stickers by emotion + + Args: + emotion: Emotion to filter by (e.g., "高兴", "难过", "生气") + limit: Maximum number of results to return + + Returns: + List of matching Sticker objects + """ + async with get_session() as session: + stmt = ( + select(Sticker) + .where(Sticker.emotion.contains(emotion)) + .order_by(Sticker.created_time.desc()) + .limit(limit) + ) + result = await session.scalars(stmt) + return list(result.all()) + + async def filter_by_label(self, label: str, limit: int = 10) -> List[Sticker]: + """ + Filter stickers by label (searches within JSON array) + + Args: + label: Label to filter by (e.g., "赞同", "摆烂", "贴贴") + limit: Maximum number of results to return + + Returns: + List of matching Sticker objects + """ + async with get_session() as session: + # labels 字段是 JSON 数组字符串,使用 contains 进行模糊匹配 + stmt = ( + select(Sticker).where(Sticker.labels.contains(label)).order_by(Sticker.created_time.desc()).limit(limit) + ) + result = await session.scalars(stmt) + return list(result.all()) + + async def filter_by_context_keyword(self, keyword: str, limit: int = 10) -> List[Sticker]: + """ + Filter stickers by context keyword (searches within JSON array) + + Args: + keyword: Context keyword to filter by + limit: Maximum number of results to return + + Returns: + List of matching Sticker objects + """ + async with get_session() as session: + # context_keywords 字段是 JSON 数组字符串,使用 contains 进行模糊匹配 + stmt = ( + select(Sticker) + .where(Sticker.context_keywords.contains(keyword)) + .order_by(Sticker.created_time.desc()) + .limit(limit) + ) + result = await session.scalars(stmt) + return list(result.all()) + + async def filter_by_classification( + self, + emotion: Optional[str] = None, + labels: Optional[List[str]] = None, + context_keywords: Optional[List[str]] = None, + limit: int = 10, + ) -> List[Sticker]: + """ + Filter stickers by multiple classification criteria (AND logic) + + Args: + emotion: Emotion to filter by (optional) + labels: List of labels to filter by, all must match (optional) + context_keywords: List of context keywords to filter by, all must match (optional) + limit: Maximum number of results to return + + Returns: + List of matching Sticker objects + """ + async with get_session() as session: + stmt = select(Sticker) + + # 应用情绪筛选 + if emotion: + stmt = stmt.where(Sticker.emotion.contains(emotion)) + + # 应用标签筛选(所有标签都必须匹配) + if labels: + for label in labels: + stmt = stmt.where(Sticker.labels.contains(label)) + + # 应用语境关键词筛选(所有关键词都必须匹配) + if context_keywords: + for keyword in context_keywords: + stmt = stmt.where(Sticker.context_keywords.contains(keyword)) + + stmt = stmt.order_by(Sticker.created_time.desc()).limit(limit) + result = await session.scalars(stmt) + return list(result.all()) + + async def migrate_existing_stickers(self, batch_size: int = 10) -> Dict[str, int]: + """ + Migrate existing stickers by classifying them with LLM + + This method processes stickers that don't have classification data + (meme_text, emotion, labels, context_keywords are all None) + + Args: + batch_size: Number of stickers to process in each batch + + Returns: + Dict with migration statistics: + - total: Total number of stickers processed + - success: Number of successfully classified stickers + - failed: Number of stickers that failed classification + - skipped: Number of stickers already classified + """ + stats = {"total": 0, "success": 0, "failed": 0, "skipped": 0} + + async with get_session() as session: + # 查找所有未分类的表情包(分类字段都为 None) + stmt = select(Sticker).where( + Sticker.meme_text.is_(None), + Sticker.emotion.is_(None), + Sticker.labels.is_(None), + Sticker.context_keywords.is_(None), + ) + result = await session.scalars(stmt) + stickers_to_migrate = list(result.all()) + + stats["total"] = len(stickers_to_migrate) + + for sticker in stickers_to_migrate: + try: + # 调用 LLM 进行分类 + classification = await classify_meme(sticker.raw) + + if classification is None: + logger.warning(f"Failed to classify sticker {sticker.id}: LLM returned None") + stats["failed"] += 1 + continue + + # 更新分类信息 + sticker.meme_text = classification["text"] + sticker.emotion = classification["emotion"] + sticker.labels = json.dumps(classification["labels"], ensure_ascii=False) + sticker.context_keywords = json.dumps(classification["context_keywords"], ensure_ascii=False) + + session.add(sticker) + stats["success"] += 1 + + logger.info(f"Successfully classified sticker {sticker.id}: emotion={classification['emotion']}") + + except Exception as e: + logger.warning(f"Failed to migrate sticker {sticker.id}: {e}") + stats["failed"] += 1 + + # 提交所有更改 + await session.commit() + + logger.info( + f"Sticker migration completed: {stats['success']} success, " + f"{stats['failed']} failed, {stats['skipped']} skipped out of {stats['total']} total" + ) + + return stats + + async def classify_single_sticker(self, sticker_id: int) -> bool: + """ + Classify a single sticker by its ID + + Args: + sticker_id: The ID of the sticker to classify + + Returns: + True if classification was successful, False otherwise + """ + async with get_session() as session: + sticker = await session.get(Sticker, sticker_id) + + if sticker is None: + logger.warning(f"Sticker {sticker_id} not found") + return False + + try: + classification = await classify_meme(sticker.raw) + + if classification is None: + logger.warning(f"Failed to classify sticker {sticker_id}: LLM returned None") + return False + + sticker.meme_text = classification["text"] + sticker.emotion = classification["emotion"] + sticker.labels = json.dumps(classification["labels"], ensure_ascii=False) + sticker.context_keywords = json.dumps(classification["context_keywords"], ensure_ascii=False) + + session.add(sticker) + await session.commit() + + logger.info(f"Successfully classified sticker {sticker_id}") + return True + + except Exception as e: + logger.warning(f"Failed to classify sticker {sticker_id}: {e}") + return False + # Global sticker manager instance sticker_manager = StickerManager() -async def get_sticker_manager() -> StickerManager: +def get_sticker_manager() -> StickerManager: """ Get the global StickerManager instance diff --git a/src/plugins/nonebot_plugin_chat/utils/tools/sticker.py b/src/plugins/nonebot_plugin_chat/utils/tools/sticker.py index ce7c3fb4..7dbd383a 100644 --- a/src/plugins/nonebot_plugin_chat/utils/tools/sticker.py +++ b/src/plugins/nonebot_plugin_chat/utils/tools/sticker.py @@ -38,7 +38,7 @@ async def save_sticker_func(session: "GroupSession", image_id: str) -> str: Returns: Success or error message """ - from ..sticker_manager import DuplicateStickerError + from ..sticker_manager import DuplicateStickerError, NotMemeError # Get image data from cache image_data = await get_image_by_id(image_id) @@ -47,7 +47,7 @@ async def save_sticker_func(session: "GroupSession", image_id: str) -> str: return await lang.text("sticker.not_found", session.user_id) # Get sticker manager and save - manager = await get_sticker_manager() + manager = get_sticker_manager() try: sticker = await manager.save_sticker( description=image_data["description"], @@ -57,6 +57,8 @@ async def save_sticker_func(session: "GroupSession", image_id: str) -> str: return await lang.text("sticker.saved", session.user_id, sticker.id) except DuplicateStickerError as e: return await lang.text("sticker.duplicate", session.user_id, e.existing_sticker.id, e.similarity) + except NotMemeError: + return await lang.text("sticker.not_meme", session.user_id) async def search_sticker_func(session: "GroupSession", query: str) -> str: @@ -70,7 +72,7 @@ async def search_sticker_func(session: "GroupSession", query: str) -> str: Returns: Formatted list of matching stickers or empty message """ - manager = await get_sticker_manager() + manager = get_sticker_manager() # First try AND matching (all keywords must match) stickers = await manager.search_sticker(query, limit=5) @@ -103,7 +105,7 @@ async def send_sticker_func(session: "GroupSession", sticker_id: int) -> str: Returns: Success or error message """ - manager = await get_sticker_manager() + manager = get_sticker_manager() sticker = await manager.get_sticker(sticker_id) if sticker is None: @@ -162,7 +164,7 @@ async def send_sticker(sticker_id: int) -> str: description=( "从收藏的表情包库中搜索合适的表情包。\n" "**何时调用**: 当你想用表情包回复群友时,先调用此工具搜索合适的表情包。\n" - "**搜索技巧**: 使用描述性的关键词,如情绪(开心、悲伤、嘲讽)、动作(大笑、哭泣)或内容(猫、狗、熊猫头)。" + "**搜索技巧**: 使用描述性的关键词,如情绪(开心、悲伤、嘲讽)、动作(大笑、哭泣)或内容。" ), parameters={ "query": FunctionParameter(