From de888fa3f56a7ba88874df44f5fcf16674322529 Mon Sep 17 00:00:00 2001 From: XiaoDeng3386 <1744793737@qq.com> Date: Sat, 24 Jan 2026 23:36:14 +0800 Subject: [PATCH 1/7] feat(chat): add LLM-based sticker classification and validation --- migrations/versions/24c7dc8e7e67_.py | 61 +++++++ src/lang/zh_hans/chat.yaml | 1 + src/plugins/nonebot_plugin_chat/models.py | 18 +- .../utils/sticker_manager.py | 170 +++++++++++++++++- .../utils/tools/sticker.py | 4 +- 5 files changed, 240 insertions(+), 14 deletions(-) create mode 100644 migrations/versions/24c7dc8e7e67_.py diff --git a/migrations/versions/24c7dc8e7e67_.py b/migrations/versions/24c7dc8e7e67_.py new file mode 100644 index 00000000..fad9f871 --- /dev/null +++ b/migrations/versions/24c7dc8e7e67_.py @@ -0,0 +1,61 @@ +"""empty message + +迁移 ID: 24c7dc8e7e67 +父迁移: a1b2c3d4e5f6 +创建时间: 2026-01-24 23:33:36.695599 + +""" +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa + + +revision: str = '24c7dc8e7e67' +down_revision: str | Sequence[str] | None = 'a1b2c3d4e5f6' +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade(name: str = "") -> None: + if name: + return + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('nonebot_plugin_chat_chatuser') + op.drop_table('nonebot_plugin_chat_sessionmessage') + with op.batch_alter_table('nonebot_plugin_chat_sticker', schema=None) as batch_op: + batch_op.add_column(sa.Column('is_meme', sa.Boolean(), nullable=True)) + batch_op.add_column(sa.Column('meme_text', sa.Text(), nullable=True)) + batch_op.add_column(sa.Column('emotion', sa.String(length=64), nullable=True)) + batch_op.add_column(sa.Column('labels', sa.Text(), nullable=True)) + batch_op.add_column(sa.Column('context_keywords', sa.Text(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(name: str = "") -> None: + if name: + return + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('nonebot_plugin_chat_sticker', schema=None) as batch_op: + batch_op.drop_column('context_keywords') + batch_op.drop_column('labels') + batch_op.drop_column('emotion') + batch_op.drop_column('meme_text') + batch_op.drop_column('is_meme') + + op.create_table('nonebot_plugin_chat_sessionmessage', + sa.Column('id_', sa.INTEGER(), nullable=False), + sa.Column('user_id', sa.VARCHAR(length=128), nullable=False), + sa.Column('content', sa.TEXT(), nullable=False), + sa.Column('role', sa.VARCHAR(length=16), nullable=False), + sa.PrimaryKeyConstraint('id_', name=op.f('pk_nonebot_plugin_chat_sessionmessage')) + ) + op.create_table('nonebot_plugin_chat_chatuser', + sa.Column('user_id', sa.VARCHAR(length=128), nullable=False), + sa.Column('latest_chat', sa.DATETIME(), nullable=False), + sa.PrimaryKeyConstraint('user_id', name=op.f('pk_nonebot_plugin_chat_chatuser')) + ) + # ### end Alembic commands ### diff --git a/src/lang/zh_hans/chat.yaml b/src/lang/zh_hans/chat.yaml index bb8585b5..340f77ef 100644 --- a/src/lang/zh_hans/chat.yaml +++ b/src/lang/zh_hans/chat.yaml @@ -469,3 +469,4 @@ sticker: send_failed: '发送失败: {}' id_not_found: '未找到 ID 为 {} 的表情包' duplicate: '该表情包已经收藏过了 (ID: {}, 相似度: {:.1%}),无需重复收藏' + not_meme: '该图片不是表情包,无法收藏' diff --git a/src/plugins/nonebot_plugin_chat/models.py b/src/plugins/nonebot_plugin_chat/models.py index 24de473f..e0d09feb 100644 --- a/src/plugins/nonebot_plugin_chat/models.py +++ b/src/plugins/nonebot_plugin_chat/models.py @@ -10,18 +10,6 @@ CompatibleBlob = LargeBinary().with_variant(MEDIUMBLOB(), "mysql") -class SessionMessage(Model): - id_: Mapped[int] = mapped_column(autoincrement=True, primary_key=True) - user_id: Mapped[str] = mapped_column(String(128)) - content: Mapped[str] = mapped_column(Text()) - role: Mapped[str] = mapped_column(String(16)) - - -class ChatUser(Model): - user_id: Mapped[str] = mapped_column(String(128), primary_key=True) - latest_chat: Mapped[datetime] - - class ChatGroup(Model): group_id: Mapped[str] = mapped_column(String(128), primary_key=True) blocked_user: Mapped[str] = mapped_column(Text(), default="[]") @@ -56,3 +44,9 @@ class Sticker(Model): group_id: Mapped[Optional[str]] = mapped_column(String(128), nullable=True, index=True) # 来源群聊 created_time: Mapped[float] = mapped_column(Float()) # 创建时间戳 p_hash: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) # 感知哈希,用于图片查重 + # 表情包分类索引信息(LLM 生成) + is_meme: Mapped[Optional[bool]] = mapped_column(nullable=True) # 是否为表情包 + meme_text: Mapped[Optional[str]] = mapped_column(Text(), nullable=True) # 表情包中的文本 + emotion: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) # 表情包表达的情绪 + labels: Mapped[Optional[str]] = mapped_column(Text(), nullable=True) # 表情包标签(JSON 数组) + context_keywords: Mapped[Optional[str]] = mapped_column(Text(), nullable=True) # 适用语境关键词(JSON 数组) diff --git a/src/plugins/nonebot_plugin_chat/utils/sticker_manager.py b/src/plugins/nonebot_plugin_chat/utils/sticker_manager.py index 34538406..a8557c97 100644 --- a/src/plugins/nonebot_plugin_chat/utils/sticker_manager.py +++ b/src/plugins/nonebot_plugin_chat/utils/sticker_manager.py @@ -15,16 +15,149 @@ # along with this program. If not, see . # ############################################################################## +import base64 +import json +import re +import traceback from datetime import datetime -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union +from nonebot import logger from nonebot_plugin_orm import get_session +from nonebot_plugin_openai.utils.chat import fetch_message +from nonebot_plugin_openai.utils.message import generate_message from sqlalchemy import select from ..models import Sticker from .sticker_similarity import calculate_hash_async, check_sticker_duplicate +# 表情包分类结果类型 +class MemeClassification(TypedDict): + is_meme: bool + text: str + emotion: str + labels: List[str] + context_keywords: List[str] + + +# 表情包分类提示词 +MEME_CLASSIFICATION_PROMPT = """你是一个表情包分析 AI。 +我会向你提供一张表情包图片,你需要分析表情包的内容,并对其进行分类。 + +### 输出格式 +一段 JSON,不要包含除了 JSON 结构以外的任何内容。 + +{ + "is_meme": boolen, // 这张图片是一个表情包吗?如果是为 true。 + "text": string, // 表情包中的文本,如果没有请填空字符串。 + "emotion": string, // 表情包所表达的情绪的类型,如:高兴、难过、生气、恐惧、伤心。 + "labels": array[string], // 表情包的标签,按照参考的标签库分类中给出的示例进行编写。 + "context_keywords": array[string] // 表情包适用的语境,这个表情包适合在群聊中谈到什么关键词时出现? +} + +### 参考的标签库分类 +1. **社交回应类**:`赞同`、`反对`、`无语`、`震惊`、`委屈`、`认怂`。 +2. **网络梗类**:`吃瓜`、`摆烂`、`摸鱼`、`内卷`、`抽象`、`典`。 +3. **时间/天气类**:`早安`、`周五`、`放假`。 +4. **互动类**:`贴贴`、`抱抱`、`禁言`、`反弹`。""" + + +def extract_json_from_response(response: str) -> Optional[Dict[str, Any]]: + """ + 从 LLM 响应中提取 JSON,处理可能包含 markdown 代码块的情况 + + Args: + response: LLM 返回的原始响应文本 + + Returns: + 解析后的 JSON 字典,如果解析失败返回 None + """ + # 去除首尾空白 + response = response.strip() + + # 尝试直接解析 + try: + return json.loads(response) + except json.JSONDecodeError: + pass + + # 尝试提取 markdown 代码块中的 JSON + # 匹配 ```json ... ``` 或 ``` ... ``` + code_block_pattern = r'```(?:json)?\s*\n?([\s\S]*?)\n?```' + matches = re.findall(code_block_pattern, response) + + for match in matches: + try: + return json.loads(match.strip()) + except json.JSONDecodeError: + continue + + # 尝试查找 JSON 对象(以 { 开头,以 } 结尾) + json_pattern = r'\{[\s\S]*\}' + json_matches = re.findall(json_pattern, response) + + for match in json_matches: + try: + return json.loads(match) + except json.JSONDecodeError: + continue + + return None + + +async def classify_meme(image_data: bytes) -> Optional[MemeClassification]: + """ + 使用 LLM 对表情包进行分类 + + Args: + image_data: 图片二进制数据 + + Returns: + MemeClassification 分类结果,如果分类失败返回 None + """ + try: + # 转换图片为 base64 + image_base64 = base64.b64encode(image_data).decode("utf-8") + + # 构建消息 + messages = [ + generate_message(MEME_CLASSIFICATION_PROMPT, "system"), + generate_message( + [ + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}, + {"type": "text", "text": "请分析这张图片并输出分类 JSON。"}, + ], + "user", + ), + ] + + # 调用 LLM + response = (await fetch_message(messages, identify="Meme Classification")).strip() + + # 解析 JSON(处理可能的 markdown 代码块) + result = extract_json_from_response(response) + + if result is None: + logger.warning(f"Failed to parse meme classification response: {response}") + return None + + # 验证并转换结果 + classification: MemeClassification = { + "is_meme": bool(result.get("is_meme", False)), + "text": str(result.get("text", "")), + "emotion": str(result.get("emotion", "")), + "labels": list(result.get("labels", [])), + "context_keywords": list(result.get("context_keywords", [])), + } + + return classification + + except Exception as e: + logger.warning(f"Failed to classify meme: {e}\n{traceback.format_exc()}") + return None + + class DuplicateStickerError(Exception): """表情包重复异常""" @@ -34,6 +167,14 @@ def __init__(self, existing_sticker: Sticker, similarity: float): super().__init__(f"发现重复的表情包 (ID: {existing_sticker.id}, 相似度: {similarity:.2%})") +class NotMemeError(Exception): + """图片不是表情包异常""" + + def __init__(self, message: str = "该图片不是表情包"): + self.message = message + super().__init__(message) + + class StickerManager: """Sticker management system for saving, searching and retrieving stickers""" @@ -66,6 +207,28 @@ async def save_sticker( # 计算感知哈希 p_hash = await calculate_hash_async(raw) + + # 调用 LLM 进行表情包分类 + classification = await classify_meme(raw) + + # 准备分类数据 + is_meme: Optional[bool] = None + meme_text: Optional[str] = None + emotion: Optional[str] = None + labels_json: Optional[str] = None + context_keywords_json: Optional[str] = None + + if classification is not None: + is_meme = classification["is_meme"] + + # 如果不是表情包,拒绝添加 + if not is_meme: + raise NotMemeError("该图片不是表情包,无法收藏") + + meme_text = classification["text"] + emotion = classification["emotion"] + labels_json = json.dumps(classification["labels"], ensure_ascii=False) + context_keywords_json = json.dumps(classification["context_keywords"], ensure_ascii=False) sticker = Sticker( description=description, @@ -73,6 +236,11 @@ async def save_sticker( group_id=group_id, created_time=current_time.timestamp(), p_hash=p_hash if p_hash else None, + is_meme=is_meme, + meme_text=meme_text, + emotion=emotion, + labels=labels_json, + context_keywords=context_keywords_json, ) session.add(sticker) diff --git a/src/plugins/nonebot_plugin_chat/utils/tools/sticker.py b/src/plugins/nonebot_plugin_chat/utils/tools/sticker.py index ce7c3fb4..1c513558 100644 --- a/src/plugins/nonebot_plugin_chat/utils/tools/sticker.py +++ b/src/plugins/nonebot_plugin_chat/utils/tools/sticker.py @@ -38,7 +38,7 @@ async def save_sticker_func(session: "GroupSession", image_id: str) -> str: Returns: Success or error message """ - from ..sticker_manager import DuplicateStickerError + from ..sticker_manager import DuplicateStickerError, NotMemeError # Get image data from cache image_data = await get_image_by_id(image_id) @@ -57,6 +57,8 @@ async def save_sticker_func(session: "GroupSession", image_id: str) -> str: return await lang.text("sticker.saved", session.user_id, sticker.id) except DuplicateStickerError as e: return await lang.text("sticker.duplicate", session.user_id, e.existing_sticker.id, e.similarity) + except NotMemeError: + return await lang.text("sticker.not_meme", session.user_id) async def search_sticker_func(session: "GroupSession", query: str) -> str: From c5ced3cd8b6b7b50818f29626a84d5b6205704aa Mon Sep 17 00:00:00 2001 From: XiaoDeng3386 <1744793737@qq.com> Date: Sat, 24 Jan 2026 23:47:48 +0800 Subject: [PATCH 2/7] refactor(chat): remove unused is_meme column from Sticker model Remove the is_meme field from the database model and sticker manager since it was only used for validation during sticker creation and never stored or queried afterward. The classification check is now done inline without persisting the boolean value. --- src/plugins/nonebot_plugin_chat/models.py | 1 - src/plugins/nonebot_plugin_chat/utils/sticker_manager.py | 6 +----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/src/plugins/nonebot_plugin_chat/models.py b/src/plugins/nonebot_plugin_chat/models.py index e0d09feb..778aef63 100644 --- a/src/plugins/nonebot_plugin_chat/models.py +++ b/src/plugins/nonebot_plugin_chat/models.py @@ -45,7 +45,6 @@ class Sticker(Model): created_time: Mapped[float] = mapped_column(Float()) # 创建时间戳 p_hash: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) # 感知哈希,用于图片查重 # 表情包分类索引信息(LLM 生成) - is_meme: Mapped[Optional[bool]] = mapped_column(nullable=True) # 是否为表情包 meme_text: Mapped[Optional[str]] = mapped_column(Text(), nullable=True) # 表情包中的文本 emotion: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) # 表情包表达的情绪 labels: Mapped[Optional[str]] = mapped_column(Text(), nullable=True) # 表情包标签(JSON 数组) diff --git a/src/plugins/nonebot_plugin_chat/utils/sticker_manager.py b/src/plugins/nonebot_plugin_chat/utils/sticker_manager.py index a8557c97..a0f4777f 100644 --- a/src/plugins/nonebot_plugin_chat/utils/sticker_manager.py +++ b/src/plugins/nonebot_plugin_chat/utils/sticker_manager.py @@ -212,17 +212,14 @@ async def save_sticker( classification = await classify_meme(raw) # 准备分类数据 - is_meme: Optional[bool] = None meme_text: Optional[str] = None emotion: Optional[str] = None labels_json: Optional[str] = None context_keywords_json: Optional[str] = None if classification is not None: - is_meme = classification["is_meme"] - # 如果不是表情包,拒绝添加 - if not is_meme: + if not classification["is_meme"]: raise NotMemeError("该图片不是表情包,无法收藏") meme_text = classification["text"] @@ -236,7 +233,6 @@ async def save_sticker( group_id=group_id, created_time=current_time.timestamp(), p_hash=p_hash if p_hash else None, - is_meme=is_meme, meme_text=meme_text, emotion=emotion, labels=labels_json, From 920b1f72f5f965d0099f78fbcc5c3bd5abca18b1 Mon Sep 17 00:00:00 2001 From: XiaoDeng3386 <1744793737@qq.com> Date: Sun, 25 Jan 2026 12:06:12 +0800 Subject: [PATCH 3/7] feat(chat): add sticker filtering and classification methods Add new methods to StickerManager for filtering and classifying stickers: - filter_by_emotion: filter stickers by emotion field - filter_by_label: filter stickers by label in JSON array - filter_by_context_keyword: filter stickers by context keyword - filter_by_classification: combined filter with AND logic - migrate_existing_stickers: batch classify unclassified stickers via LLM - classify_single_sticker: classify individual sticker by ID --- .../utils/sticker_manager.py | 213 +++++++++++++++++- 1 file changed, 212 insertions(+), 1 deletion(-) diff --git a/src/plugins/nonebot_plugin_chat/utils/sticker_manager.py b/src/plugins/nonebot_plugin_chat/utils/sticker_manager.py index a0f4777f..8292f429 100644 --- a/src/plugins/nonebot_plugin_chat/utils/sticker_manager.py +++ b/src/plugins/nonebot_plugin_chat/utils/sticker_manager.py @@ -20,7 +20,7 @@ import re import traceback from datetime import datetime -from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union +from typing import Any, Dict, List, Optional, TypedDict from nonebot import logger from nonebot_plugin_orm import get_session @@ -350,6 +350,217 @@ async def get_all_stickers(self, limit: int = 100) -> List[Sticker]: result = await session.scalars(stmt) return list(result.all()) + async def filter_by_emotion(self, emotion: str, limit: int = 10) -> List[Sticker]: + """ + Filter stickers by emotion + + Args: + emotion: Emotion to filter by (e.g., "高兴", "难过", "生气") + limit: Maximum number of results to return + + Returns: + List of matching Sticker objects + """ + async with get_session() as session: + stmt = ( + select(Sticker) + .where(Sticker.emotion.contains(emotion)) + .order_by(Sticker.created_time.desc()) + .limit(limit) + ) + result = await session.scalars(stmt) + return list(result.all()) + + async def filter_by_label(self, label: str, limit: int = 10) -> List[Sticker]: + """ + Filter stickers by label (searches within JSON array) + + Args: + label: Label to filter by (e.g., "赞同", "摆烂", "贴贴") + limit: Maximum number of results to return + + Returns: + List of matching Sticker objects + """ + async with get_session() as session: + # labels 字段是 JSON 数组字符串,使用 contains 进行模糊匹配 + stmt = ( + select(Sticker) + .where(Sticker.labels.contains(label)) + .order_by(Sticker.created_time.desc()) + .limit(limit) + ) + result = await session.scalars(stmt) + return list(result.all()) + + async def filter_by_context_keyword(self, keyword: str, limit: int = 10) -> List[Sticker]: + """ + Filter stickers by context keyword (searches within JSON array) + + Args: + keyword: Context keyword to filter by + limit: Maximum number of results to return + + Returns: + List of matching Sticker objects + """ + async with get_session() as session: + # context_keywords 字段是 JSON 数组字符串,使用 contains 进行模糊匹配 + stmt = ( + select(Sticker) + .where(Sticker.context_keywords.contains(keyword)) + .order_by(Sticker.created_time.desc()) + .limit(limit) + ) + result = await session.scalars(stmt) + return list(result.all()) + + async def filter_by_classification( + self, + emotion: Optional[str] = None, + labels: Optional[List[str]] = None, + context_keywords: Optional[List[str]] = None, + limit: int = 10, + ) -> List[Sticker]: + """ + Filter stickers by multiple classification criteria (AND logic) + + Args: + emotion: Emotion to filter by (optional) + labels: List of labels to filter by, all must match (optional) + context_keywords: List of context keywords to filter by, all must match (optional) + limit: Maximum number of results to return + + Returns: + List of matching Sticker objects + """ + async with get_session() as session: + stmt = select(Sticker) + + # 应用情绪筛选 + if emotion: + stmt = stmt.where(Sticker.emotion.contains(emotion)) + + # 应用标签筛选(所有标签都必须匹配) + if labels: + for label in labels: + stmt = stmt.where(Sticker.labels.contains(label)) + + # 应用语境关键词筛选(所有关键词都必须匹配) + if context_keywords: + for keyword in context_keywords: + stmt = stmt.where(Sticker.context_keywords.contains(keyword)) + + stmt = stmt.order_by(Sticker.created_time.desc()).limit(limit) + result = await session.scalars(stmt) + return list(result.all()) + + async def migrate_existing_stickers(self, batch_size: int = 10) -> Dict[str, int]: + """ + Migrate existing stickers by classifying them with LLM + + This method processes stickers that don't have classification data + (meme_text, emotion, labels, context_keywords are all None) + + Args: + batch_size: Number of stickers to process in each batch + + Returns: + Dict with migration statistics: + - total: Total number of stickers processed + - success: Number of successfully classified stickers + - failed: Number of stickers that failed classification + - skipped: Number of stickers already classified + """ + stats = {"total": 0, "success": 0, "failed": 0, "skipped": 0} + + async with get_session() as session: + # 查找所有未分类的表情包(分类字段都为 None) + stmt = select(Sticker).where( + Sticker.meme_text.is_(None), + Sticker.emotion.is_(None), + Sticker.labels.is_(None), + Sticker.context_keywords.is_(None), + ) + result = await session.scalars(stmt) + stickers_to_migrate = list(result.all()) + + stats["total"] = len(stickers_to_migrate) + + for sticker in stickers_to_migrate: + try: + # 调用 LLM 进行分类 + classification = await classify_meme(sticker.raw) + + if classification is None: + logger.warning(f"Failed to classify sticker {sticker.id}: LLM returned None") + stats["failed"] += 1 + continue + + # 更新分类信息 + sticker.meme_text = classification["text"] + sticker.emotion = classification["emotion"] + sticker.labels = json.dumps(classification["labels"], ensure_ascii=False) + sticker.context_keywords = json.dumps(classification["context_keywords"], ensure_ascii=False) + + session.add(sticker) + stats["success"] += 1 + + logger.info(f"Successfully classified sticker {sticker.id}: emotion={classification['emotion']}") + + except Exception as e: + logger.warning(f"Failed to migrate sticker {sticker.id}: {e}") + stats["failed"] += 1 + + # 提交所有更改 + await session.commit() + + logger.info( + f"Sticker migration completed: {stats['success']} success, " + f"{stats['failed']} failed, {stats['skipped']} skipped out of {stats['total']} total" + ) + + return stats + + async def classify_single_sticker(self, sticker_id: int) -> bool: + """ + Classify a single sticker by its ID + + Args: + sticker_id: The ID of the sticker to classify + + Returns: + True if classification was successful, False otherwise + """ + async with get_session() as session: + sticker = await session.get(Sticker, sticker_id) + + if sticker is None: + logger.warning(f"Sticker {sticker_id} not found") + return False + + try: + classification = await classify_meme(sticker.raw) + + if classification is None: + logger.warning(f"Failed to classify sticker {sticker_id}: LLM returned None") + return False + + sticker.meme_text = classification["text"] + sticker.emotion = classification["emotion"] + sticker.labels = json.dumps(classification["labels"], ensure_ascii=False) + sticker.context_keywords = json.dumps(classification["context_keywords"], ensure_ascii=False) + + session.add(sticker) + await session.commit() + + logger.info(f"Successfully classified sticker {sticker_id}") + return True + + except Exception as e: + logger.warning(f"Failed to classify sticker {sticker_id}: {e}") + return False + # Global sticker manager instance sticker_manager = StickerManager() From 1f39f3b95f2a4d997b236b0f6d835a65490e80ee Mon Sep 17 00:00:00 2001 From: XiaoDeng3386 <1744793737@qq.com> Date: Sun, 25 Jan 2026 13:28:41 +0800 Subject: [PATCH 4/7] feat(chat): add dynamic sticker recommendations based on mood and context Integrate sticker recommendations into group chat system prompts by: - Extract mood from LLM reasoning content using regex pattern matching - Filter stickers by emotion and context keywords from chat history - Update system prompt with personalized sticker recommendations - Change get_sticker_manager to synchronous function for simpler usage The recommendations are dynamically updated after detecting the thinking process in bot responses, providing contextually relevant sticker suggestions based on current conversation mood and topics. --- src/lang/zh_hans/chat.yaml | 6 +- .../nonebot_plugin_chat/matcher/group.py | 171 +++++++++++++++++- .../utils/sticker_manager.py | 4 +- .../utils/tools/sticker.py | 8 +- 4 files changed, 177 insertions(+), 12 deletions(-) diff --git a/src/lang/zh_hans/chat.yaml b/src/lang/zh_hans/chat.yaml index 340f77ef..bb15190d 100644 --- a/src/lang/zh_hans/chat.yaml +++ b/src/lang/zh_hans/chat.yaml @@ -82,7 +82,7 @@ prompt_group: ```markdown ## 思考过程: - 当前状态: 你对当前群聊中正在发生的事情的观察和判断,以及你对群聊氛围的评估。 - - 心情和想法: 你的心情和想法(我正在想XX事,我很开心/难过/生气/平静等) + - 心情: 你当前的心情(很高兴/悲伤/生气/恐惧/平静) - 是否需要回应: 你是否需要回应群友的消息,你想要回应哪些消息或话题,以及你决定回应的原因。 - 连续发送消息数量: 评估你在本次回应中需要发送的消息数量,判断该数量是否合理。 @@ -187,8 +187,8 @@ prompt_group: 部分群友的介绍: {4} - - + 表情包推荐(根据当前你的情绪或群聊中的话题筛选): + {5} image_describe_system: | # Role diff --git a/src/plugins/nonebot_plugin_chat/matcher/group.py b/src/plugins/nonebot_plugin_chat/matcher/group.py index 66d4b8d9..53d0f294 100644 --- a/src/plugins/nonebot_plugin_chat/matcher/group.py +++ b/src/plugins/nonebot_plugin_chat/matcher/group.py @@ -29,6 +29,7 @@ from typing import Literal, TypedDict, Optional, Any from nonebot_plugin_apscheduler import scheduler from nonebot_plugin_alconna import UniMessage, Target, get_target +from nonebot_plugin_chat.utils.sticker_manager import get_sticker_manager from nonebot_plugin_userinfo import EventUserInfo, UserInfo from nonebot_plugin_larkuser import get_user @@ -138,6 +139,7 @@ def __init__(self, processor: "MessageProcessor", max_message_count: int = 10) - self.processor = processor self.max_message_count = max_message_count self.messages: list[OpenAIMessage] = [] + self.fetcher_lock = asyncio.Lock() self.consecutive_bot_messages = 0 # 连续发送消息计数器 @@ -148,11 +150,11 @@ def clean_special_message(self) -> None: break self.messages.pop(0) - async def get_messages(self) -> list[OpenAIMessage]: + async def get_messages(self, reasoning_content: Optional[str] = None) -> list[OpenAIMessage]: self.clean_special_message() self.messages = self.messages[-self.max_message_count :] messages = copy.deepcopy(self.messages) - messages.insert(0, await self.processor.generate_system_prompt()) + messages.insert(0, await self.processor.generate_system_prompt(reasoning_content)) return messages async def fetch_reply(self) -> None: @@ -161,6 +163,27 @@ async def fetch_reply(self) -> None: async with self.fetcher_lock: await self._fetch_reply() + def _extract_reasoning_content(self, message: OpenAIMessage) -> Optional[str]: + """ + 从消息中提取思考过程内容 + + Args: + message: OpenAI 消息对象 + + Returns: + 思考过程内容,如果没有找到则返回 None + """ + content = None + if isinstance(message, dict): + content = message.get("content", "") + elif hasattr(message, "content"): + content = message.content + + if content and isinstance(content, str) and content.strip().startswith("## 思考过程"): + return content + + return None + async def _fetch_reply(self) -> None: messages = await self.get_messages() self.messages.clear() @@ -176,7 +199,27 @@ async def _fetch_reply(self) -> None: logger.info(f"Moonlark 说: {message}") fetcher.session.messages.extend(self.messages) self.messages = [] + + # 在消息流结束后检测思考过程并更新 system 消息 self.messages = fetcher.get_messages() + + # 检查返回的消息中是否包含思考过程 + reasoning_content: Optional[str] = None + for msg in self.messages: + extracted = self._extract_reasoning_content(msg) + if extracted: + reasoning_content = extracted + break + + # 如果检测到思考过程,更新表情包推荐并重新生成 system 消息 + if reasoning_content: + logger.debug("检测到思考过程,正在更新表情包推荐...") + new_system_prompt = await self.processor.generate_system_prompt(reasoning_content) + # 更新 self.messages 中的 system 消息(如果有的话),或在开头插入 + if self.messages and get_role(self.messages[0]) == "system": + self.messages[0] = new_system_prompt + else: + self.messages.insert(0, new_system_prompt) def append_user_message(self, message: str) -> None: self.consecutive_bot_messages = 0 # 收到用户消息时重置计数器 @@ -208,10 +251,127 @@ def insert_warning_message(self) -> None: class MessageProcessor: + async def get_sticker_recommendations(self, reasoning_content: Optional[str] = None) -> list[str]: + """ + 根据思考过程中的心情和上下文关键词获取表情包推荐 + + Args: + reasoning_content: LLM 输出的思考过程内容(以 "## 思考过程" 开头) + 如果为 None,则只根据 context_keywords 进行匹配,不根据心情筛选 + + Returns: + 推荐的表情包列表(格式为 "ID: 描述") + """ + recommendations: list[str] = [] + seen_ids: set[int] = set() # 用于去重 + + # 获取聊天记录内容 + chat_history = "\n".join(self.get_message_content_list()) + + # 只有当提供了 reasoning_content 时才根据心情筛选 + # 如果请求来自 MessageProcessor 或其他地方(reasoning_content 为 None),则跳过心情筛选 + if reasoning_content: + # 从思考过程中提取心情 + mood = self._extract_mood_from_reasoning(reasoning_content) + + # 根据心情筛选表情包 + if mood: + stickers = await self.sticker_manager.filter_by_emotion(mood, limit=3) + for sticker in stickers: + if sticker.id not in seen_ids: + seen_ids.add(sticker.id) + desc = sticker.description + recommendations.append(f"{sticker.id}: {desc}") + + # 根据 context_keywords 匹配聊天记录和思考过程 + combined_text = chat_history + if reasoning_content: + combined_text += "\n" + reasoning_content + + matched_stickers = await self._match_stickers_by_context(combined_text, exclude_ids=seen_ids) + for sticker in matched_stickers: + if sticker.id not in seen_ids: + seen_ids.add(sticker.id) + desc = sticker.description + recommendations.append(f"{sticker.id}: {desc}") + + # 限制推荐数量 + return recommendations[:10] + + def _extract_mood_from_reasoning(self, reasoning_content: Optional[str]) -> Optional[str]: + """ + 从思考过程中提取心情 + + Args: + reasoning_content: LLM 输出的思考过程内容 + + Returns: + 提取的心情字符串,如果未找到返回 None + """ + if not reasoning_content: + return None + + # 匹配 "- 心情: XXX" 格式 + mood_pattern = r"-\s*心情[::]\s*(.+?)(?:\n|$)" + match = re.search(mood_pattern, reasoning_content) + if match: + mood = match.group(1).strip() + # 清理可能的括号内容,如 "很高兴(因为...)" -> "很高兴" + mood = re.sub(r'[((].+?[))]', '', mood).strip() + return mood + + return None + + async def _match_stickers_by_context(self, text: str, exclude_ids: set[int], limit: int = 5) -> list: + """ + 根据上下文关键词匹配表情包 + + Args: + text: 要匹配的文本(聊天记录 + 思考过程) + exclude_ids: 要排除的表情包 ID 集合 + limit: 返回的最大数量 + + Returns: + 匹配的 Sticker 对象列表 + """ + from nonebot_plugin_orm import get_session + from sqlalchemy import select + from ..models import Sticker + + matched: list = [] + + async with get_session() as session: + # 获取所有有 context_keywords 的表情包 + stmt = select(Sticker).where(Sticker.context_keywords.isnot(None)) + result = await session.scalars(stmt) + stickers = list(result.all()) + + for sticker in stickers: + if sticker.id in exclude_ids: + continue + + # 解析 context_keywords JSON + try: + keywords = json.loads(sticker.context_keywords) if sticker.context_keywords else [] + except json.JSONDecodeError: + continue + + # 检查关键词是否出现在文本中 + for keyword in keywords: + if keyword and keyword in text: + matched.append(sticker) + break + + if len(matched) >= limit: + break + + return matched + def __init__(self, session: "GroupSession"): self.openai_messages = MessageQueue(self, 50) self.session = session self.enabled = True + self.sticker_manager = get_sticker_manager() self.interrupter = Interrupter(session) self.cold_until = datetime.now() self.blocked = False @@ -587,7 +747,7 @@ async def _get_user_profiles(self, chat_history: str) -> dict[str, str]: profiles[nickname] = (await session.get_one(UserProfile, {"user_id": user_id})).profile_content return profiles - async def generate_system_prompt(self) -> OpenAIMessage: + async def generate_system_prompt(self, reasoning_content: Optional[str] = None) -> OpenAIMessage: chat_history = "\n".join(self.get_message_content_list()) # 获取相关笔记 note_manager = await get_context_notes(self.session.group_id) @@ -606,6 +766,10 @@ def format_note(note): created_time = datetime.fromtimestamp(note.created_time).strftime("%y-%m-%d") return f"- {note.content} (#{note.id},创建于 {created_time})" + # 获取表情包推荐 + sticker_recommendations = await self.get_sticker_recommendations(reasoning_content) + sticker_text = "\n".join([f"- {rec}" for rec in sticker_recommendations]) if sticker_recommendations else "暂无推荐" + return generate_message( await lang.text( "prompt_group.default", @@ -619,6 +783,7 @@ def format_note(note): else "暂无" ), profiles_text, + sticker_text, ), "system", ) diff --git a/src/plugins/nonebot_plugin_chat/utils/sticker_manager.py b/src/plugins/nonebot_plugin_chat/utils/sticker_manager.py index 8292f429..e5ac047a 100644 --- a/src/plugins/nonebot_plugin_chat/utils/sticker_manager.py +++ b/src/plugins/nonebot_plugin_chat/utils/sticker_manager.py @@ -51,7 +51,7 @@ class MemeClassification(TypedDict): { "is_meme": boolen, // 这张图片是一个表情包吗?如果是为 true。 "text": string, // 表情包中的文本,如果没有请填空字符串。 - "emotion": string, // 表情包所表达的情绪的类型,如:高兴、难过、生气、恐惧、伤心。 + "emotion": string, // 表情包所表达的情绪的类型,如:高兴、难过、生气、恐惧。 "labels": array[string], // 表情包的标签,按照参考的标签库分类中给出的示例进行编写。 "context_keywords": array[string] // 表情包适用的语境,这个表情包适合在群聊中谈到什么关键词时出现? } @@ -566,7 +566,7 @@ async def classify_single_sticker(self, sticker_id: int) -> bool: sticker_manager = StickerManager() -async def get_sticker_manager() -> StickerManager: +def get_sticker_manager() -> StickerManager: """ Get the global StickerManager instance diff --git a/src/plugins/nonebot_plugin_chat/utils/tools/sticker.py b/src/plugins/nonebot_plugin_chat/utils/tools/sticker.py index 1c513558..7dbd383a 100644 --- a/src/plugins/nonebot_plugin_chat/utils/tools/sticker.py +++ b/src/plugins/nonebot_plugin_chat/utils/tools/sticker.py @@ -47,7 +47,7 @@ async def save_sticker_func(session: "GroupSession", image_id: str) -> str: return await lang.text("sticker.not_found", session.user_id) # Get sticker manager and save - manager = await get_sticker_manager() + manager = get_sticker_manager() try: sticker = await manager.save_sticker( description=image_data["description"], @@ -72,7 +72,7 @@ async def search_sticker_func(session: "GroupSession", query: str) -> str: Returns: Formatted list of matching stickers or empty message """ - manager = await get_sticker_manager() + manager = get_sticker_manager() # First try AND matching (all keywords must match) stickers = await manager.search_sticker(query, limit=5) @@ -105,7 +105,7 @@ async def send_sticker_func(session: "GroupSession", sticker_id: int) -> str: Returns: Success or error message """ - manager = await get_sticker_manager() + manager = get_sticker_manager() sticker = await manager.get_sticker(sticker_id) if sticker is None: @@ -164,7 +164,7 @@ async def send_sticker(sticker_id: int) -> str: description=( "从收藏的表情包库中搜索合适的表情包。\n" "**何时调用**: 当你想用表情包回复群友时,先调用此工具搜索合适的表情包。\n" - "**搜索技巧**: 使用描述性的关键词,如情绪(开心、悲伤、嘲讽)、动作(大笑、哭泣)或内容(猫、狗、熊猫头)。" + "**搜索技巧**: 使用描述性的关键词,如情绪(开心、悲伤、嘲讽)、动作(大笑、哭泣)或内容。" ), parameters={ "query": FunctionParameter( From 6f1fb28c9dd4f8f657fe83ceb597d0a9204d2ca2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 25 Jan 2026 05:30:50 +0000 Subject: [PATCH 5/7] =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E5=8C=96=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- migrations/versions/24c7dc8e7e67_.py | 58 +++++------ .../nonebot_plugin_chat/matcher/group.py | 66 +++++++------ .../utils/sticker_manager.py | 97 +++++++++---------- 3 files changed, 111 insertions(+), 110 deletions(-) diff --git a/migrations/versions/24c7dc8e7e67_.py b/migrations/versions/24c7dc8e7e67_.py index fad9f871..f643686d 100644 --- a/migrations/versions/24c7dc8e7e67_.py +++ b/migrations/versions/24c7dc8e7e67_.py @@ -5,6 +5,7 @@ 创建时间: 2026-01-24 23:33:36.695599 """ + from __future__ import annotations from collections.abc import Sequence @@ -12,9 +13,8 @@ from alembic import op import sqlalchemy as sa - -revision: str = '24c7dc8e7e67' -down_revision: str | Sequence[str] | None = 'a1b2c3d4e5f6' +revision: str = "24c7dc8e7e67" +down_revision: str | Sequence[str] | None = "a1b2c3d4e5f6" branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None @@ -23,14 +23,14 @@ def upgrade(name: str = "") -> None: if name: return # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('nonebot_plugin_chat_chatuser') - op.drop_table('nonebot_plugin_chat_sessionmessage') - with op.batch_alter_table('nonebot_plugin_chat_sticker', schema=None) as batch_op: - batch_op.add_column(sa.Column('is_meme', sa.Boolean(), nullable=True)) - batch_op.add_column(sa.Column('meme_text', sa.Text(), nullable=True)) - batch_op.add_column(sa.Column('emotion', sa.String(length=64), nullable=True)) - batch_op.add_column(sa.Column('labels', sa.Text(), nullable=True)) - batch_op.add_column(sa.Column('context_keywords', sa.Text(), nullable=True)) + op.drop_table("nonebot_plugin_chat_chatuser") + op.drop_table("nonebot_plugin_chat_sessionmessage") + with op.batch_alter_table("nonebot_plugin_chat_sticker", schema=None) as batch_op: + batch_op.add_column(sa.Column("is_meme", sa.Boolean(), nullable=True)) + batch_op.add_column(sa.Column("meme_text", sa.Text(), nullable=True)) + batch_op.add_column(sa.Column("emotion", sa.String(length=64), nullable=True)) + batch_op.add_column(sa.Column("labels", sa.Text(), nullable=True)) + batch_op.add_column(sa.Column("context_keywords", sa.Text(), nullable=True)) # ### end Alembic commands ### @@ -39,23 +39,25 @@ def downgrade(name: str = "") -> None: if name: return # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('nonebot_plugin_chat_sticker', schema=None) as batch_op: - batch_op.drop_column('context_keywords') - batch_op.drop_column('labels') - batch_op.drop_column('emotion') - batch_op.drop_column('meme_text') - batch_op.drop_column('is_meme') - - op.create_table('nonebot_plugin_chat_sessionmessage', - sa.Column('id_', sa.INTEGER(), nullable=False), - sa.Column('user_id', sa.VARCHAR(length=128), nullable=False), - sa.Column('content', sa.TEXT(), nullable=False), - sa.Column('role', sa.VARCHAR(length=16), nullable=False), - sa.PrimaryKeyConstraint('id_', name=op.f('pk_nonebot_plugin_chat_sessionmessage')) + with op.batch_alter_table("nonebot_plugin_chat_sticker", schema=None) as batch_op: + batch_op.drop_column("context_keywords") + batch_op.drop_column("labels") + batch_op.drop_column("emotion") + batch_op.drop_column("meme_text") + batch_op.drop_column("is_meme") + + op.create_table( + "nonebot_plugin_chat_sessionmessage", + sa.Column("id_", sa.INTEGER(), nullable=False), + sa.Column("user_id", sa.VARCHAR(length=128), nullable=False), + sa.Column("content", sa.TEXT(), nullable=False), + sa.Column("role", sa.VARCHAR(length=16), nullable=False), + sa.PrimaryKeyConstraint("id_", name=op.f("pk_nonebot_plugin_chat_sessionmessage")), ) - op.create_table('nonebot_plugin_chat_chatuser', - sa.Column('user_id', sa.VARCHAR(length=128), nullable=False), - sa.Column('latest_chat', sa.DATETIME(), nullable=False), - sa.PrimaryKeyConstraint('user_id', name=op.f('pk_nonebot_plugin_chat_chatuser')) + op.create_table( + "nonebot_plugin_chat_chatuser", + sa.Column("user_id", sa.VARCHAR(length=128), nullable=False), + sa.Column("latest_chat", sa.DATETIME(), nullable=False), + sa.PrimaryKeyConstraint("user_id", name=op.f("pk_nonebot_plugin_chat_chatuser")), ) # ### end Alembic commands ### diff --git a/src/plugins/nonebot_plugin_chat/matcher/group.py b/src/plugins/nonebot_plugin_chat/matcher/group.py index 53d0f294..1843bea2 100644 --- a/src/plugins/nonebot_plugin_chat/matcher/group.py +++ b/src/plugins/nonebot_plugin_chat/matcher/group.py @@ -139,7 +139,7 @@ def __init__(self, processor: "MessageProcessor", max_message_count: int = 10) - self.processor = processor self.max_message_count = max_message_count self.messages: list[OpenAIMessage] = [] - + self.fetcher_lock = asyncio.Lock() self.consecutive_bot_messages = 0 # 连续发送消息计数器 @@ -166,10 +166,10 @@ async def fetch_reply(self) -> None: def _extract_reasoning_content(self, message: OpenAIMessage) -> Optional[str]: """ 从消息中提取思考过程内容 - + Args: message: OpenAI 消息对象 - + Returns: 思考过程内容,如果没有找到则返回 None """ @@ -178,10 +178,10 @@ def _extract_reasoning_content(self, message: OpenAIMessage) -> Optional[str]: content = message.get("content", "") elif hasattr(message, "content"): content = message.content - + if content and isinstance(content, str) and content.strip().startswith("## 思考过程"): return content - + return None async def _fetch_reply(self) -> None: @@ -202,7 +202,7 @@ async def _fetch_reply(self) -> None: # 在消息流结束后检测思考过程并更新 system 消息 self.messages = fetcher.get_messages() - + # 检查返回的消息中是否包含思考过程 reasoning_content: Optional[str] = None for msg in self.messages: @@ -210,7 +210,7 @@ async def _fetch_reply(self) -> None: if extracted: reasoning_content = extracted break - + # 如果检测到思考过程,更新表情包推荐并重新生成 system 消息 if reasoning_content: logger.debug("检测到思考过程,正在更新表情包推荐...") @@ -254,26 +254,26 @@ class MessageProcessor: async def get_sticker_recommendations(self, reasoning_content: Optional[str] = None) -> list[str]: """ 根据思考过程中的心情和上下文关键词获取表情包推荐 - + Args: reasoning_content: LLM 输出的思考过程内容(以 "## 思考过程" 开头) 如果为 None,则只根据 context_keywords 进行匹配,不根据心情筛选 - + Returns: 推荐的表情包列表(格式为 "ID: 描述") """ recommendations: list[str] = [] seen_ids: set[int] = set() # 用于去重 - + # 获取聊天记录内容 chat_history = "\n".join(self.get_message_content_list()) - + # 只有当提供了 reasoning_content 时才根据心情筛选 # 如果请求来自 MessageProcessor 或其他地方(reasoning_content 为 None),则跳过心情筛选 if reasoning_content: # 从思考过程中提取心情 mood = self._extract_mood_from_reasoning(reasoning_content) - + # 根据心情筛选表情包 if mood: stickers = await self.sticker_manager.filter_by_emotion(mood, limit=3) @@ -282,89 +282,89 @@ async def get_sticker_recommendations(self, reasoning_content: Optional[str] = N seen_ids.add(sticker.id) desc = sticker.description recommendations.append(f"{sticker.id}: {desc}") - + # 根据 context_keywords 匹配聊天记录和思考过程 combined_text = chat_history if reasoning_content: combined_text += "\n" + reasoning_content - + matched_stickers = await self._match_stickers_by_context(combined_text, exclude_ids=seen_ids) for sticker in matched_stickers: if sticker.id not in seen_ids: seen_ids.add(sticker.id) desc = sticker.description recommendations.append(f"{sticker.id}: {desc}") - + # 限制推荐数量 return recommendations[:10] - + def _extract_mood_from_reasoning(self, reasoning_content: Optional[str]) -> Optional[str]: """ 从思考过程中提取心情 - + Args: reasoning_content: LLM 输出的思考过程内容 - + Returns: 提取的心情字符串,如果未找到返回 None """ if not reasoning_content: return None - + # 匹配 "- 心情: XXX" 格式 mood_pattern = r"-\s*心情[::]\s*(.+?)(?:\n|$)" match = re.search(mood_pattern, reasoning_content) if match: mood = match.group(1).strip() # 清理可能的括号内容,如 "很高兴(因为...)" -> "很高兴" - mood = re.sub(r'[((].+?[))]', '', mood).strip() + mood = re.sub(r"[((].+?[))]", "", mood).strip() return mood - + return None - + async def _match_stickers_by_context(self, text: str, exclude_ids: set[int], limit: int = 5) -> list: """ 根据上下文关键词匹配表情包 - + Args: text: 要匹配的文本(聊天记录 + 思考过程) exclude_ids: 要排除的表情包 ID 集合 limit: 返回的最大数量 - + Returns: 匹配的 Sticker 对象列表 """ from nonebot_plugin_orm import get_session from sqlalchemy import select from ..models import Sticker - + matched: list = [] - + async with get_session() as session: # 获取所有有 context_keywords 的表情包 stmt = select(Sticker).where(Sticker.context_keywords.isnot(None)) result = await session.scalars(stmt) stickers = list(result.all()) - + for sticker in stickers: if sticker.id in exclude_ids: continue - + # 解析 context_keywords JSON try: keywords = json.loads(sticker.context_keywords) if sticker.context_keywords else [] except json.JSONDecodeError: continue - + # 检查关键词是否出现在文本中 for keyword in keywords: if keyword and keyword in text: matched.append(sticker) break - + if len(matched) >= limit: break - + return matched def __init__(self, session: "GroupSession"): @@ -768,7 +768,9 @@ def format_note(note): # 获取表情包推荐 sticker_recommendations = await self.get_sticker_recommendations(reasoning_content) - sticker_text = "\n".join([f"- {rec}" for rec in sticker_recommendations]) if sticker_recommendations else "暂无推荐" + sticker_text = ( + "\n".join([f"- {rec}" for rec in sticker_recommendations]) if sticker_recommendations else "暂无推荐" + ) return generate_message( await lang.text( diff --git a/src/plugins/nonebot_plugin_chat/utils/sticker_manager.py b/src/plugins/nonebot_plugin_chat/utils/sticker_manager.py index e5ac047a..4e697ddc 100644 --- a/src/plugins/nonebot_plugin_chat/utils/sticker_manager.py +++ b/src/plugins/nonebot_plugin_chat/utils/sticker_manager.py @@ -66,60 +66,60 @@ class MemeClassification(TypedDict): def extract_json_from_response(response: str) -> Optional[Dict[str, Any]]: """ 从 LLM 响应中提取 JSON,处理可能包含 markdown 代码块的情况 - + Args: response: LLM 返回的原始响应文本 - + Returns: 解析后的 JSON 字典,如果解析失败返回 None """ # 去除首尾空白 response = response.strip() - + # 尝试直接解析 try: return json.loads(response) except json.JSONDecodeError: pass - + # 尝试提取 markdown 代码块中的 JSON # 匹配 ```json ... ``` 或 ``` ... ``` - code_block_pattern = r'```(?:json)?\s*\n?([\s\S]*?)\n?```' + code_block_pattern = r"```(?:json)?\s*\n?([\s\S]*?)\n?```" matches = re.findall(code_block_pattern, response) - + for match in matches: try: return json.loads(match.strip()) except json.JSONDecodeError: continue - + # 尝试查找 JSON 对象(以 { 开头,以 } 结尾) - json_pattern = r'\{[\s\S]*\}' + json_pattern = r"\{[\s\S]*\}" json_matches = re.findall(json_pattern, response) - + for match in json_matches: try: return json.loads(match) except json.JSONDecodeError: continue - + return None async def classify_meme(image_data: bytes) -> Optional[MemeClassification]: """ 使用 LLM 对表情包进行分类 - + Args: image_data: 图片二进制数据 - + Returns: MemeClassification 分类结果,如果分类失败返回 None """ try: # 转换图片为 base64 image_base64 = base64.b64encode(image_data).decode("utf-8") - + # 构建消息 messages = [ generate_message(MEME_CLASSIFICATION_PROMPT, "system"), @@ -131,17 +131,17 @@ async def classify_meme(image_data: bytes) -> Optional[MemeClassification]: "user", ), ] - + # 调用 LLM response = (await fetch_message(messages, identify="Meme Classification")).strip() - + # 解析 JSON(处理可能的 markdown 代码块) result = extract_json_from_response(response) - + if result is None: logger.warning(f"Failed to parse meme classification response: {response}") return None - + # 验证并转换结果 classification: MemeClassification = { "is_meme": bool(result.get("is_meme", False)), @@ -150,9 +150,9 @@ async def classify_meme(image_data: bytes) -> Optional[MemeClassification]: "labels": list(result.get("labels", [])), "context_keywords": list(result.get("context_keywords", [])), } - + return classification - + except Exception as e: logger.warning(f"Failed to classify meme: {e}\n{traceback.format_exc()}") return None @@ -207,21 +207,21 @@ async def save_sticker( # 计算感知哈希 p_hash = await calculate_hash_async(raw) - + # 调用 LLM 进行表情包分类 classification = await classify_meme(raw) - + # 准备分类数据 meme_text: Optional[str] = None emotion: Optional[str] = None labels_json: Optional[str] = None context_keywords_json: Optional[str] = None - + if classification is not None: # 如果不是表情包,拒绝添加 if not classification["is_meme"]: raise NotMemeError("该图片不是表情包,无法收藏") - + meme_text = classification["text"] emotion = classification["emotion"] labels_json = json.dumps(classification["labels"], ensure_ascii=False) @@ -385,10 +385,7 @@ async def filter_by_label(self, label: str, limit: int = 10) -> List[Sticker]: async with get_session() as session: # labels 字段是 JSON 数组字符串,使用 contains 进行模糊匹配 stmt = ( - select(Sticker) - .where(Sticker.labels.contains(label)) - .order_by(Sticker.created_time.desc()) - .limit(limit) + select(Sticker).where(Sticker.labels.contains(label)).order_by(Sticker.created_time.desc()).limit(limit) ) result = await session.scalars(stmt) return list(result.all()) @@ -458,13 +455,13 @@ async def filter_by_classification( async def migrate_existing_stickers(self, batch_size: int = 10) -> Dict[str, int]: """ Migrate existing stickers by classifying them with LLM - + This method processes stickers that don't have classification data (meme_text, emotion, labels, context_keywords are all None) - + Args: batch_size: Number of stickers to process in each batch - + Returns: Dict with migration statistics: - total: Total number of stickers processed @@ -473,7 +470,7 @@ async def migrate_existing_stickers(self, batch_size: int = 10) -> Dict[str, int - skipped: Number of stickers already classified """ stats = {"total": 0, "success": 0, "failed": 0, "skipped": 0} - + async with get_session() as session: # 查找所有未分类的表情包(分类字段都为 None) stmt = select(Sticker).where( @@ -484,79 +481,79 @@ async def migrate_existing_stickers(self, batch_size: int = 10) -> Dict[str, int ) result = await session.scalars(stmt) stickers_to_migrate = list(result.all()) - + stats["total"] = len(stickers_to_migrate) - + for sticker in stickers_to_migrate: try: # 调用 LLM 进行分类 classification = await classify_meme(sticker.raw) - + if classification is None: logger.warning(f"Failed to classify sticker {sticker.id}: LLM returned None") stats["failed"] += 1 continue - + # 更新分类信息 sticker.meme_text = classification["text"] sticker.emotion = classification["emotion"] sticker.labels = json.dumps(classification["labels"], ensure_ascii=False) sticker.context_keywords = json.dumps(classification["context_keywords"], ensure_ascii=False) - + session.add(sticker) stats["success"] += 1 - + logger.info(f"Successfully classified sticker {sticker.id}: emotion={classification['emotion']}") - + except Exception as e: logger.warning(f"Failed to migrate sticker {sticker.id}: {e}") stats["failed"] += 1 - + # 提交所有更改 await session.commit() - + logger.info( f"Sticker migration completed: {stats['success']} success, " f"{stats['failed']} failed, {stats['skipped']} skipped out of {stats['total']} total" ) - + return stats async def classify_single_sticker(self, sticker_id: int) -> bool: """ Classify a single sticker by its ID - + Args: sticker_id: The ID of the sticker to classify - + Returns: True if classification was successful, False otherwise """ async with get_session() as session: sticker = await session.get(Sticker, sticker_id) - + if sticker is None: logger.warning(f"Sticker {sticker_id} not found") return False - + try: classification = await classify_meme(sticker.raw) - + if classification is None: logger.warning(f"Failed to classify sticker {sticker_id}: LLM returned None") return False - + sticker.meme_text = classification["text"] sticker.emotion = classification["emotion"] sticker.labels = json.dumps(classification["labels"], ensure_ascii=False) sticker.context_keywords = json.dumps(classification["context_keywords"], ensure_ascii=False) - + session.add(sticker) await session.commit() - + logger.info(f"Successfully classified sticker {sticker_id}") return True - + except Exception as e: logger.warning(f"Failed to classify sticker {sticker_id}: {e}") return False From 6af4d839a1c2412e4571b011cf37e676b8e746d8 Mon Sep 17 00:00:00 2001 From: XiaoDeng3386 <1744793737@qq.com> Date: Sun, 25 Jan 2026 13:33:20 +0800 Subject: [PATCH 6/7] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=BA=93=E8=BF=81=E7=A7=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- migrations/versions/3875dac35e58_.py | 39 ++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 migrations/versions/3875dac35e58_.py diff --git a/migrations/versions/3875dac35e58_.py b/migrations/versions/3875dac35e58_.py new file mode 100644 index 00000000..49a19dd3 --- /dev/null +++ b/migrations/versions/3875dac35e58_.py @@ -0,0 +1,39 @@ +"""empty message + +迁移 ID: 3875dac35e58 +父迁移: 24c7dc8e7e67 +创建时间: 2026-01-25 13:33:07.981017 + +""" +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa + + +revision: str = '3875dac35e58' +down_revision: str | Sequence[str] | None = '24c7dc8e7e67' +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade(name: str = "") -> None: + if name: + return + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('nonebot_plugin_chat_sticker', schema=None) as batch_op: + batch_op.drop_column('is_meme') + + # ### end Alembic commands ### + + +def downgrade(name: str = "") -> None: + if name: + return + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('nonebot_plugin_chat_sticker', schema=None) as batch_op: + batch_op.add_column(sa.Column('is_meme', sa.BOOLEAN(), nullable=True)) + + # ### end Alembic commands ### From 26bbd628b4243d5599b3b72d7636c8b7ab487ca4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 25 Jan 2026 05:33:41 +0000 Subject: [PATCH 7/7] =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E5=8C=96=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- migrations/versions/3875dac35e58_.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/migrations/versions/3875dac35e58_.py b/migrations/versions/3875dac35e58_.py index 49a19dd3..38119573 100644 --- a/migrations/versions/3875dac35e58_.py +++ b/migrations/versions/3875dac35e58_.py @@ -5,6 +5,7 @@ 创建时间: 2026-01-25 13:33:07.981017 """ + from __future__ import annotations from collections.abc import Sequence @@ -12,9 +13,8 @@ from alembic import op import sqlalchemy as sa - -revision: str = '3875dac35e58' -down_revision: str | Sequence[str] | None = '24c7dc8e7e67' +revision: str = "3875dac35e58" +down_revision: str | Sequence[str] | None = "24c7dc8e7e67" branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None @@ -23,8 +23,8 @@ def upgrade(name: str = "") -> None: if name: return # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('nonebot_plugin_chat_sticker', schema=None) as batch_op: - batch_op.drop_column('is_meme') + with op.batch_alter_table("nonebot_plugin_chat_sticker", schema=None) as batch_op: + batch_op.drop_column("is_meme") # ### end Alembic commands ### @@ -33,7 +33,7 @@ def downgrade(name: str = "") -> None: if name: return # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('nonebot_plugin_chat_sticker', schema=None) as batch_op: - batch_op.add_column(sa.Column('is_meme', sa.BOOLEAN(), nullable=True)) + with op.batch_alter_table("nonebot_plugin_chat_sticker", schema=None) as batch_op: + batch_op.add_column(sa.Column("is_meme", sa.BOOLEAN(), nullable=True)) # ### end Alembic commands ###