diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..a9f6f68a --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +.DS_Store +.idea +__pycache__/ +venv* +*.pyc +config.json +QR.png diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..5f0162e5 --- /dev/null +++ b/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2023 zhayujie + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 00000000..702d996d --- /dev/null +++ b/README.md @@ -0,0 +1,62 @@ +# 简介 + +将 **AI模型** 接入各类 **消息应用**,开发者通过轻量配置即可在二者之间选择一条连线,运行起一个智能对话机器人,在一个项目中轻松完成多条链路的切换。该架构扩展性强,每接入一个应用可复用已有的算法能力,同样每接入一个模型也可作用于所有应用之上。 + +**模型:** + + - [x] ChatGPT + - ... + +**应用:** + + - [ ] 终端 + - [ ] Web + - [x] 个人微信 + - [x] 公众号 + - [ ] 企业微信 + - [ ] Telegram + - [ ] QQ + - [ ] 钉钉 + - ... + + + +# 快速开始 + +## 一、准备 + +### 1.运行环境 + +支持 Linux、MacOS、Windows 系统(Linux服务器上可长期运行)。同时需安装 Python,建议Python版本在 3.7.1~3.10 之间。 + +项目代码克隆: + +```bash +git clone https://github.com/zhayujie/bot-on-anything +cd bot-on-anything/ +``` +> 或在 Realase 直接手动下载源码。 + +### 2.配置说明 + +核心配置文件为 `config.json`, + + + +## 二、选择模型 + +### 1.ChatGPT + + +## 三、选择应用 + +### 1.微信 + + +### 2.公众号 + + + + +## 四、运行 + diff --git a/app.py b/app.py new file mode 100644 index 00000000..fa41023b --- /dev/null +++ b/app.py @@ -0,0 +1,21 @@ +# encoding:utf-8 + +import config +from channel import channel_factory +from common.log import logger + + +if __name__ == '__main__': + try: + # load config + config.load_config() + logger.info("[INIT] load config: {}".format(config.conf())) + + # create channel + channel = channel_factory.create_channel(config.conf().get("channel")) + + # startup channel + channel.startup() + except Exception as e: + logger.error("App startup failed!") + logger.exception(e) diff --git a/bridge/bridge.py b/bridge/bridge.py new file mode 100644 index 00000000..bc5b2f3b --- /dev/null +++ b/bridge/bridge.py @@ -0,0 +1,9 @@ +from model import model_factory +import config + +class Bridge(object): + def __init__(self): + pass + + def fetch_reply_content(self, query, context): + return model_factory.create_bot(config.conf().get("model")).reply(query, context) diff --git a/channel/channel.py b/channel/channel.py new file mode 100644 index 00000000..e2617d1f --- /dev/null +++ b/channel/channel.py @@ -0,0 +1,31 @@ +""" +Message sending channel abstract class +""" + +from bridge.bridge import Bridge + +class Channel(object): + def startup(self): + """ + init channel + """ + raise NotImplementedError + + def handle(self, msg): + """ + process received msg + :param msg: message object + """ + raise NotImplementedError + + def send(self, msg, receiver): + """ + send message to user + :param msg: message content + :param receiver: receiver channel account + :return: + """ + raise NotImplementedError + + def build_reply_content(self, query, context=None): + return Bridge().fetch_reply_content(query, context) diff --git a/channel/channel_factory.py b/channel/channel_factory.py new file mode 100644 index 00000000..ad60c4be --- /dev/null +++ b/channel/channel_factory.py @@ -0,0 +1,21 @@ +""" +channel factory +""" +from common import const + +def create_channel(channel_type): + """ + create a channel instance + :param channel_type: channel type code + :return: channel instance + """ + if channel_type == const.WECHAT: + from channel.wechat.wechat_channel import WechatChannel + return WechatChannel() + + elif channel_type == const.WECHAT_MP: + from channel.wechat.wechat_mp_channel import WechatPublicAccount + return WechatPublicAccount() + + else: + raise RuntimeError diff --git a/channel/terminal/terminal_channel.py b/channel/terminal/terminal_channel.py new file mode 100644 index 00000000..e69de29b diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py new file mode 100644 index 00000000..75c3c687 --- /dev/null +++ b/channel/wechat/wechat_channel.py @@ -0,0 +1,165 @@ +# encoding:utf-8 + +""" +wechat channel +""" +import itchat +import json +from itchat.content import * +from channel.channel import Channel +from concurrent.futures import ThreadPoolExecutor +from common.log import logger +from config import conf +import requests +import io + +thread_pool = ThreadPoolExecutor(max_workers=8) + + +@itchat.msg_register(TEXT) +def handler_single_msg(msg): + WechatChannel().handle(msg) + return None + + +@itchat.msg_register(TEXT, isGroupChat=True) +def handler_group_msg(msg): + WechatChannel().handle_group(msg) + return None + + +class WechatChannel(Channel): + def __init__(self): + pass + + def startup(self): + # login by scan QRCode + itchat.auto_login(enableCmdQR=2) + + # start message listener + itchat.run() + + def handle(self, msg): + logger.debug("[WX]receive msg: " + json.dumps(msg, ensure_ascii=False)) + from_user_id = msg['FromUserName'] + to_user_id = msg['ToUserName'] # 接收人id + other_user_id = msg['User']['UserName'] # 对手方id + content = msg['Text'] + match_prefix = self.check_prefix(content, conf().get('single_chat_prefix')) + if from_user_id == other_user_id and match_prefix is not None: + # 好友向自己发送消息 + if match_prefix != '': + str_list = content.split(match_prefix, 1) + if len(str_list) == 2: + content = str_list[1].strip() + + img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix')) + if img_match_prefix: + content = content.split(img_match_prefix, 1)[1].strip() + thread_pool.submit(self._do_send_img, content, from_user_id) + else: + thread_pool.submit(self._do_send, content, from_user_id) + + elif to_user_id == other_user_id and match_prefix: + # 自己给好友发送消息 + str_list = content.split(match_prefix, 1) + if len(str_list) == 2: + content = str_list[1].strip() + img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix')) + if img_match_prefix: + content = content.split(img_match_prefix, 1)[1].strip() + thread_pool.submit(self._do_send_img, content, to_user_id) + else: + thread_pool.submit(self._do_send, content, to_user_id) + + + def handle_group(self, msg): + logger.debug("[WX]receive group msg: " + json.dumps(msg, ensure_ascii=False)) + group_name = msg['User'].get('NickName', None) + group_id = msg['User'].get('UserName', None) + if not group_name: + return "" + origin_content = msg['Content'] + content = msg['Content'] + content_list = content.split(' ', 1) + context_special_list = content.split('\u2005', 1) + if len(context_special_list) == 2: + content = context_special_list[1] + elif len(content_list) == 2: + content = content_list[1] + + config = conf() + match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or self.check_prefix(origin_content, config.get('group_chat_prefix')) \ + or self.check_contain(origin_content, config.get('group_chat_keyword')) + if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or self.check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix: + img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix')) + if img_match_prefix: + content = content.split(img_match_prefix, 1)[1].strip() + thread_pool.submit(self._do_send_img, content, group_id) + else: + thread_pool.submit(self._do_send_group, content, msg) + + def send(self, msg, receiver): + logger.info('[WX] sendMsg={}, receiver={}'.format(msg, receiver)) + itchat.send(msg, toUserName=receiver) + + def _do_send(self, query, reply_user_id): + try: + if not query: + return + context = dict() + context['from_user_id'] = reply_user_id + reply_text = super().build_reply_content(query, context) + if reply_text: + self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id) + except Exception as e: + logger.exception(e) + + def _do_send_img(self, query, reply_user_id): + try: + if not query: + return + context = dict() + context['type'] = 'IMAGE_CREATE' + img_url = super().build_reply_content(query, context) + if not img_url: + return + + # 图片下载 + pic_res = requests.get(img_url, stream=True) + image_storage = io.BytesIO() + for block in pic_res.iter_content(1024): + image_storage.write(block) + image_storage.seek(0) + + # 图片发送 + logger.info('[WX] sendImage, receiver={}'.format(reply_user_id)) + itchat.send_image(image_storage, reply_user_id) + except Exception as e: + logger.exception(e) + + def _do_send_group(self, query, msg): + if not query: + return + context = dict() + context['from_user_id'] = msg['ActualUserName'] + reply_text = super().build_reply_content(query, context) + if reply_text: + reply_text = '@' + msg['ActualNickName'] + ' ' + reply_text.strip() + self.send(conf().get("group_chat_reply_prefix", "") + reply_text, msg['User']['UserName']) + + + def check_prefix(self, content, prefix_list): + for prefix in prefix_list: + if content.startswith(prefix): + return prefix + return None + + + def check_contain(self, content, keyword_list): + if not keyword_list: + return None + for ky in keyword_list: + if content.find(ky) != -1: + return True + return None diff --git a/channel/wechat/wechat_mp_channel.py b/channel/wechat/wechat_mp_channel.py new file mode 100644 index 00000000..6a4fac63 --- /dev/null +++ b/channel/wechat/wechat_mp_channel.py @@ -0,0 +1,58 @@ +import werobot +import time +import config +from common import const +from common.log import logger +from channel.channel import Channel +from concurrent.futures import ThreadPoolExecutor + +robot = werobot.WeRoBot(token=config.fetch(const.WECHAT_MP).get('token')) +thread_pool = ThreadPoolExecutor(max_workers=8) +cache = {} + +@robot.text +def hello_world(msg): + logger.info('[WX_Public] receive public msg: {}, userId: {}'.format(msg.content, msg.source)) + key = msg.content + '|' + msg.source + if cache.get(key): + cache.get(key)['req_times'] += 1 + return WechatPublicAccount().handle(msg) + + +class WechatPublicAccount(Channel): + def startup(self): + logger.info('[WX_Public] Wechat Public account service start!') + robot.config['PORT'] = config.fetch(const.WECHAT_MP).get('port') + robot.run() + + def handle(self, msg, count=0): + context = dict() + context['from_user_id'] = msg.source + key = msg.content + '|' + msg.source + res = cache.get(key) + if not res: + thread_pool.submit(self._do_send, msg.content, context) + temp = {'flag': True, 'req_times': 1} + cache[key] = temp + if count < 10: + time.sleep(2) + return self.handle(msg, count+1) + + elif res.get('flag', False) and res.get('data', None): + cache.pop(key) + return res['data'] + + elif res.get('flag', False) and not res.get('data', None): + if res.get('req_times') == 3 and count == 9: + return '不好意思我的CPU烧了,请再问我一次吧~' + if count < 10: + time.sleep(0.5) + return self.handle(msg, count+1) + return "请再说一次" + + + def _do_send(self, query, context): + reply_text = super().build_reply_content(query, context) + logger.info('[WX_Public] reply content: {}'.format(reply_text)) + key = query + '|' + context['from_user_id'] + cache[key] = {'flag': True, 'data': reply_text} diff --git a/common/const.py b/common/const.py new file mode 100644 index 00000000..1f525c40 --- /dev/null +++ b/common/const.py @@ -0,0 +1,6 @@ +# channel +WECHAT = "wechat" +WECHAT_MP = "wechat_mp" + +# model +OPEN_AI = "openai" diff --git a/common/log.py b/common/log.py new file mode 100644 index 00000000..616e5eb5 --- /dev/null +++ b/common/log.py @@ -0,0 +1,18 @@ +# encoding:utf-8 + +import logging +import sys + + +def _get_logger(): + log = logging.getLogger('log') + log.setLevel(logging.INFO) + console_handle = logging.StreamHandler(sys.stdout) + console_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S')) + log.addHandler(console_handle) + return log + + +# 日志句柄 +logger = _get_logger() \ No newline at end of file diff --git a/config-template.json b/config-template.json new file mode 100644 index 00000000..d524472d --- /dev/null +++ b/config-template.json @@ -0,0 +1,23 @@ +{ + "channel": "wechat", + "bot": "openai", + + "openai": { + "api_key": "YOUR API KEY", + "conversation_max_tokens": 1000, + "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。" + }, + + "wechat": { + "single_chat_prefix": ["bot", "@bot"], + "single_chat_reply_prefix": "[bot] ", + "group_chat_prefix": ["@bot"], + "group_name_white_list": ["ALL_GROUP"], + "image_create_prefix": ["画", "看", "找一张"] + }, + + "wechat_mp": { + "token": "YOUR TOKEN", + "port": "8088" + } +} diff --git a/config.py b/config.py new file mode 100644 index 00000000..e203b8da --- /dev/null +++ b/config.py @@ -0,0 +1,32 @@ +# encoding:utf-8 + +import json +import os + +config = {} + + +def load_config(): + global config + config_path = "config.json" + if not os.path.exists(config_path): + raise Exception('配置文件不存在,请根据config-template.json模板创建config.json文件') + + config_str = read_file(config_path) + # 将json字符串反序列化为dict类型 + config = json.loads(config_str) + +def get_root(): + return os.path.dirname(os.path.abspath( __file__ )) + + +def read_file(path): + with open(path, mode='r', encoding='utf-8') as f: + return f.read() + + +def conf(): + return config + +def fetch(model): + return config.get(model) diff --git a/model/model.py b/model/model.py new file mode 100644 index 00000000..28827667 --- /dev/null +++ b/model/model.py @@ -0,0 +1,13 @@ +""" +Auto-replay chat robot abstract class +""" + + +class Model(object): + def reply(self, query, context=None): + """ + model auto-reply content + :param req: received message + :return: reply content + """ + raise NotImplementedError diff --git a/model/model_factory.py b/model/model_factory.py new file mode 100644 index 00000000..4501bff8 --- /dev/null +++ b/model/model_factory.py @@ -0,0 +1,19 @@ +""" +channel factory +""" + +from common import const + +def create_bot(model_type): + """ + create a channel instance + :param channel_type: channel type code + :return: channel instance + """ + + if model_type == const.OPEN_AI: + # OpenAI 官方对话模型API + from model.openai.open_ai_model import OpenAIModel + return OpenAIModel() + + raise RuntimeError diff --git a/model/openai/.DS_Store b/model/openai/.DS_Store new file mode 100644 index 00000000..5008ddfc Binary files /dev/null and b/model/openai/.DS_Store differ diff --git a/model/openai/open_ai_model.py b/model/openai/open_ai_model.py new file mode 100644 index 00000000..8fbd4bf3 --- /dev/null +++ b/model/openai/open_ai_model.py @@ -0,0 +1,160 @@ +# encoding:utf-8 + +from model.model import Model +from config import fetch +from common import const +from common.log import logger +import openai +import time + +user_session = dict() + +# OpenAI对话模型API (可用) +class OpenAIModel(Model): + def __init__(self): + openai.api_key = fetch(const.OPEN_AI).get('api_key') + + + def reply(self, query, context=None): + # acquire reply content + if not context or not context.get('type') or context.get('type') == 'TEXT': + logger.info("[OPEN_AI] query={}".format(query)) + from_user_id = context['from_user_id'] + if query == '#清除记忆': + Session.clear_session(from_user_id) + return '记忆已清除' + + new_query = Session.build_session_query(query, from_user_id) + logger.debug("[OPEN_AI] session query={}".format(new_query)) + + reply_content = self.reply_text(new_query, from_user_id, 0) + logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content)) + if reply_content and query: + Session.save_session(query, reply_content, from_user_id) + return reply_content + + elif context.get('type', None) == 'IMAGE_CREATE': + return self.create_img(query, 0) + + def reply_text(self, query, user_id, retry_count=0): + try: + response = openai.Completion.create( + model="text-davinci-003", # 对话模型的名称 + prompt=query, + temperature=0.9, # 值在[0,1]之间,越大表示回复越具有不确定性 + max_tokens=1200, # 回复最大的字符数 + top_p=1, + frequency_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容 + presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容 + stop=["\n\n\n"] + ) + res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '') + logger.info("[OPEN_AI] reply={}".format(res_content)) + return res_content + except openai.error.RateLimitError as e: + # rate limit exception + logger.warn(e) + if retry_count < 1: + time.sleep(5) + logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1)) + return self.reply_text(query, user_id, retry_count+1) + else: + return "提问太快啦,请休息一下再问我吧" + except Exception as e: + # unknown exception + logger.exception(e) + Session.clear_session(user_id) + return "请再问我一次吧" + + + def create_img(self, query, retry_count=0): + try: + logger.info("[OPEN_AI] image_query={}".format(query)) + response = openai.Image.create( + prompt=query, #图片描述 + n=1, #每次生成图片的数量 + size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024 + ) + image_url = response['data'][0]['url'] + logger.info("[OPEN_AI] image_url={}".format(image_url)) + return image_url + except openai.error.RateLimitError as e: + logger.warn(e) + if retry_count < 1: + time.sleep(5) + logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) + return self.reply_text(query, retry_count+1) + else: + return "提问太快啦,请休息一下再问我吧" + except Exception as e: + logger.exception(e) + return None + + +class Session(object): + @staticmethod + def build_session_query(query, user_id): + ''' + build query with conversation history + e.g. Q: xxx + A: xxx + Q: xxx + :param query: query content + :param user_id: from user id + :return: query content with conversaction + ''' + prompt = fetch(const.OPEN_AI).get("character_desc", "") + if prompt: + prompt += "<|endoftext|>\n\n\n" + session = user_session.get(user_id, None) + if session: + for conversation in session: + prompt += "Q: " + conversation["question"] + "\n\n\nA: " + conversation["answer"] + "<|endoftext|>\n" + prompt += "Q: " + query + "\nA: " + return prompt + else: + return prompt + "Q: " + query + "\nA: " + + @staticmethod + def save_session(query, answer, user_id): + max_tokens = fetch(const.OPEN_AI).get("conversation_max_tokens") + if not max_tokens: + # default 3000 + max_tokens = 1000 + conversation = dict() + conversation["question"] = query + conversation["answer"] = answer + session = user_session.get(user_id) + logger.debug(conversation) + logger.debug(session) + if session: + # append conversation + session.append(conversation) + else: + # create session + queue = list() + queue.append(conversation) + user_session[user_id] = queue + + # discard exceed limit conversation + Session.discard_exceed_conversation(user_session[user_id], max_tokens) + + + @staticmethod + def discard_exceed_conversation(session, max_tokens): + count = 0 + count_list = list() + for i in range(len(session)-1, -1, -1): + # count tokens of conversation list + history_conv = session[i] + count += len(history_conv["question"]) + len(history_conv["answer"]) + count_list.append(count) + + for c in count_list: + if c > max_tokens: + # pop first conversation + session.pop(0) + + @staticmethod + def clear_session(user_id): + user_session[user_id] = []