diff --git a/README.md b/README.md index 43fe4277..e98a79fb 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ **应用:** - - [ ] 终端 + - [x] 终端 - [ ] Web - [x] 个人微信 - [x] 公众号 (个人/企业) @@ -36,13 +36,13 @@ cd bot-on-anything/ ### 2.配置说明 -核心配置文件为 `config.json`,项目中提供了模板文件 `config-template.json` ,可以从模板复制生成最终生效的 `config.json` 文件: +核心配置文件为 `config.json`,在项目中提供了模板文件 `config-template.json` ,可以从模板复制生成最终生效的 `config.json` 文件: ```bash cp config-template.json config.json ``` -完整的配置文件结构如下: +每一个模型和应用都有自己的配置块,最终组成完整的配置文件,整体结构如下: ```bash { @@ -63,9 +63,9 @@ cp config-template.json config.json } } ``` -配置文件在最外层分成 `model` 和 `channel` 两部分,model 部分为模型配置,其中的 `type` 指定了选用哪个模型;`channel` 部分包含了应用渠道的配置,`type` 字段指定了接入哪个应用,同时下方对应的配置块也会生效。 +配置文件在最外层分成 `model` 和 `channel` 两部分,model部分为模型配置,其中的 `type` 指定了选用哪个模型;channel部分包含了应用渠道的配置,`type` 字段指定了接入哪个应用。 -在使用时只需要更改 `model` 和 `channel` 配置块下的 `type` 字段,即可在任意模型和应用间完成切换,连接不同的通路。下面将依次介绍各个 模型 及 应用 的配置和运行过程。 +在使用时只需要更改 model 和 channel 配置块下的 type 字段,即可在任意模型和应用间完成切换,连接不同的通路。下面将依次介绍各个 模型 及 应用 的配置和运行过程。 ## 二、选择模型 @@ -105,9 +105,17 @@ pip3 install --upgrade openai + `character_desc` 配置中保存着你对机器人说的一段话,他会记住这段话并作为他的设定,你可以为他定制任何人格 -## 三、运行应用 +## 三、选择应用 -### 1.个人微信 +### 1.命令行终端 + +配置模板中默认启动的应用即是终端,无需任何额外配置,直接在项目目录下通过命令行执行 `python3 app.py` 便可启动程序。用户通过命令行的输入与对话模型交互,且支持流式响应效果。 + +![terminal_demo.png](docs/images/terminal_demo.png) + + + +### 2.个人微信 与项目 [chatgpt-on-wechat](https://github.com/zhayujie/chatgpt-on-wechat) 的使用方式相同,目前接入个人微信可能导致账号被限制,暂时不建议使用。 @@ -131,9 +139,9 @@ pip3 install --upgrade openai 在项目根目录下执行 `python3 app.py` 即可启动程序,用手机扫码后完成登录,使用详情参考 [chatgpt-on-wechat](https://github.com/zhayujie/chatgpt-on-wechat)。 -### 2.个人订阅号 +### 3.个人订阅号 -**需要:**一台服务器,一个个人订阅号,一个已备案的域名。 +**需要:** 一台服务器,一个订阅号 #### 2.1 依赖安装 @@ -174,21 +182,28 @@ Hit Ctrl-C to quit. ![wx_mp_config.png](docs/images/wx_mp_config.png) -- **服务器地址 (URL)**:在浏览器访问该URL需要能访问到服务器上运行的python程序 (默认监听8088端口)。由于公众号只能配置 80/443端口,所以需要在服务器进行端口转发 (如使用nginx),或改为直接监听80端口,并将对应的域名地址配置在url处 (仅用ip不行)。 -- **令牌 (Token)**:需和配置中的token一致。 +**服务器地址 (URL) 配置**: 如果在浏览器上通过配置的URL 能够访问到服务器上的Python程序 (默认监听8088端口),则说明配置有效。由于公众号只能配置 80/443端口,可以修改配置为直接监听 80 端口 (需要sudo权限),或者使用反向代理进行转发 (如nginx)。 根据官方文档说明,此处填写公网ip或域名均可。 + +**令牌 (Token) 配置**:需和 `config.json` 配置中的token一致。 + +详细操作过程参考 [官方文档](https://developers.weixin.qq.com/doc/offiaccount/Getting_Started/Getting_Started_Guide.html) + #### 2.3 使用 用户关注订阅号后,发送消息即可。 > 注:用户发送消息后,微信后台会向配置的URL地址推送,但如果5s内未回复就会断开连接,同时重试3次,但往往请求openai接口不止5s。本项目中通过异步和缓存将5s超时限制优化至15s,但超出该时间仍无法正常回复。 同时每次5s连接断开时web框架会报错,待后续优化。 - -### 3.企业服务号 -在企业服务号中,通过先异步访问openai接口,再通过客服接口主动推送用户的方式,解决了个人订阅号的15s超时问题。 -企业服务号配置只需修改type为`wechat_mp_service`,配置块仍复用 `wechat_mp`,在基础上增加了 `app_id` 和 `app_secret` 两个配置项。 +### 4.企业服务号 + +**需要:** 一个服务器、一个已微信认证的服务号 + +在企业服务号中,通过先异步访问openai接口,再通过客服接口主动推送给用户的方式,解决了个人订阅号的15s超时问题。服务号的开发者模式配置和上述订阅号类似,详情参考 [官方文档](https://developers.weixin.qq.com/doc/offiaccount/Getting_Started/Getting_Started_Guide.html)。 + +企业服务号的 `config.json` 配置只需修改type为`wechat_mp_service`,但配置块仍复用 `wechat_mp`,在此基础上需要增加 `app_id` 和 `app_secret` 两个配置项。 ```bash "channel": { @@ -197,7 +212,7 @@ Hit Ctrl-C to quit. "wechat_mp": { "token": "YOUR TOKEN", # token值 "port": "8088", # 程序启动监听的端口 - "app_id": "YOUR APP ID", # appID + "app_id": "YOUR APP ID", # app ID "app_secret": "YOUR APP SECRET" # app secret } } diff --git a/app.py b/app.py index d5e03280..8e95a7d1 100644 --- a/app.py +++ b/app.py @@ -2,20 +2,24 @@ import config from channel import channel_factory -from common.log import logger +from common import log if __name__ == '__main__': try: # load config config.load_config() - logger.info("[INIT] load config: {}".format(config.conf())) + + model_type = config.conf().get("model").get("type") + channel_type = config.conf().get("channel").get("type") + + log.info("[INIT] Start up: {} on {}", model_type, channel_type) # create channel - channel = channel_factory.create_channel(config.conf().get("channel").get("type")) + channel = channel_factory.create_channel(channel_type) # startup channel channel.startup() except Exception as e: - logger.error("App startup failed!") - logger.exception(e) + log.error("App startup failed!") + log.exception(e) diff --git a/channel/channel_factory.py b/channel/channel_factory.py index 0758a7a2..dd622288 100644 --- a/channel/channel_factory.py +++ b/channel/channel_factory.py @@ -9,6 +9,10 @@ def create_channel(channel_type): :param channel_type: channel type code :return: channel instance """ + if channel_type== const.TERMINAL: + from channel.terminal.terminal_channel import TerminalChannel + return TerminalChannel() + if channel_type == const.WECHAT: from channel.wechat.wechat_channel import WechatChannel return WechatChannel() diff --git a/channel/terminal/terminal_channel.py b/channel/terminal/terminal_channel.py index e69de29b..c09f2797 100644 --- a/channel/terminal/terminal_channel.py +++ b/channel/terminal/terminal_channel.py @@ -0,0 +1,33 @@ +from channel.channel import Channel +from common import log + +import sys + +class TerminalChannel(Channel): + def startup(self): + # close log + log.close_log() + context = {"from_user_id": "User", "stream": True} + print("Please input your question\n") + while True: + try: + prompt = self.get_input("User:\n") + except KeyboardInterrupt: + print("\nExiting...") + sys.exit() + + print("Bot:") + sys.stdout.flush() + for res in super().build_reply_content(prompt, context): + print(res, end="") + sys.stdout.flush() + print("\n") + + + def get_input(self, prompt): + """ + Multi-line input function + """ + print(prompt, end="") + line = input() + return line diff --git a/channel/wechat/wechat_mp_service_channel.py b/channel/wechat/wechat_mp_service_channel.py index 14677246..02ab8cc8 100644 --- a/channel/wechat/wechat_mp_service_channel.py +++ b/channel/wechat/wechat_mp_service_channel.py @@ -18,8 +18,8 @@ class WechatServiceAccount(Channel): def startup(self): logger.info('[WX_Public] Wechat Public account service start!') robot.config['PORT'] = channel_conf(const.WECHAT_MP).get('port') - robot.config["APP_ID"] = "YOUR APP ID" - robot.config["APP_SECRET"] = "YOUR APP SECRET" + robot.config["APP_ID"] = channel_conf(const.WECHAT_MP).get('app_id') + robot.config["APP_SECRET"] = channel_conf(const.WECHAT_MP).get('app_secret') robot.run() def handle(self, msg, count=0): diff --git a/common/const.py b/common/const.py index 1ae175fc..6b10525a 100644 --- a/common/const.py +++ b/common/const.py @@ -1,4 +1,5 @@ # channel +TERMINAL = "terminal" WECHAT = "wechat" WECHAT_MP = "wechat_mp" WECHAT_MP_SERVICE = "wechat_mp_service" diff --git a/common/log.py b/common/log.py index 616e5eb5..2510bcb2 100644 --- a/common/log.py +++ b/common/log.py @@ -3,6 +3,7 @@ import logging import sys +SWITCH = True def _get_logger(): log = logging.getLogger('log') @@ -13,6 +14,41 @@ def _get_logger(): log.addHandler(console_handle) return log +def close_log(): + global SWITCH + SWITCH = False + + +def debug(arg, *args): + if SWITCH: + if len(args) == 0: + logger.debug(arg) + else: + logger.debug(arg.format(*args)) + +def info(arg, *args): + if SWITCH: + if len(args) == 0: + logger.info(arg) + else: + logger.info(arg.format(*args)) + + +def warn(arg, *args): + if len(args) == 0: + logger.warning(arg) + else: + logger.warning(arg.format(*args)) + +def error(arg, *args): + if len(args) == 0: + logger.error(arg) + else: + logger.error(arg.format(*args)) + +def exception(e): + logger.exception(e) + # 日志句柄 -logger = _get_logger() \ No newline at end of file +logger = _get_logger() diff --git a/config-template.json b/config-template.json index f6c2991e..c30b2102 100644 --- a/config-template.json +++ b/config-template.json @@ -8,7 +8,7 @@ } }, "channel": { - "type": "wechat_mp", + "type": "terminal", "single_chat_prefix": ["bot", "@bot"], "single_chat_reply_prefix": "[bot] ", "group_chat_prefix": ["@bot"], diff --git a/docs/images/terminal_demo.png b/docs/images/terminal_demo.png new file mode 100644 index 00000000..2464fa4f Binary files /dev/null and b/docs/images/terminal_demo.png differ diff --git a/docs/images/wx_mp_config.png b/docs/images/wx_mp_config.png index a84beadd..36ea91c5 100644 Binary files a/docs/images/wx_mp_config.png and b/docs/images/wx_mp_config.png differ diff --git a/model/openai/open_ai_model.py b/model/openai/open_ai_model.py index dd75ca2b..36dc3c45 100644 --- a/model/openai/open_ai_model.py +++ b/model/openai/open_ai_model.py @@ -3,7 +3,7 @@ from model.model import Model from config import model_conf from common import const -from common.log import logger +from common import log import openai import time @@ -18,17 +18,21 @@ def __init__(self): def reply(self, query, context=None): # acquire reply content if not context or not context.get('type') or context.get('type') == 'TEXT': - logger.info("[OPEN_AI] query={}".format(query)) + log.info("[OPEN_AI] query={}".format(query)) from_user_id = context['from_user_id'] if query == '#清除记忆': Session.clear_session(from_user_id) return '记忆已清除' new_query = Session.build_session_query(query, from_user_id) - logger.debug("[OPEN_AI] session query={}".format(new_query)) + log.debug("[OPEN_AI] session query={}".format(new_query)) + + if context.get('stream'): + # reply in stream + return self.reply_text_stream(query, new_query, from_user_id) reply_content = self.reply_text(new_query, from_user_id, 0) - logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content)) + log.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content)) if reply_content and query: Session.save_session(query, reply_content, from_user_id) return reply_content @@ -49,45 +53,98 @@ def reply_text(self, query, user_id, retry_count=0): stop=["\n\n\n"] ) res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '') - logger.info("[OPEN_AI] reply={}".format(res_content)) + log.info("[OPEN_AI] reply={}".format(res_content)) return res_content except openai.error.RateLimitError as e: # rate limit exception - logger.warn(e) + log.warn(e) if retry_count < 1: time.sleep(5) - logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1)) + log.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1)) return self.reply_text(query, user_id, retry_count+1) else: return "提问太快啦,请休息一下再问我吧" except Exception as e: # unknown exception - logger.exception(e) + log.exception(e) Session.clear_session(user_id) return "请再问我一次吧" + def reply_text_stream(self, query, new_query, user_id, retry_count=0): + try: + res = openai.Completion.create( + model="text-davinci-003", # 对话模型的名称 + prompt=new_query, + temperature=0.9, # 值在[0,1]之间,越大表示回复越具有不确定性 + max_tokens=1200, # 回复最大的字符数 + top_p=1, + frequency_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容 + presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容 + stop=["\n\n\n"], + stream=True + ) + return self._process_reply_stream(query, res, user_id) + + except openai.error.RateLimitError as e: + # rate limit exception + log.warn(e) + if retry_count < 1: + time.sleep(5) + log.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1)) + return self.reply_text(query, user_id, retry_count+1) + else: + return "提问太快啦,请休息一下再问我吧" + except Exception as e: + # unknown exception + log.exception(e) + Session.clear_session(user_id) + return "请再问我一次吧" + + + def _process_reply_stream( + self, + query: str, + reply: dict, + user_id: str + ) -> str: + full_response = "" + for response in reply: + if response.get("choices") is None or len(response["choices"]) == 0: + raise Exception("OpenAI API returned no choices") + if response["choices"][0].get("finish_details") is not None: + break + if response["choices"][0].get("text") is None: + raise Exception("OpenAI API returned no text") + if response["choices"][0]["text"] == "<|endoftext|>": + break + yield response["choices"][0]["text"] + full_response += response["choices"][0]["text"] + if query and full_response: + Session.save_session(query, full_response, user_id) + + def create_img(self, query, retry_count=0): try: - logger.info("[OPEN_AI] image_query={}".format(query)) + log.info("[OPEN_AI] image_query={}".format(query)) response = openai.Image.create( prompt=query, #图片描述 n=1, #每次生成图片的数量 size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024 ) image_url = response['data'][0]['url'] - logger.info("[OPEN_AI] image_url={}".format(image_url)) + log.info("[OPEN_AI] image_url={}".format(image_url)) return image_url except openai.error.RateLimitError as e: - logger.warn(e) + log.warn(e) if retry_count < 1: time.sleep(5) - logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) + log.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) return self.reply_text(query, retry_count+1) else: return "提问太快啦,请休息一下再问我吧" except Exception as e: - logger.exception(e) + log.exception(e) return None @@ -125,8 +182,8 @@ def save_session(query, answer, user_id): conversation["question"] = query conversation["answer"] = answer session = user_session.get(user_id) - logger.debug(conversation) - logger.debug(session) + log.debug(conversation) + log.debug(session) if session: # append conversation session.append(conversation)