diff --git a/modules/ros_chatbot/.gitignore b/modules/ros_chatbot/.gitignore new file mode 100644 index 0000000..c7b50e7 --- /dev/null +++ b/modules/ros_chatbot/.gitignore @@ -0,0 +1,114 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +.*.swp +.*.swo +install +*.pyc +*.bak + +__cache__ +models/ +.vscode diff --git a/modules/ros_chatbot/.isort.cfg b/modules/ros_chatbot/.isort.cfg new file mode 100644 index 0000000..ff4458d --- /dev/null +++ b/modules/ros_chatbot/.isort.cfg @@ -0,0 +1,3 @@ +[settings] +profile = black +known_first_party = ros_chatbot diff --git a/modules/ros_chatbot/.pre-commit-config.yaml b/modules/ros_chatbot/.pre-commit-config.yaml new file mode 100644 index 0000000..1057f4f --- /dev/null +++ b/modules/ros_chatbot/.pre-commit-config.yaml @@ -0,0 +1,62 @@ +exclude: '^docs/conf.py' + +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: trailing-whitespace + - id: check-added-large-files + - id: check-ast + - id: check-json + - id: check-merge-conflict + - id: check-xml + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer + - id: requirements-txt-fixer + - id: mixed-line-ending + args: ['--fix=auto'] # replace 'auto' with 'lf' to enforce Linux/Mac line endings or 'crlf' for Windows + +## If you want to avoid flake8 errors due to unused vars or imports: +# - repo: https://github.com/myint/autoflake.git +# rev: v1.4 +# hooks: +# - id: autoflake +# args: [ +# --in-place, +# --remove-all-unused-imports, +# --remove-unused-variables, +# ] + +- repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + args: ["--profile", "black"] + + +- repo: https://github.com/psf/black + rev: 24.4.2 + hooks: + - id: black + language_version: python3 + args: # arguments to configure black + - --line-length=88 + + +## If like to embrace black styles even in the docs: +# - repo: https://github.com/asottile/blacken-docs +# rev: v1.11.0 +# hooks: +# - id: blacken-docs +# additional_dependencies: [black] + +- repo: https://github.com/PyCQA/flake8 + rev: 7.1.0 + hooks: + - id: flake8 + ## You can add flake8 plugins via `additional_dependencies`: + # additional_dependencies: [flake8-bugbear] + args: # arguments to configure flake8 + - "--max-line-length=88" + - "--ignore=E203,E266,E501,W503,F403,F401,E402" diff --git a/modules/ros_chatbot/CMakeLists.txt b/modules/ros_chatbot/CMakeLists.txt new file mode 100644 index 0000000..e01ab63 --- /dev/null +++ b/modules/ros_chatbot/CMakeLists.txt @@ -0,0 +1,220 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +cmake_minimum_required(VERSION 2.8.3) +project(ros_chatbot) + +## Compile as C++11, supported in ROS Kinetic and newer +# add_compile_options(-std=c++11) + +## Find catkin macros and libraries +## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz) +## is used, also find other catkin packages +find_package(catkin REQUIRED message_generation) +find_package(catkin REQUIRED COMPONENTS + dynamic_reconfigure + rospy +) + +## System dependencies are found with CMake's conventions +# find_package(Boost REQUIRED COMPONENTS system) + + +## Uncomment this if the package has a setup.py. This macro ensures +## modules and global scripts declared therein get installed +## See http://ros.org/doc/api/catkin/html/user_guide/setup_dot_py.html +catkin_python_setup() + +################################################ +## Declare ROS messages, services and actions ## +################################################ + +## To declare and build messages, services or actions from within this +## package, follow these steps: +## * Let MSG_DEP_SET be the set of packages whose message types you use in +## your messages/services/actions (e.g. std_msgs, actionlib_msgs, ...). +## * In the file package.xml: +## * add a build_depend tag for "message_generation" +## * add a build_depend and a exec_depend tag for each package in MSG_DEP_SET +## * If MSG_DEP_SET isn't empty the following dependency has been pulled in +## but can be declared for certainty nonetheless: +## * add a exec_depend tag for "message_runtime" +## * In this file (CMakeLists.txt): +## * add "message_generation" and every package in MSG_DEP_SET to +## find_package(catkin REQUIRED COMPONENTS ...) +## * add "message_runtime" and every package in MSG_DEP_SET to +## catkin_package(CATKIN_DEPENDS ...) +## * uncomment the add_*_files sections below as needed +## and list every .msg/.srv/.action file to be processed +## * uncomment the generate_messages entry below +## * add every package in MSG_DEP_SET to generate_messages(DEPENDENCIES ...) + +## Generate messages in the 'msg' folder +# add_message_files( +# FILES +# Message1.msg +# Message2.msg +# ) + +## Generate services in the 'srv' folder +#add_service_files( +# FILES +#) + +## Generate actions in the 'action' folder +# add_action_files( +# FILES +# Action1.action +# Action2.action +# ) + +## Generate added messages and services with any dependencies listed here +#generate_messages( +# DEPENDENCIES +#) + +################################################ +## Declare ROS dynamic reconfigure parameters ## +################################################ + +## To declare and build dynamic reconfigure parameters within this +## package, follow these steps: +## * In the file package.xml: +## * add a build_depend and a exec_depend tag for "dynamic_reconfigure" +## * In this file (CMakeLists.txt): +## * add "dynamic_reconfigure" to +## find_package(catkin REQUIRED COMPONENTS ...) +## * uncomment the "generate_dynamic_reconfigure_options" section below +## and list every .cfg file to be processed + +## Generate dynamic reconfigure parameters in the 'cfg' folder +generate_dynamic_reconfigure_options( + cfg/ROSChatbot.cfg + cfg/ContentManager.cfg + cfg/ROSBot.cfg +) + +################################### +## catkin specific configuration ## +################################### +## The catkin_package macro generates cmake config files for your package +## Declare things to be passed to dependent projects +## INCLUDE_DIRS: uncomment this if your package contains header files +## LIBRARIES: libraries you create in this project that dependent projects also need +## CATKIN_DEPENDS: catkin_packages dependent projects also need +## DEPENDS: system dependencies of this project that dependent projects also need +catkin_package( +# INCLUDE_DIRS include +# LIBRARIES chatbot +# CATKIN_DEPENDS rospy message_runtime +# DEPENDS system_lib +) + +########### +## Build ## +########### + +## Specify additional locations of header files +## Your package locations should be listed before other locations +include_directories( +# include +# ${catkin_INCLUDE_DIRS} +) + +## Declare a C++ library +# add_library(${PROJECT_NAME} +# src/${PROJECT_NAME}/ros_chatbot.cpp +# ) + +## Add cmake target dependencies of the library +## as an example, code may need to be generated before libraries +## either from message generation or dynamic reconfigure +# add_dependencies(${PROJECT_NAME} ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) + +## Declare a C++ executable +## With catkin_make all packages are built within a single CMake context +## The recommended prefix ensures that target names across packages don't collide +# add_executable(${PROJECT_NAME}_node src/ros_chatbot_node.cpp) + +## Rename C++ executable without prefix +## The above recommended prefix causes long target names, the following renames the +## target back to the shorter version for ease of user use +## e.g. "rosrun someones_pkg node" instead of "rosrun someones_pkg someones_pkg_node" +# set_target_properties(${PROJECT_NAME}_node PROPERTIES OUTPUT_NAME node PREFIX "") + +## Add cmake target dependencies of the executable +## same as for the library above +# add_dependencies(${PROJECT_NAME}_node ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) + +## Specify libraries to link a library or executable target against +# target_link_libraries(${PROJECT_NAME}_node +# ${catkin_LIBRARIES} +# ) + +############# +## Install ## +############# + +# all install targets should use catkin DESTINATION variables +# See http://ros.org/doc/api/catkin/html/adv_user_guide/variables.html + +## Mark executable scripts (Python etc.) for installation +## in contrast to setup.py, you can choose the destination +install(PROGRAMS + scripts/run.py + scripts/n8n_client.py + scripts/priming.py + scripts/server.py + scripts/services.py + scripts/tg.py + scripts/content_manager.py + scripts/rosbot.py + DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +) + +## Mark executables and/or libraries for installation +# install(TARGETS ${PROJECT_NAME} ${PROJECT_NAME}_node +# ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} +# LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} +# RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +# ) + +## Mark cpp header files for installation +# install(DIRECTORY include/${PROJECT_NAME}/ +# DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION} +# FILES_MATCHING PATTERN "*.h" +# PATTERN ".svn" EXCLUDE +# ) + +## Mark other files for installation (e.g. launch and bag files, etc.) +install(FILES + launch/sdk.launch + DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION} +) + +############# +## Testing ## +############# + +## Add gtest based cpp test target and link libraries +# catkin_add_gtest(${PROJECT_NAME}-test test/test_ros_chatbot.cpp) +# if(TARGET ${PROJECT_NAME}-test) +# target_link_libraries(${PROJECT_NAME}-test ${PROJECT_NAME}) +# endif() + +## Add folders to be run by python nosetests +# catkin_add_nosetests(test) diff --git a/modules/ros_chatbot/NOTICES b/modules/ros_chatbot/NOTICES new file mode 100644 index 0000000..5e57d16 --- /dev/null +++ b/modules/ros_chatbot/NOTICES @@ -0,0 +1,48 @@ +NOTICES +================================================================================ +This project includes third-party content licensed under terms compatible with +the GNU General Public License v3.0 (GPLv3). Below is a summary of these +third-party components and their license notices. + +-------------------------------------------------------------------------------- +1) PyAIML - BSD 2-Clause License + Files in src/ros_chatbot/pyaiml/ directory are derived from PyAIML originally + created by Cort Stratton and modified by Hanson Robotics. + + Copyright 2003-2010 Cort Stratton. All rights reserved. + Copyright 2015, 2016 Hanson Robotics + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + 1. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED + WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO + EVENT SHALL THE FREEBSD PROJECT OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, + INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +-------------------------------------------------------------------------------- +2) BM25 Implementation - GNU Lesser General Public License v2.1 + The file src/ros_chatbot/bm25.py is derived from code originally published + under the GNU LGPL v2.1. + + Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html + + This module contains function of computing rank scores for documents in + corpus and helper class `BM25` used in calculations. Original algorithm + described in: + - Robertson, Stephen; Zaragoza, Hugo (2009). The Probabilistic Relevance + Framework: BM25 and Beyond, + http://www.staff.city.ac.uk/~sb317/papers/foundations_bm25_review.pdf + - Okapi BM25 on Wikipedia, https://en.wikipedia.org/wiki/Okapi_BM25 + \ No newline at end of file diff --git a/modules/ros_chatbot/README.md b/modules/ros_chatbot/README.md new file mode 100644 index 0000000..fc0c805 --- /dev/null +++ b/modules/ros_chatbot/README.md @@ -0,0 +1,11 @@ +Chatbot ensemble service for ROS + +## Generate DB Schema + + alembic revision --autogenerate -m "xxx" + +## Upgrade DB Schema + + alembic upgrade head + + diff --git a/modules/ros_chatbot/app/main.py b/modules/ros_chatbot/app/main.py new file mode 100644 index 0000000..051c448 --- /dev/null +++ b/modules/ros_chatbot/app/main.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python + +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import os +import sys +from typing import Dict, List + +import coloredlogs +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel + +from ros_chatbot.chat_server import ChatServer +from ros_chatbot.utils import get_current_time_str + +server = ChatServer() +logger = logging.getLogger("hr.ros_chatbot.app.main") + + +class SessionRequest(BaseModel): + uid: str # User id + reset: bool = False # Reset the session + + +class SessionResponse(BaseModel): + sid: str + err_code: int = 0 + err_msg: str = "" + + +class ListResponse(BaseModel): + agents: List[str] + + +class ChatRequest(BaseModel): + sid: str + question: str + request_id: str = "" + time: str = get_current_time_str() + audio: str = "" # path of the audio if the request is from Speech-to-Text + tag: str = "" # tag for the conversation + lang: str = "en-US" + context: Dict = {} + mode: str = "ranking" + scene: str = "" + user_id: str = "" # graph user id + + def __hash__(self) -> int: + return hash( + ( + self.sid, + self.question, + self.request_id, + self.time, + self.audio, + self.tag, + self.lang, + str(self.context.items()), + self.mode, + self.scene, + self.user_id, + ) + ) + + +class AgentResponse(BaseModel): + sid: str + request_id: str + agent_id: str + response_id: str = "" + agent_sid: str = "" + start_dt: str = get_current_time_str() + end_dt: str = get_current_time_str() + answer: str = "" + trace: str = "" + priority: int = -1 + attachment: Dict = {} + + +class ChatResponse(BaseModel): + responses: List[AgentResponse] = [] + err_code: int = 0 + err_msg: str = "" + + +class StatusResponse(BaseModel): + err_code: int = 0 + err_msg: str = "" + + +if "coloredlogs" in sys.modules and os.isatty(2): + formatter_str = "%(asctime)s %(levelname)-7s %(name)s: %(message)s" + coloredlogs.install(logging.INFO, fmt=formatter_str) + +app = FastAPI() +app.add_middleware(CORSMiddleware, allow_origins=["*"]) + + +@app.post( + "/session", summary="Get or retrieve a session", response_model=SessionResponse +) +def session(request: SessionRequest): + response = {} + try: + sid = server.get_client_session(request.uid, request.reset) + response["sid"] = sid + except Exception as ex: + logger.error(ex) + response["sid"] = "" + response["err_code"] = 1 + response["err_msg"] = str(ex) + return response + + +@app.get( + "/agents", summary="List of installed chat agents", response_model=ListResponse +) +def list_all_installed_agents(): + agents = [] + agents.extend( + [ + "%s:%s:%s" % (agent.type, agent.id, agent.enabled) + for agent in server.agents.values() + ] + ) + return {"agents": agents} + + +@app.post("/chat", summary="Chat", response_model=ChatResponse) +def chat(request: ChatRequest): + response = {} + + try: + agent_responses = [] + for responses in server.chat_with_ranking(request): + if responses: + agent_responses.extend(responses) + logger.warning("responses %r", responses) + response["responses"] = agent_responses + except Exception as ex: + response["err_code"] = 1 + response["err_msg"] = str(ex) + logger.error(ex) + + logger.error("response %r", response) + return response + + +@app.get("/", summary="Check chat server status", response_model=StatusResponse) +def status(): + return {} diff --git a/modules/ros_chatbot/cfg/ContentManager.cfg b/modules/ros_chatbot/cfg/ContentManager.cfg new file mode 100755 index 0000000..f2268cf --- /dev/null +++ b/modules/ros_chatbot/cfg/ContentManager.cfg @@ -0,0 +1,32 @@ +#!/usr/bin/env python + +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +PACKAGE = 'ros_chatbot' + +from dynamic_reconfigure.parameter_generator_catkin import * + +gen = ParameterGenerator() + +gen.add("pull", bool_t, 0, "Pull Content", False) +gen.add("test", bool_t, 0, "Include Test Content?", False) +gen.add("automatic_pull", bool_t, 0, "Keep content up-to-date", False) +gen.add('pull_interval', int_t, 30, "The inverval in minutes for pulling content ", 30, 0, 1440) + +# package name, node name, config name +exit(gen.generate(PACKAGE, "content_manager", "ContentManager")) diff --git a/modules/ros_chatbot/cfg/ROSBot.cfg b/modules/ros_chatbot/cfg/ROSBot.cfg new file mode 100755 index 0000000..77d87f3 --- /dev/null +++ b/modules/ros_chatbot/cfg/ROSBot.cfg @@ -0,0 +1,29 @@ +#!/usr/bin/env python + +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +PACKAGE = 'ros_chatbot' + +from dynamic_reconfigure.parameter_generator_catkin import * + +gen = ParameterGenerator() + +gen.add("running", bool_t, 0, "Running", False) + +# package name, node name, config name +exit(gen.generate(PACKAGE, "ros_chatbot", "ROSBot")) diff --git a/modules/ros_chatbot/cfg/ROSChatbot.cfg b/modules/ros_chatbot/cfg/ROSChatbot.cfg new file mode 100755 index 0000000..39c6846 --- /dev/null +++ b/modules/ros_chatbot/cfg/ROSChatbot.cfg @@ -0,0 +1,79 @@ +#!/usr/bin/env python + +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +PACKAGE = 'ros_chatbot' + +from dynamic_reconfigure.parameter_generator_catkin import * + +gen = ParameterGenerator() + +character_enum = gen.enum([ + gen.const("Desdemona", str_t, "desi", "Desdemona"), + gen.const("Grace", str_t, "grace", "Grace"), + gen.const("Little_Sophia", str_t, "littlesophia", "Little Sophia"), + gen.const("Mika", str_t, "mika", "Mika"), + gen.const("Professor_Einstein", str_t, "profeinstein", "Professor Einstein"), + gen.const("Sophia", str_t, "sophia", "Sophia"), + ], "Available characters") + +gen.add("enable", bool_t, 0, "Enable chatbot", True) +gen.add("hybrid_mode", bool_t, 0, "Hybrid mode", True) +gen.add("enable_rag", bool_t, 0, "Enable RAG", False) +gen.add("enable_global_workspace_drivers", bool_t, 0, "Enable global workspace drivers (experimental)", False) +gen.add("auto_global_workspace", bool_t, 0, "Automatically turn on/off global workspace", False) +gen.add("enable_emotion_driven_response_primer", bool_t, 0, "Enable emotion driven response primer", False) +gen.add("auto_automonous_free_chat", bool_t, 0, "Automatically turn on/off free chat ", False) +gen.add("hybrid_when_idle", int_t, 0, "Seconds before going to hybrid when idle", 90, 15, 180) +gen.add("auto_fire_arf", bool_t, 0, "Automatically fire ARF!", False) + +mode_enum = gen.enum([ + gen.const("Auto", str_t, "Auto", "Auto"), + gen.const("Demo", str_t, "Demo", "Demo"), + gen.const("Stage", str_t, "Stage", "Stage"), + gen.const("Undefined", str_t, "Undefined", ""), + ], "Available robot modes") +gen.add('robot_mode', str_t, 0, "Robot mode", "Stage", edit_method=mode_enum) + +gen.add('fast_score', int_t, 60, "The minimum required response score, the smaller the faster", 60, 0, 100) +gen.add("offline_asr_free_chat", bool_t, 0, "Enable offline free chat (This will be auto enabled when there is no internet)", False) +gen.add('listen_speech' , bool_t, 0, "Listen to speech", True) +gen.add('ignore_speech_while_thinking' , bool_t, 0, "In autonomous mode do not respond to text that comes between question and answer", False) +gen.add('concat_multiple_speech' , bool_t, 0, "If speech is heard before answer is returned, concat the utterances until longer pause", False) + +gen.add("character", str_t, 0, "Chatbot Character (test)", "sophia", edit_method=character_enum) + +gen.add("min_wait_for", double_t, 0, "Wait for minimum time for agents before decision", 0.5, 0.0, 10) +gen.add("timeout", double_t, 0, "Chatbot Timeout Per Batch", 3, 0, 20) + +#gen.add("enable_interruption_controller", bool_t, 0, "Enable interruption controller (experimental)", False) +#gen.add('enable_language_switch_controller', bool_t, 0, "Enable Language switch controller (experimental, not in hybrid mode)", False) +#gen.add('enable_emotion_controller', bool_t, 0, "Enable emotion control (experimental)", False) +#gen.add('enable_monitor_controller', bool_t, 0, "Enable event monitor controller (experimental)", False) +#gen.add('enable_responsivity_controller', bool_t, 0, "Enable responsivity controller (experimental)", False) +gen.add('enable_command_controller', bool_t, 0, "Enable command controller (experimental)", True) +#gen.add('enable_user_acquisition_controller', bool_t, 0, "Enable user acquisition controller (experimental)", False) + +gen.add('enable_placeholder_utterance_controller', bool_t, 0, "Enable placeholder utterance", False) +gen.add('placeholder_prob_step', double_t, 0, "The incremental step of probability of placeholder utterances (every half second)", 0.07, 0, 1) +gen.add('placeholder_utterances', str_t, 0, "Placeholder utterances (one on a line)", "") +# schema setting for some fields +gen.add("node_schema", str_t, 0, '', '{"placeholder_utterances": {"type": "string","format": "textarea"}}') + +# package name, node name, config name +exit(gen.generate(PACKAGE, "ros_chatbot", "ROSChatbot")) diff --git a/modules/ros_chatbot/dev.sh b/modules/ros_chatbot/dev.sh new file mode 100755 index 0000000..5880f38 --- /dev/null +++ b/modules/ros_chatbot/dev.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash + +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +source ~/workspace/hrsdk_configs/scripts/_env/10-defaults.sh +export HR_OPENAI_PROXY=https://openai.hr-tools.io +export HR_CHARACTER="sophia" +export ROBOT_NAME="sophia54" +export CMS_DIR="/tmp/cms" +export HR_CHATBOT_WORLD_DIR=/home/hr/workspace/hrsdk_configs/characters/sophia/worlds/lab/ +export HR_CHATBOT_DATA_DIR=/home/hr/workspace/hrsdk_configs/characters/sophia/data/ +export SOULTALK_HOT_UPLOAD_DIR=/tmp +export HR_MODES_FILE=/home/hr/workspace/hrsdk_configs/configs/common/all/modes.yaml +roslaunch launch/sdk.launch \ No newline at end of file diff --git a/modules/ros_chatbot/launch/sdk.launch b/modules/ros_chatbot/launch/sdk.launch new file mode 100644 index 0000000..d2b093a --- /dev/null +++ b/modules/ros_chatbot/launch/sdk.launch @@ -0,0 +1,43 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/modules/ros_chatbot/package.xml b/modules/ros_chatbot/package.xml new file mode 100644 index 0000000..5da33ec --- /dev/null +++ b/modules/ros_chatbot/package.xml @@ -0,0 +1,39 @@ + + + + ros_chatbot + 0.7.0 + The Hanson Robotics ros_chatbot package + GPLv3 + Wenwei Huang + Vytas Krisciunas + Wenwei Huang + Vytas Krisciunas + catkin + rospy + std_msgs + dynamic_reconfigure + rospy + message_runtime + + + + + + + diff --git a/modules/ros_chatbot/package/common.sh b/modules/ros_chatbot/package/common.sh new file mode 100644 index 0000000..3bc738d --- /dev/null +++ b/modules/ros_chatbot/package/common.sh @@ -0,0 +1,138 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +env() { + export HR_PREFIX=/opt/hansonrobotics + export HR_BIN_PREFIX=$HR_PREFIX/bin + export HRTOOL_PREFIX=${HR_PREFIX}/hrtool + export HR_ROS_PREFIX=${HR_PREFIX}/ros + export HR_TOOLS_PREFIX=$HR_PREFIX/tools + export HR_DATA_PREFIX=$HR_PREFIX/data + export VOICE_CACHE_DIR=$HOME/.hr/tts/voice + export URL_PREFIX=https://github.com/hansonrobotics + export GITHUB_STORAGE_URL=https://raw.githubusercontent.com/hansonrobotics/binary_dependency/master + export GITHUB_STORAGE_URL2=https://$GITHUB_TOKEN@raw.githubusercontent.com/hansonrobotics/binary_dependency2/master + export VENDOR="Hanson Robotics" + export PYTHON_PKG_PREFIX=$HR_PREFIX/py2env/lib/python2.7/dist-packages + export PYTHON3_PKG_PREFIX=$HR_PREFIX/py3env/lib/python3.6/dist-packages + export ROS_PYTHON_PKG_PREFIX=$HR_ROS_PREFIX/lib/python2.7/dist-packages +} + +install_deps() { + if ! hash gem >/dev/null 2>&1; then + echo "Installing ruby-full" + sudo apt-get install ruby-full + fi + + if ! hash fpm >/dev/null 2>&1; then + gem install fpm + gem install deb-s3 + fi + + if ! hash chrpath >/dev/null 2>&1; then + echo "Installing chrpath" + sudo apt-get install chrpath + fi + + if ! hash autoconf >/dev/null 2>&1; then + echo "Installing autoconf" + sudo apt-get install autoconf + fi + + if ! hash jq >/dev/null 2>&1; then + echo "Installing jq" + sudo apt-get install jq + fi + + if [[ ! -f /usr/local/go/bin/go ]]; then + echo "Installing go" + wget https://dl.google.com/go/go1.14.2.linux-amd64.tar.gz -O /tmp/go1.14.2.linux-amd64.tar.gz + sudo tar -C /usr/local -xzf /tmp/go1.14.2.linux-amd64.tar.gz + fi + + export PATH=/usr/local/go/bin:$PATH +} + +COLOR_INFO='\033[32m' +COLOR_WARN='\033[33m' +COLOR_ERROR='\033[31m' +COLOR_RESET='\033[0m' +info() { + printf "${COLOR_INFO}[INFO] ${1}${COLOR_RESET}\n" >&2 +} +warn() { + printf "${COLOR_WARN}[WARN] ${1}${COLOR_RESET}\n" >&2 +} +error() { + printf "${COLOR_ERROR}[ERROR] ${1}${COLOR_RESET}\n" >&2 +} + +source_ros() { + local ros_dists=(noetic melodic kinetic indigo) + for ros_dist in ${ros_dists[@]}; do + if [[ -e /opt/ros/$ros_dist/setup.bash ]]; then + info "ROS distribution $ros_dist" + source /opt/ros/$ros_dist/setup.bash + return + fi + done +} + +add_control_scripts() { + local root_dir=${1:-${PACKAGE_DIR}/control} + local preinst="${root_dir}/preinst.sh" + local postinst="${root_dir}/postinst.sh" + local prerm="${root_dir}/prerm.sh" + local postrm="${root_dir}/postrm.sh" + + local ms="" + [[ -f ${preinst} ]] && ms="$ms --before-install ${preinst}" + [[ -f ${postinst} ]] && ms="$ms --after-install ${postinst}" + [[ -f ${prerm} ]] && ms="$ms --before-remove ${prerm}" + [[ -f ${postrm} ]] && ms="$ms --after-remove ${postrm}" + + if [[ -z $ms ]]; then + echo "Empty maintainer scripts" + return 1 + fi + echo $ms +} + +cleanup_ros_package_build() { + # clean up + pushd $1 >/dev/null + rm -r src build_isolated devel_isolated .catkin_workspace install + popd >/dev/null +} + +get_version() { + local date=$(date +%Y%m%d%H%M%S) + local version_file=$BASEDIR/src/$reponame/version + local tag=$(git describe --tags --candidates=0) + if [[ -f $version_file ]]; then + version=$(head -n 1 $version_file) + # if 1 is present or the latest tag equals to version + if [[ $1 != 1 && ${tag#v} != $version ]]; then + version=${version}-${date} + fi + else + version=$date + fi +} + +env +install_deps diff --git a/modules/ros_chatbot/package/package.sh b/modules/ros_chatbot/package/package.sh new file mode 100755 index 0000000..8bc39e0 --- /dev/null +++ b/modules/ros_chatbot/package/package.sh @@ -0,0 +1,53 @@ +#!/usr/bin/env bash + +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +package() { + local reponame=ros_chatbot + + mkdir -p $BASEDIR/src + rsync -r --delete \ + --exclude ".git" \ + --exclude "package" \ + $BASEDIR/../ $BASEDIR/src/$reponame + + get_version $1 + source_ros + catkin_make_isolated --directory $BASEDIR --install --install-space $BASEDIR/install -DCMAKE_BUILD_TYPE=Release + + local name=head-ros-chatbot + local desc="SoulTalk ROS API" + local url="https://api.github.com/repos/hansonrobotics/$reponame/releases" + + fpm -C "${BASEDIR}" -s dir -t deb -n "${name}" -v "${version#v}" --vendor "${VENDOR}" \ + --url "${url}" --description "${desc}" ${ms} --force \ + --deb-no-default-config-files \ + -p $BASEDIR/${name}_VERSION_ARCH.deb \ + install/share=${HR_ROS_PREFIX}/ \ + install/lib=${HR_ROS_PREFIX}/ + + cleanup_ros_package_build $BASEDIR +} + +if [[ $(readlink -f ${BASH_SOURCE[0]}) == $(readlink -f $0) ]]; then + BASEDIR=$(dirname $(readlink -f ${BASH_SOURCE[0]})) + source $BASEDIR/common.sh + set -e + + package $1 +fi diff --git a/modules/ros_chatbot/requirements.txt b/modules/ros_chatbot/requirements.txt new file mode 100644 index 0000000..c318184 --- /dev/null +++ b/modules/ros_chatbot/requirements.txt @@ -0,0 +1,43 @@ +aiofiles +airtable-python-wrapper==0.15.2 +boto3 +bs4 +chainmap==1.0.3 +coloredlogs==14.0 + +emoji +fastapi +flask==0.10.1 +gevent==20.6.2 +google-cloud-dialogflow==2.0.0 +google-cloud-translate==1.3.1 +googlemaps==2.5.1 +grpcio + +##### system dependencies +# ros-noetic-ddynamic-reconfigure-python + +langchain-anthropic +langchain-aws +langchain-community +langchain-openai>=0.1.25 + +lxml +mysqlclient==1.4.6 +numpy==1.21.2 + +opencv-python==4.9.0.80 +openpyxl==3.0.7 +pandas==1.2.5 + +pydantic>=2.0,<3.0 +pytest==4.6.11 +python-benedict>=0.24.1 +python-socketio~=5.5.2 +PyYAML==5.3.1 +rospkg~=1.2.3 +sqlalchemy==1.3.18 +telethon==1.21.1 +transitions==0.9.0 +uvicorn +websocket-client==0.57.0 diff --git a/modules/ros_chatbot/scripts/analysis.py b/modules/ros_chatbot/scripts/analysis.py new file mode 100644 index 0000000..0649966 --- /dev/null +++ b/modules/ros_chatbot/scripts/analysis.py @@ -0,0 +1,150 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import time + +from haipy.memory_manager.memory_model import search_scenes_by_conditions +from haipy.memory_manager.scene_postprocessing import update_scene_memory_models +from haipy.scheduler.intention_manager import ( + extract_driver_metrics, + plot_metrics, + search_drivers, +) +from haipy.scheduler.schemas.enums import DriverStatus +from rich.console import Console +from rich.panel import Panel +from rich.table import Table + +# Create a console instance for rich output +console = Console() + + +def plot(): + for driver_type in [ + "GoalDriver", + "EmotionalDriver", + "PhysiologicalDriver", + "InterestDriver", + "DeepDriver", + ]: + drivers = search_drivers(driver_type=driver_type) + metrics_df = extract_driver_metrics(drivers) + console.print(f"[bold blue]Metrics for {driver_type}:[/bold blue]") + console.print(metrics_df) + plot_metrics(driver_type, metrics_df) + + time.sleep(1000) + + +def test(driver_type: str, lookback_hours: int = 72): + console.print( + f"[bold green]Searching for drivers of type: [/bold green][yellow]{driver_type}[/yellow]" + ) + drivers = search_drivers( + driver_type=driver_type, + lookback_hours=lookback_hours, + ) + metrics_df = extract_driver_metrics(drivers) + console.print( + Panel(str(metrics_df), title=f"Metrics for {driver_type}", border_style="cyan") + ) + + for driver in drivers: + panel_content = ( + f"[bold]ID:[/bold] [cyan]{driver.id}[/cyan]\n" + f"[bold]Created at:[/bold] [cyan]{driver.created_at}[/cyan]\n" + f"[bold]Updated at:[/bold] [cyan]{driver.updated_at}[/cyan]\n" + f"[bold]Type:[/bold] [magenta]{driver.type}[/magenta]\n" + f"[bold]Conversation IDs:[/bold] [yellow]{driver.conversation_ids}[/yellow]\n" + f"[bold]Metrics:[/bold] {driver.context.metrics.model_dump()}\n" + f"[bold]Valence:[/bold] [green]{driver.context.valence}[/green]\n" + f"[bold]Status:[/bold] [red] {driver.context.status}[/red]\n" + f"[bold]Name:[/bold] [blue]{driver.name}[/blue]" + ) + + if driver.type == "GoalDriver": + panel_content += ( + f"\n[bold]Deadline:[/bold] [yellow]{driver.context.deadline}[/yellow]" + ) + + console.print( + Panel( + panel_content, + title="Driver Info", + border_style="green", + ) + ) + + if driver.conversation_ids: + i = 0 + for conversation_id in driver.conversation_ids: + scenes = search_scenes_by_conditions(conversation_id=conversation_id) + if scenes: + console.print("[bold purple]Chat History:[/bold purple]") + for scene in scenes: + if not scene.chat_history.messages: + update_scene_memory_models([scene]) + if scene.chat_history.messages: + console.print( + Panel( + "\n".join( + [ + f"[{'green' if message.role == 'ai' else 'blue' if message.role == 'human' else 'white'}]{message.role.title()}: {message.text}[/]" + for message in scene.chat_history.messages + ] + ), + title=f"Scene {i+1}", + border_style="yellow", + ) + ) + i += 1 + + +def test2(): + console.print("[bold green]Searching for ACTIVE and PENDING drivers[/bold green]") + drivers = search_drivers(statuses=[DriverStatus.ACTIVE, DriverStatus.PENDING]) + + table = Table(title="Active and Pending Drivers") + table.add_column("Created At", style="cyan") + table.add_column("Type", style="magenta") + table.add_column("Conversation IDs", style="yellow") + table.add_column("Valence", style="green") + table.add_column("Name", style="blue") + + for driver in drivers: + table.add_row( + str(driver.created_at), + driver.type, + ", ".join(driver.conversation_ids), + str(driver.context.valence), + driver.name, + ) + + console.print(table) + + +if __name__ == "__main__": + import os + + from haipy.memory_manager import init as mm_init + from haipy.scheduler import init + + init(os.environ["CLOUD_MONGO_DATABASE_URL"]) + mm_init(os.environ["CLOUD_MONGO_DATABASE_URL"]) + # test("EmotionalDriver") + test("GoalDriver", lookback_hours=None) + # test2() diff --git a/modules/ros_chatbot/scripts/chat_agent_example.py b/modules/ros_chatbot/scripts/chat_agent_example.py new file mode 100755 index 0000000..038e90d --- /dev/null +++ b/modules/ros_chatbot/scripts/chat_agent_example.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python + +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import rospy +import random +from hr_msgs.srv import AgentRegister, AgentUnregister, AgentChat, StringArray + +logger = logging.getLogger("hr.ros_chatbot.chatbot_agent_example") + + +class ChatAgentExample(object): + def __init__(self) -> None: + super().__init__() + rospy.Service("~chat", AgentChat, self._callback) + + def _callback(self, req): + response = AgentChat._response_class() + response.state = random.choice([0, 1]) + response.answer = req.text + logger.warning(response) + return response + + def register(self): + register_service = rospy.ServiceProxy("/hr/interaction/register", AgentRegister) + request = AgentRegister._request_class() + request.node = rospy.get_name() + request.level = 100 + request.ttl = 20 + response = register_service(request) + logger.warning(response) + + +if __name__ == "__main__": + rospy.init_node("chatbot_agent_example") + agent = ChatAgentExample() + agent.register() + rospy.spin() diff --git a/modules/ros_chatbot/scripts/content_manager.py b/modules/ros_chatbot/scripts/content_manager.py new file mode 100755 index 0000000..ff5b0a9 --- /dev/null +++ b/modules/ros_chatbot/scripts/content_manager.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python + +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import os +import threading +import time + +import requests +import rospy +from dynamic_reconfigure.server import Server +from haipy.arf.airtables import OperationSceneTable, RobotOperationContentBase +from haipy.memory_manager import PromptTemplate, copy_collection +from haipy.s3sync import s3sync +from haipy.utils import clean_dir, dump_yaml +from std_msgs.msg import String +from std_srvs.srv import Trigger + +from ros_chatbot.cfg import ContentManagerConfig + +logger = logging.getLogger("hr.ros_chatbot.content_manager") + +HR_CHARACTER = os.environ["HR_CHARACTER"] +ROBOT_NAME = os.environ["ROBOT_NAME"] +CMS_BUCKET = os.environ.get("CMS_BUCKET", "s3://dl.cms.hansonrobotics.com") +CMS_DIR = os.environ["CMS_DIR"] # eg. /hr/.hrsdk/cms_content/sophia-sophia +SUCCESS_FILE = os.path.join(CMS_DIR, ".success") + + +class ContentManager(object): + def __init__(self): + self.include_test = None + self.content_update_pub = rospy.Publisher( + "~update_event", String, queue_size=10, latch=True + ) + + self.cfg = None + self._lock = threading.RLock() + self._pulling = threading.Event() + + RobotOperationContentBase("/tmp").monitor_scenes_update( + HR_CHARACTER, self._pull_operation_scenes + ) + + threading.Thread(target=self._automatic_pull, daemon=True).start() + + def _pull_operation_scenes(self, scenes_table: OperationSceneTable): + scenes = [] + for record in scenes_table.records: + logger.warning("Scene Name: %s", record.fields.Name) + scenes.append(record.model_dump()) + operation_scenes_file = os.path.join(CMS_DIR, "airtable-operation-scenes.yaml") + dump_yaml(scenes, operation_scenes_file) + logger.warning("Dumped operation scenes to %s", operation_scenes_file) + self.content_update_pub.publish("updated") + + def _automatic_pull(self): + while True: + if not os.path.isfile(SUCCESS_FILE) and not self._pulling.is_set(): + # pull content if it is empty + logger.warning("Automatic content pull") + self._pull() + + if self.cfg is not None and self.cfg.automatic_pull: + self._pull() + time.sleep(self.cfg.pull_interval * 60) + time.sleep(1) + + def _pull_helper(self): + logger.warning( + "Updating content, including test content? %s", + "Yes" if self.include_test else "No", + ) + if not os.path.isdir(CMS_DIR): + os.makedirs(CMS_DIR) + + # sync character content & robot content + success = True + clean_dir(CMS_DIR, keep_file=["airtable-operation-scenes.yaml"]) + for content_type, name in zip( + ["characters", "robots"], [HR_CHARACTER, ROBOT_NAME] + ): + if self.include_test: + s3_source = os.path.join(CMS_BUCKET, "dist", "test", content_type, name) + else: + s3_source = os.path.join(CMS_BUCKET, "dist", "prod", content_type, name) + ret = s3sync(s3_source, CMS_DIR, delete=False) + if ret != 0: + logger.warning("Content synchronization failed") + success = False + + # copy prompt templates collection + copy_collection( + os.environ["CLOUD_MONGO_DATABASE_URL"], + os.environ["LOCAL_MONGO_DATABASE_URL"], + model=PromptTemplate, + drop=True, + ) + logger.info("Copied prompt templates") + + if success: + with open(SUCCESS_FILE, "w") as f: + f.write(str(time.time())) + self._reload() + self.content_update_pub.publish("updated") + + def _pull(self): + with self._lock: + self._pulling.set() + try: + self._pull_helper() + finally: + self._pulling.clear() + + def _reload(self): + self.content_update_pub.publish("reloading") + try: + # self._reload_behavior_tree() + self._reload_soultalk() + except Exception as ex: + logger.error("Error during reloading: %s", ex) + else: + logger.info("Reloading content finished successfully") + finally: + self.content_update_pub.publish("reloaded") + + def _reload_behavior_tree(self): + try: + reload_trees = rospy.ServiceProxy( + "/hr/behavior/interactive_fiction/reload_trees", Trigger + ) + reload_trees.wait_for_service(timeout=2) + reload_trees() + logger.warning("Behavior tree has been reloaded") + except Exception as ex: + logger.error("Error reloading behavior tree: %s", ex) + raise + + def _reload_soultalk(self): + try: + soultalk_url = os.environ.get("SOULTALK_SERVER_HOST", "127.0.0.1") + soultalk_port = os.environ.get("SOULTALK_SERVER_PORT", "8801") + response = requests.post( + f"http://{soultalk_url}:{soultalk_port}/reset", + json={"uid": "default", "sid": "default", "reload": True}, + ) + logger.info("Reset soultalk %s", response) + except Exception as ex: + logger.error("Error reloading soultalk: %s", ex) + raise + + def reconfig(self, config, level): + self.cfg = config + try: + if self.include_test is None or self.include_test != config.test: + self.include_test = config.test + self._pull() + elif config.pull: + self._pull() + except Exception as ex: + logger.error("Error during reconfiguration: %s", ex) + finally: + config.pull = False + self.include_test = config.test + return config + + +if __name__ == "__main__": + rospy.init_node("content_manager") + node = ContentManager() + Server(ContentManagerConfig, node.reconfig) + while not rospy.is_shutdown(): + rospy.spin() diff --git a/modules/ros_chatbot/scripts/cypher/query.cypher b/modules/ros_chatbot/scripts/cypher/query.cypher new file mode 100644 index 0000000..44f5e3f --- /dev/null +++ b/modules/ros_chatbot/scripts/cypher/query.cypher @@ -0,0 +1,19 @@ +// +// Copyright (C) 2017-2025 Hanson Robotics +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . +// + +// default scene +MATCH (n:Session {scene: "default"})-[r:P_PARTICIPATE_IN]-(p:Person) RETURN n, r, p diff --git a/modules/ros_chatbot/scripts/cypher/sample_graph.cypher b/modules/ros_chatbot/scripts/cypher/sample_graph.cypher new file mode 100644 index 0000000..698a58b --- /dev/null +++ b/modules/ros_chatbot/scripts/cypher/sample_graph.cypher @@ -0,0 +1,24 @@ +// +// Copyright (C) 2017-2025 Hanson Robotics +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . +// + +// create character nodes +CREATE (Sophia:Character:BasePerson {uid: "sophia", name:"Sophia the Robot", born:2016, languages: ["English, Mandarin"]}) +// create person nodes +CREATE (Ben:Person:BasePerson {uid: "ben", name:"Ben Goertzel", born:1966, languages: ["English"], gender: "male"}) +CREATE (DavidH:Person:BasePerson {uid: "davidh", name:"David Hanson", born:1969, languages: ["English"], gender: "male"}) +// create session nodes +CREATE (SessionWithDavidH:Session {uid: "1", time: 1612775757, series: "test series", episode: "test episode", act: "test act", scene: "test scene"}) diff --git a/modules/ros_chatbot/scripts/n8n_client.py b/modules/ros_chatbot/scripts/n8n_client.py new file mode 100755 index 0000000..1a13202 --- /dev/null +++ b/modules/ros_chatbot/scripts/n8n_client.py @@ -0,0 +1,498 @@ +#!/usr/bin/env python + +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import rospy +import websocket +import random +import json +import threading +import time +import logging +from std_msgs.msg import String +from hr_msgs.msg import ChatMessage, TTS, ChatResponse, Event, ChatResponses +from hr_msgs.srv import AgentChat, AgentFeedback, RunByName, RunByNameRequest, RunByNameResponse +from dynamic_reconfigure.client import Client + +# WebSocket configuration +WEBSOCKET_SERVER_URL = "wss://n8n-dev.hr-tools.io/wss" + +# ROS topic to subscribe to +ROS_SUBSCRIBE_TOPIC = "/hr/interaction/n8n/robot_events" +ROS_PUBLISH_TOPIC = "/hr/interaction/n8n/received_events" +ROS_HEAR_EVENT_TOPIC = "/hr/perception/hear/sentence" +ROS_SAID_EVENT_TOPIC = "/hr/control/speech/say" +SESSION_ID_PARAM = "/hr/interaction/n8n/session_id" +SCENE_CONFIG_TOPIC = "/hr/interaction/prompts/scene" +CHAT_RESPONSES_TOPIC = "/hr/interaction/n8n/responses" +PERFORMANCES_SERVICE = "/hr/control/performances/background/run_by_name" +PERFORMANCES_STATUS_TOPIC = "/hr/control/performances/background/events" +PERFORMANCES_RUNNING_TOPIC = "/hr/control/performances/background/running_performance" +CHATBOT_SUGGESTIONS = "/hr/interaction/chatbot_responses" +RPS_TOPIC="/hr/perception/rps_result" +ROS_AGENT_SETTINGS="/hr/interaction/agents/n8n" +VISUAL_PROMPTS="/hr/interaction/prompts/visual_processing" + + +# Configure logging +logger = logging.getLogger('hr.ros_chatbot.n8n') +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +handler.setFormatter(formatter) +logger.addHandler(handler) + +class WebSocketROSNode: + def __init__(self): + # Initialize ROS node + rospy.init_node('websocket_ros_node', anonymous=True) + # Get or create session_id parameter + if not rospy.has_param(SESSION_ID_PARAM): + rospy.set_param(SESSION_ID_PARAM, "session_{}".format(rospy.Time.now().to_nsec())) + self.metaprompting_allowed = False + self.current_scene = None + self.last_scene_msg = None + self.scenes = None + self.session_id = rospy.get_param(SESSION_ID_PARAM) + self.robot_name = rospy.get_param("/hr/robot_name") + self.robot_body = rospy.get_param("/hr/robot_body") + self.robot_character = rospy.get_param("/hr/character") + # ROS subscriber + self.subscriber = rospy.Subscriber(ROS_SUBSCRIBE_TOPIC, String, self.on_ros_message_received) + + # ROS hear event subscriber + self.hear_event_subscriber = rospy.Subscriber(ROS_HEAR_EVENT_TOPIC, ChatMessage, self.on_hear_event_received) + + # ROS said event subscriber + self.said_event_subscriber = rospy.Subscriber(ROS_SAID_EVENT_TOPIC, TTS, self.on_said_event_received) + + # ROS publisher + self.publisher = rospy.Publisher(ROS_PUBLISH_TOPIC, String, queue_size=10) + # Chat responses to chatbot + self.chat_responses = rospy.Publisher(CHAT_RESPONSES_TOPIC, ChatResponse, queue_size=10) + # Initialize WebSocket client + self.ws = None + self.ws_lock = threading.Lock() + self.ws_connected = threading.Event() + self.ws_thread = threading.Thread(target=self.connect_ws) + self.ws_thread.daemon = True + self.ws_thread.start() + self.last_chatbot_session = None + # Wait for WebSocket to be connected + logger.info("Waiting for WebSocket connection to be established...") + self.ws_connected.wait(timeout=2) + if not self.ws_connected.is_set(): + logger.warning("WebSocket connection not established within the timeout period.") + else: + logger.info("WebSocket connection established.") + self.enabled= None + # Start ping thread + self.ping_thread = threading.Thread(target=self.ping_ws) + self.ping_thread.daemon = True + self.ping_thread.start() + self.n8n_agent_configs = Client(ROS_AGENT_SETTINGS, config_callback=self.n8n_cfg_callback) + while self.enabled is None: + # wait for callback, so make sure the scene update after + time.sleep(0.1) + # Dynamic reconfigure client for scene config + self.robot_connected() + self.scene_client = Client(SCENE_CONFIG_TOPIC, config_callback=self.on_scene_config_update) + self.current_scene = None + # Visual queues + self.visual_prompts_configs = Client(VISUAL_PROMPTS, config_callback=self.on_visual_config_update) + + + # ROS services + self.agent_chat_service = rospy.Service('/hr/interaction/n8n/chat', AgentChat, self.handle_agent_chat) + self.agent_feedback_service = rospy.Service('/hr/interaction/n8n/feedback', AgentFeedback, self.handle_agent_feedback) + self.performance_runner = rospy.ServiceProxy(PERFORMANCES_SERVICE, RunByName) + + # Topics + self.last_perfromance_event = None + self.idle_performance_event = threading.Event() + self.performance_events_sub = rospy.Subscriber(PERFORMANCES_STATUS_TOPIC, Event, self.performance_event_callback) + # RPS_reselts + self.rps_topic = rospy.Subscriber(RPS_TOPIC, String, self.rps_callback) + + # Suggested responses + self.suggested_responses = rospy.Publisher(CHATBOT_SUGGESTIONS, ChatResponses) + # TTS + self.tts_pub = rospy.Publisher(ROS_SAID_EVENT_TOPIC, TTS) + + + def n8n_cfg_callback(self, cfg, lvl=None): + if cfg.enabled and self.enabled is False: + self.change_session() + try: + self.metaprompting_allowed = cfg.allow_metaprompting + except Exception: + self.metaprompting_allowed = False + self.enabled = cfg.enabled + return cfg + + def robot_connected(self): + self.send_to_ws({ + "event_type": "RobotConnected", + "robot_name": self.robot_name, + "robot_body": self.robot_body, + "robot_character": self.robot_character, + "session_id": self.session_id + }) + + def connect_ws(self): + headers = { + "X-API-Key": 'n8n-Top-Secret-code-to-be-never-used' + } + while not rospy.is_shutdown(): + try: + logger.info("Connecting to WebSocket server...") + self.ws = websocket.WebSocketApp( + WEBSOCKET_SERVER_URL, + header=[f"{key}: {value}" for key, value in headers.items()], + on_message=self.on_ws_message, + on_open=self.on_ws_open, + on_close=self.on_ws_close, + on_error=self.on_ws_error, + on_ping=self.on_ws_ping + ) + self.ws.run_forever() + time.sleep(2) + except Exception as e: + logger.error("WebSocket connection failed: {}. Retrying in 5 seconds...".format(e)) + + def send_to_ws(self, message): + if not self.enabled: + return True + try: + with self.ws_lock: + if self.ws and self.ws.sock and self.ws.sock.connected: + if "session_id" not in message: + message["session_id"] = self.session_id + self.ws.send(json.dumps(message)) + return True + except Exception as e: + logger.warning("Failed to send message to WebSocket: {}".format(e)) + return False + + def on_ws_open(self, ws): + logger.warn("Connected to WebSocket server") + self.ws_connected.set() + self.robot_connected() + + def on_ws_close(self, ws, close_status_code, close_msg): + logger.warning("WebSocket connection closed. Reconnecting...") + self.ws_connected.clear() + + def on_ws_error(self, ws, error): + logger.error("WebSocket error: {}".format(error)) + self.ws_connected.clear() + + def on_ws_message(self, ws, message): + if not self.enabled: + return None + logger.info("Received message from WebSocket: {}".format(message)) + msg = json.loads(message) + if not 'event_type' in msg: + pass + if msg['event_type'] == 'ChatResults': + self.publish_chat_response(msg) + if msg['event_type'] == 'Performance': + self.run_performance_from_message(msg) + if msg['event_type'] == 'TTS': + self.publish_tts(msg) + if msg['event_type'] == 'UpdateScene': + self.update_current_scene(msg) + if msg['event_type'] == 'SetRobotSettings': + self.set_robot_params(msg) + + + def update_current_scene(self, msg): + # Do not update anything unless there is a an issue + if not self.metaprompting_allowed: + return + new_scene = msg.get('scene', {}) + changed = False + if not self.current_scene or not self.scenes: + return + + for scene in self.scenes: + if scene.get('name') != self.current_scene: + continue + for key, value in new_scene.items(): + if key in scene: + scene[key] = value + changed = True + if changed: + self.scene_client.update_configuration({'scenes': json.dumps(self.scenes)}) + + + def on_ws_ping(self, ws, message): + logger.info("Received ping from WebSocket server. Sending pong response.") + try: + ws.send('PONG') + except Exception as e: + pass + + def ping_ws(self): + while not rospy.is_shutdown(): + if self.ws and self.ws.sock and self.ws.sock.connected: + try: + logger.info("Sending ping to WebSocket server") + self.ws.sock.ping() + except Exception as e: + logger.warning("Ping failed: {}".format(e)) + time.sleep(30) + + def on_ros_message_received(self, ros_message): + logger.info("Received message from ROS topic: {}".format(ros_message.data)) + try: + # Convert ROS message to JSON and add session_id, then send it via WebSocket + event = json.loads(ros_message.data) + event["session_id"] = self.session_id + if not self.send_to_ws(event): + logger.warning("WebSocket is not connected. Cannot send message.") + except json.JSONDecodeError: + logger.warning("Failed to parse ROS message as JSON") + + def on_hear_event_received(self, hear_event: ChatMessage): + logger.info("Received hear event: utterance='{}', lang='{}', confidence={}, source='{}'".format( + hear_event.utterance, hear_event.lang, hear_event.confidence, hear_event.source)) + try: + # Convert hear event to JSON and add session_id, then send it via WebSocket + utterance = hear_event.utterance.strip() + if utterance == ':reset': + self.change_session() + return + + + event = { + "event_type": "RobotHears", + "utterance": hear_event.utterance, + "lang": hear_event.lang, + "confidence": hear_event.confidence, + "source": hear_event.source, + "audio_path": hear_event.audio_path, + "session_id": self.session_id + } + if not self.send_to_ws(event): + logger.warning("WebSocket is not connected. Cannot send hear event.") + else: + logger.info("Sent hear event to WebSocket server: {}".format(event)) + except Exception as e: + logger.warning("Failed to process hear event: {}".format(e)) + + def on_said_event_received(self, said_event): + logger.info("Received said event: text='{}', lang='{}', request_id='{}', agent_id='{}'".format( + said_event.text, said_event.lang, said_event.request_id, said_event.agent_id)) + try: + # Convert said event to JSON and add session_id, then send it via WebSocket + event = { + "event_type": "RobotSays", + "text": said_event.text, + "lang": said_event.lang, + "request_id": said_event.request_id, + "agent_id": said_event.agent_id, + "audio_path": said_event.audio_path, + "session_id": self.session_id + } + if not self.send_to_ws(event): + logger.warning("WebSocket is not connected. Cannot send said event.") + else: + logger.info("Sent said event to WebSocket server: {}".format(event)) + except Exception as e: + logger.warning("Failed to process said event: {}".format(e)) + + + def on_scene_config_update(self, config): + try: + current_scene_name = config.get('current', None) + scenes_json = config.get('scenes', None) + if current_scene_name and scenes_json: + scenes = json.loads(scenes_json) + self.current_scene = current_scene_name + self.scenes = scenes + found_scene = False + for scene in scenes: + if scene.get('name') == current_scene_name: + self.current_scene = current_scene_name + found_scene = True + logger.info("Updated current scene: {}".format(self.current_scene)) + # Send current scene to WebSocket + event = { + "event_type": "CurrentSceneUpdate", + "current_scene_name": self.current_scene, + "session_id": self.session_id, + "scene": scene + } + self.last_scene_msg = event + if not self.send_to_ws(event): + logger.warning("WebSocket is not connected. Cannot send current scene update.") + else: + logger.info("Sent current scene to WebSocket server: {}".format(event)) + break + if not found_scene: + logger.warning("Current scene '{}' not found in scenes list.".format(current_scene_name)) + else: + logger.warning("Missing parameters in scene config update.") + except Exception as e: + logger.warning("Failed to update scene config: {}".format(e)) + + def on_visual_config_update(self, cfg): + if not cfg.results: + return + event = { + "session_id": self.session_id, + "event_type": "LLMVision", + "visual_prompt": cfg.visual_prompt, + "result_time": cfg.result_time, + "results": cfg.results + } + if not self.send_to_ws(event): + logger.warning("WebSocket is not connected. Cannot send vision update.") + else: + logger.info("Sent current vision data to WebSocket server: {}".format(event)) + + def handle_agent_chat(self, req): + logger.info("Handling AgentChat request: text='{}', lang='{}', session='{}'".format(req.text, req.lang, req.session)) + if self.last_chatbot_session is None: + self.last_chatbot_session = req.session + if self.last_chatbot_session != req.session: + # changes session + self.change_session() + + event = { + "event_type": "ChatRequest", + "text": req.text, + "lang": req.lang, + "request_id": req.request_id, + "session_id": self.session_id, + "response_prompt": rospy.get_param("/hr/interaction/prompts/response_prompt") + } + if not self.send_to_ws(event): + logger.warning("WebSocket is not connected. Cannot send agent chat event.") + # Example response generation (to be replaced with actual logic) + response = AgentChat._response_class() + response.state = 1 # Example state + response.score = 0.95 # Example score + return response + + def handle_agent_feedback(self, req): + logger.info("Handling AgentFeedback request: request_id='{}', hybrid={}, chosen={}".format(req.request_id, req.hybrid, req.chosen)) + # Example response generation (to be replaced with actual logic) + response = AgentFeedback._response_class() + response.success = True # Example success status + return response + + def publish_chat_response(self, response): + msg = ChatResponse() + msg.text = response['text'] + msg.lang = response['lang'] + msg.label = response['label'] + msg.request_id = response['request_id'] + self.chat_responses.publish(msg) + + def performance_event_callback(self, event): + if event.event == 'idle': + self.idle_performance_event.set() + + def run_performance_from_message(self, message): + hybrid = rospy.get_param("/hr/interaction/chatbot/hybrid_mode", False) + performance = message['performance'] + if hybrid: + suggestion = ChatResponse() + suggestion.text = f"|t, {performance}|" + suggestion.lang = 'en-US' + suggestion.label = 'n8n' + self.suggested_responses.publish(ChatResponses(responses=[suggestion])) + # Probably not very accurate, but only thing we can do is push performance to operator. + self.send_to_ws({ + "event_type": "PerformanceFinished", + "performance": performance + }) + else: + t = threading.Thread(target=self.run_performance, daemon=True, args=(performance,)) + t.start() + + def run_performance(self, performance): + # runs performance and blocks until it finished + try: + r = RunByNameRequest() + r.id = performance + result = self.performance_runner.call(r) + if not result.success: + return False + self.idle_performance_event.clear() + self.idle_performance_event.wait() + self.send_to_ws({ + "event_type": "PerformanceFinished", + "performance": performance + }) + except Exception: + pass + + def rps_callback(self, result): + msg = { + 'event_type': 'RPSResult', + 'result': result.data + } + self.send_to_ws(msg) + + def publish_tts(self, msg): + hybrid = rospy.get_param("/hr/interaction/chatbot/hybrid_mode", False) + if not hybrid: + ros_msg = TTS() + ros_msg.text = msg['text'] + ros_msg.lang = msg['lang'] + ros_msg.agent_id = 'n8n' + self.tts_pub.publish(ros_msg) + else: + suggestion = ChatResponse() + suggestion.text = msg['text'] + suggestion.lang = msg['lang'] + suggestion.label = 'n8n' + self.suggested_responses.publish(ChatResponses(responses=[suggestion])) + + def change_session(self): + rospy.set_param(SESSION_ID_PARAM, "session_{}".format(rospy.Time.now().to_nsec())) + self.session_id = rospy.get_param(SESSION_ID_PARAM) + if self.last_scene_msg: + self.last_scene_msg['session_id'] = self.session_id + self.send_to_ws(self.last_scene_msg) + + def set_robot_params(self, msg): + node = msg.get('node') + params = msg.get('params', {}) + timeout = msg.get('timeout', 1) + if not self.set_dyn_params(node, params, timeout): + logger.warning("Failed to set dynamic parameters for node '{}'".format(node)) + + def set_dyn_params(self, node: str, params: dict,timeout: int = 1): + try: + client = Client(node, timeout=timeout) + client.update_configuration(params) + return True + except Exception as e: + return False + +if __name__ == '__main__': + try: + # Create and start WebSocket ROS node + ws_ros_node = WebSocketROSNode() + rospy.spin() + except rospy.ROSInterruptException: + logger.warning("ROS node interrupted") diff --git a/modules/ros_chatbot/scripts/priming.py b/modules/ros_chatbot/scripts/priming.py new file mode 100755 index 0000000..1e83ebb --- /dev/null +++ b/modules/ros_chatbot/scripts/priming.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python + +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import json +import random + +import rospy +from std_srvs.srv import Trigger, TriggerResponse + +from ros_chatbot.ddr_node import DDRNode + + +class TriggerServiceNode: + def __init__(self): + rospy.init_node("priming_service") + rospy.Service( + "/hr/interaction/prompts/response_prompt", + Trigger, + self.trigger_service_callback, + ) + self.response_primers = [ + { + "enabled": True, + "name": "funny", + "prompt": "Respond to below question in very funny way", + "probability": 0.5, + }, + { + "enabled": True, + "name": "serious", + "prompt": "Respond to asnwer as a serious robot.", + "probability": 0.5, + }, + ] + self.length_primers = [ + { + "enabled": True, + "name": "Medium", + "prompt": "Respond in 2-3 sentences", + "probability": 0.5, + }, + { + "enabled": True, + "name": "Smart", + "prompt": "If question is short respond in 1 sentence. For more elaborate questions respond in multiple sentences", + "probability": 0.7, + }, + { + "enabled": True, + "name": "Long", + "prompt": "Respond in long question", + "probability": 0.2, + }, + ] + self.current_reponse_primer = "auto" + self.current_length_primer = "auto" + self.start_ddrs() + + def start_ddr(self, label="response"): + node = DDRNode( + namespace=f"/hr/interaction/prompts/{label}", + callback=lambda config, level: self.update_prompts( + config, label == "response" + ), + ) + node.new_param("current", f"Current {label} primer", default="auto") + primers = self.response_primers if label == "response" else self.length_primers + node.new_param("prompts", "Current prompts", default=json.dumps(primers)) + node.new_param("node_schema", "node_schema", default=self.node_schema(primers)) + node.ddstart() + return node + + def start_ddrs(self): + self.response_ddr = self.start_ddr("response") + self.length_ddr = self.start_ddr("length") + + def update_prompts(self, config, is_response): + prompts = json.loads(config["prompts"]) + if is_response: + self.current_reponse_primer = config["current"] + self.response_primers = prompts + else: + self.current_length_primer = config["current"] + self.length_primers = prompts + # config.prompts = json.dumps(primers) + config.node_schema = self.node_schema(prompts) + return config + + def node_schema(self, prompts): + names = [p["name"] for p in prompts if p["name"]] + node_schema = { + "current": {"type": "string", "default": "auto", "enum": ["auto"] + names}, + "prompts": { + "type": "array", + "format": "tabs", + "items": { + "type": "object", + "headerTemplate": "{{self.name}} - {{self.probability}}", + "properties": { + "enabled": { + "type": "boolean", + "default": True, + "format": "checkbox", + }, + "name": {"type": "string", "default": "", "maxLength": 10}, + "probability": {"type": "number", "default": 0.0}, + "prompt": { + "type": "string", + "default": "", + "format": "textarea", + "expand_height": True, + }, + }, + }, + }, + } + return json.dumps(node_schema) + + def get_by_name(self, primers, name): + for primer in primers: + if primer["name"] == name: + return primer + return primers[0] + + def pick_random_primer(self, primers): + total_probability = sum( + primer["probability"] + for primer in primers + if primer["enabled"] and primer["prompt"] + ) + random_number = random.uniform(0, total_probability) + cumulative_probability = 0 + for primer in primers: + if primer["enabled"] is False or primer["prompt"] == "": + continue + cumulative_probability += primer["probability"] + if random_number <= cumulative_probability: + return primer + + def format_response_prompt(self): + response_primer = ( + self.pick_random_primer(self.response_primers) + if self.current_reponse_primer == "auto" + else self.get_by_name(self.response_primers, self.current_reponse_primer) + ) + length_primer = ( + self.pick_random_primer(self.length_primers) + if self.current_length_primer == "auto" + else self.get_by_name(self.length_primers, self.current_length_primer) + ) + prompt = response_primer["prompt"] + if prompt[-1] not in [".", "?", "!"]: + prompt += "." + prompt += " " + length_primer["prompt"] + return prompt + + def trigger_service_callback(self, request): + response = TriggerResponse() + response.success = True + response.message = self.format_response_prompt() + return response + + def run(self): + r = rospy.Rate(1) + while not rospy.is_shutdown(): + r.sleep() + rospy.set_param( + "/hr/interaction/prompts/response_prompt", self.format_response_prompt() + ) + rospy.spin() + + +if __name__ == "__main__": + node = TriggerServiceNode() + node.run() diff --git a/modules/ros_chatbot/scripts/profiler b/modules/ros_chatbot/scripts/profiler new file mode 100755 index 0000000..286febb --- /dev/null +++ b/modules/ros_chatbot/scripts/profiler @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +BASH_DIR=$(dirname $(readlink -f ${BASH_SOURCE[0]})) +source /opt/hansonrobotics/ros/setup.bash +source /root/hansonrobotics/devel/setup.bash +/opt/hansonrobotics/py3env/bin/kernprof -l $BASH_DIR/run.py diff --git a/modules/ros_chatbot/scripts/rosbot.py b/modules/ros_chatbot/scripts/rosbot.py new file mode 100755 index 0000000..1706286 --- /dev/null +++ b/modules/ros_chatbot/scripts/rosbot.py @@ -0,0 +1,350 @@ +#!/usr/bin/env python + +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import asyncio +import logging +import os +import threading +import time +import uuid +from queue import Queue + +import rospy +import socketio +from btree_client.robot import GenericRobot +from btree_client.schemas import ActionResult +from dotenv import load_dotenv +from dynamic_reconfigure.client import Client +from dynamic_reconfigure.server import Server +from haipy.parameter_server_proxy import UserSessionContext +from hr_msgs.msg import TTS, ChatMessage, ChatResponse, ChatResponses +from std_srvs.srv import Trigger + +from ros_chatbot.agents.model import AgentRequest +from ros_chatbot.cfg import ROSBotConfig +from ros_chatbot.db import write_request + +logger = logging.getLogger("hr.ros_chatbot.rosbot") + +load_dotenv() + + +class ROSBot(GenericRobot): + def __init__(self, token): + super(ROSBot, self).__init__("default", "default", namespace="/") + self.token = token + self.future = None + self.current_speech_result = None + self.events = Queue() + self.loop = asyncio.new_event_loop() + self.session_context = UserSessionContext("default", "default") + self.trees = [] + self.running_arf = threading.Event() + self.detect_speech_event = threading.Event() + + # Add this check for empty token + self.btree_enabled = bool(token) + if not self.btree_enabled: + logger.warning("Behavior tree functionality is disabled due to missing token") + + rospy.Subscriber( + "/hr/perception/hear/sentence", ChatMessage, self._user_speech_cb + ) + rospy.Subscriber( + "/hr/perception/hear/interim_speech", + ChatMessage, + self._user_interim_speech_cb, + ) + + self.wait_for_tts = rospy.ServiceProxy( + "/hr/control/speech/wait_for_tts", Trigger + ) + self.say_pub = rospy.Publisher("/hr/control/speech/say", TTS, queue_size=1) + self._responses_publisher = rospy.Publisher( + "/hr/interaction/chatbot_responses", ChatResponses, queue_size=1 + ) + + Server(ROSBotConfig, self.reconfig) + threading.Thread(target=self.monitor_tree, daemon=True).start() + threading.Thread(target=self.start_background_loop, daemon=True).start() + + def start_background_loop(self): + asyncio.set_event_loop(self.loop) + self.loop.run_forever() + + def _user_interim_speech_cb(self, msg): + if msg.utterance: + self.user_speaking.set() + + def _user_speech_cb(self, msg): + if msg.utterance: + speech_result = {} + speech_result["transcript"] = msg.utterance + speech_result["lang"] = msg.lang + speech_result["audio_path"] = msg.audio_path + speech_result["source"] = msg.source + if self.detect_speech_event.is_set(): + self.current_speech_result = speech_result + else: + asyncio.run_coroutine_threadsafe( + self.sio.emit( + "event", + {"type": "chat", "text": msg.utterance, "lang": msg.lang}, + ), + self.loop, + ) + + def is_hybrid(self): + try: + client = Client("/hr/interaction/chatbot", timeout=1) + hybrid_mode = client.get_configuration(timeout=1)["hybrid_mode"] + except Exception: + hybrid_mode = False + return hybrid_mode + + async def say(self, message): + attachment = message.get("attachment") + if attachment and attachment.get("agent_id"): + agent_id = attachment.get("agent_id") + else: + agent_id = "" + if self.is_hybrid(): + responses_msg = ChatResponses() + response_msg = ChatResponse() + response_msg.text = message["text"] + response_msg.lang = message.get("lang") or "en-US" + response_msg.agent_id = f"ARF-{agent_id}" + response_msg.label = f"ARF-{agent_id}" + responses_msg.responses.append(response_msg) + if responses_msg.responses: + self._responses_publisher.publish(responses_msg) + else: + msg = TTS() + msg.text = message["text"] + msg.lang = message.get("lang") or "en-US" + msg.agent_id = f"ARF-{agent_id}" + self.say_pub.publish(msg) + logger.warning("Waiting for TTS to finish %r", message["text"]) + await asyncio.sleep( + 0.2 + ) # need to wait a bit for the message to be received by TTS node + self.wait_for_tts() + logger.warning("TTS has finished") + return True + + def monitor_tree(self): + if not self.btree_enabled: + # If behavior trees are disabled, just wait indefinitely + threading.Event().wait() + return + + while True: + while not self.trees: + self.trees = self.session_context.get("btrees", []) + time.sleep(1) + while self.running_arf.is_set(): + logger.info("Btrees %s", self.trees) + if self.future: + if self.future.done(): + self.future = None + if not self.future and self.trees: + next_tree = self.trees[0] + try: + logger.warning("Running next tree %r", next_tree) + self.future = asyncio.run_coroutine_threadsafe( + self.connect_socket(self.token, [next_tree]), self.loop + ) + self.future.result() + self.future = None + logger.warning("Completed tree %r", next_tree) + except socketio.exceptions.ConnectionError as ex: + logger.error("ConnectionError %r", ex) + self.future.cancel() + except Exception as ex: + logger.error("Tree running error %r", ex) + finally: + self.trees.remove(next_tree) + time.sleep(1) + else: + if self.future: + logger.warning("Cancelling ARF") + self.future.cancel() + try: + asyncio.run_coroutine_threadsafe(self.close(), self.loop) + except Exception as ex: + logger.error("Cancelling tree error %s", ex) + time.sleep(2) + self.future = None + self.trees = [] + time.sleep(1) + + def set_running(self, running): + if running: + logger.warning("Disabling chatbot listening") + else: + logger.warning("Enabling chatbot listening") + try: + Client("/hr/interaction/chatbot").update_configuration( + {"listen_speech": not running} + ) + # Client("/hr/interaction/rosbot").update_configuration({"running": running}) + except Exception as ex: + logger.error("Set running error %s", ex) + rospy.set_param("/hr/interaction/chatbot/listen_speech", False) + + async def on_say(self, message): + await self.sio.emit("ack", "say") + try: + success = await self.say(message) + except Exception as ex: + success = False + logger.error("Say error %s", ex) + result = ActionResult(success=success, event="say").dict() + await self.sio.emit("ack", "say done") + return result + + async def on_set(self, message): + logger.info("set %s", message) + return ActionResult(success=True, event="set").dict() + + async def wait_asr_result(self): + while True: + if self.current_speech_result: + return self.current_speech_result + else: + await asyncio.sleep(0.1) + + def new_request(self, speech_result: dict): + sid = rospy.get_param("/hr/interaction/chatbot/session", "") + request = AgentRequest() + request.sid = sid + request.request_id = str(uuid.uuid4()) + request.question = speech_result["transcript"] + request.lang = speech_result["lang"] + request.audio = speech_result["audio_path"] + request.source = speech_result["source"] + request.session_context = self.session_context + try: + write_request(request) + except Exception as ex: + logger.error("Write request error: %s", ex) + + async def wait_for_speech(self, timeout, lang): + """Waits for user's speech""" + logger.warning("[Speech Start]") + self.current_speech_result = None + self.user_speaking.clear() + try: + logger.warning("Wait for speech event %s", timeout) + await asyncio.wait_for(self.user_start_speaking(), timeout=timeout) + logger.warning("Wait for speech result") + result = await asyncio.wait_for(self.wait_asr_result(), timeout=30) + if result: + self.new_request(result) + logger.warning("[Speech End]") + logger.warning("Speech %r", result["transcript"]) + return result + else: + logger.info("No speech result") + except asyncio.TimeoutError: + logger.error("[Speech Timeout]") + + def enable_asr(self): + try: + dyn_client = Client("/hr/perception/speech_recognizer", timeout=1) + dyn_client.update_configuration({"enable": True}) + logger.warning("Enable ASR") + return True + except Exception as ex: + logger.exception(ex) + return False + + async def on_detect_speech(self, message): + logger.info("detect_speech %s", message) + await self.sio.emit("ack", "detect_speech") + self.enable_asr() + self.detect_speech_event.set() + + try: + result = await self.wait_for_speech( + message["speech_timeout"], + message["lang"], + ) + except Exception as ex: + logger.error(ex) + finally: + self.detect_speech_event.clear() + self.empty_speech_queue() + + if result: + return ActionResult( + success=True, event="detect_speech", message=result + ).dict() + return ActionResult(success=False, event="detect_speech").dict() + + async def on_probe(self, message): + return ActionResult(success=True, event="probe", message={}).dict() + + async def on_disconnect(self): + if self.future: + self.future.cancel() + self.future = None + self.set_running(False) + logger.warning("rosbot is disconnected from server") + + async def on_connect(self): + self.set_running(True) + logger.warning("rosbot is connected to server") + + def reconfig(self, config, level): + robot_mode = rospy.get_param("/hr/interaction/chatbot/robot_mode", None) + if config.running: + if robot_mode and robot_mode != "Stage": + if not self.running_arf.is_set(): + logger.warning("Running ARF") + self.running_arf.set() + else: + logger.warning("Can't run ARF in stage mode") + config.running = False + else: + if self.running_arf.is_set(): + logger.warning("Stopping ARF") + self.running_arf.clear() + if self.future and not self.future.done(): + future = asyncio.run_coroutine_threadsafe(self.close(), self.loop) + try: + future.result(timeout=10) + except Exception as ex: + logger.error(ex) + self.running_trees = [] + self.future.cancel() + self.set_running(False) + return config + + +if __name__ == "__main__": + BTREE_SERVER_TOKEN = os.environ.get("BTREE_SERVER_TOKEN") + if not BTREE_SERVER_TOKEN: + logger.warning("BTREE_SERVER_TOKEN environment variable not set. Behavior tree functionality will be disabled.") + # Initialize with a dummy token - the ROSBot class will handle this gracefully + BTREE_SERVER_TOKEN = "" + + rospy.init_node("rosbot") + bot = ROSBot(BTREE_SERVER_TOKEN) + rospy.spin() diff --git a/modules/ros_chatbot/scripts/run.py b/modules/ros_chatbot/scripts/run.py new file mode 100755 index 0000000..a0d0ba0 --- /dev/null +++ b/modules/ros_chatbot/scripts/run.py @@ -0,0 +1,1710 @@ +#!/usr/bin/env python + +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import copy +import json +import logging +import os +import re +import threading +import time +import uuid +import warnings +from datetime import datetime +from queue import Empty, Queue +from typing import List + +import requests +import rospy +from dynamic_reconfigure.client import Client +from dynamic_reconfigure.server import Server +from haipy.nlp.translate import detect_language +from haipy.parameter_server_proxy import UserSessionContext +from haipy.scheduler.ims_drivers import BaseDriver +from haipy.scheduler.intention_manager import IntentionManager +from haipy.scheduler.schemas.documents import BaseDriverDocument +from haipy.utils import LANGUAGE_CODES_NAMES +from hr_msgs.msg import ( + TTS, + ChatMessage, + ChatResponse, + ChatResponses, + Event, + EventMessage, +) +from hr_msgs.srv import ( + AgentRegister, + AgentUnregister, + RunByName, + StringArray, + StringTrigger, + TTSTrigger, +) +from langchain_community.chat_message_histories import RedisChatMessageHistory +from pydantic import BaseModel +from sensor_msgs.msg import JointState +from std_msgs.msg import String +from std_srvs.srv import Trigger + +import ros_chatbot.interact.action_types as action_types +import ros_chatbot.interact.event_types as event_types +from ros_chatbot.activity_monitor import ActivityMonitor, EngagementLevel +from ros_chatbot.agents.model import AgentResponse, AgentStreamResponse, LLMAgent +from ros_chatbot.agents.rosagent import ROSGenericAgent +from ros_chatbot.cfg import ROSChatbotConfig +from ros_chatbot.chat_server import ChatServer +from ros_chatbot.context_manager import ContextManager +from ros_chatbot.data_loader import DataLoader +from ros_chatbot.db import write_conv_insight, write_responses +from ros_chatbot.interact.base import BasicEventGenerator +from ros_chatbot.interact.controller_manager import ControllerManager +from ros_chatbot.interact.state import State +from ros_chatbot.reconfiguration import ( + AgentReconfiguration, + DriverReconfiguration, + PromptTemplatesReconfiguration, + SceneReconfiguration, +) +from ros_chatbot.scene_manager import SceneManager +from ros_chatbot.schemas import Scene +from ros_chatbot.state_machine import RobotState +from ros_chatbot.utils import ( + load_agent_config, + load_modes_config, + remove_duplicated_responses, + strip_xmltag, +) +from ros_chatbot.visual_processing import VisualProcessingConfig + +logger = logging.getLogger("hr.ros_chatbot.run") + +# Suppress the specific ROS publisher warning +warnings.filterwarnings( + "ignore", + message=".*The publisher should be created with an explicit keyword argument 'queue_size'.*", +) + +SOULTALK_HOT_UPLOAD_DIR = os.environ["SOULTALK_HOT_UPLOAD_DIR"] +CMS_DIR = os.environ["CMS_DIR"] # eg. /hr/.hrsdk/cms_content/sophia-sophia +REDIS_SERVER_HOST = os.environ.get("REDIS_SERVER_HOST", "localhost") +REDIS_SERVER_PORT = os.environ.get("REDIS_SERVER_PORT", "6379") + +BAR_PATTERN = re.compile(r"""\|[^\|]+\|""") # e.g. |p| + +HR_CHARACTER = os.environ["HR_CHARACTER"] +ROBOT_NAME = os.environ["ROBOT_NAME"] +ROBOT_BODY = os.environ["ROBOT_BODY"] +character = HR_CHARACTER.title() + + +class Chatbot: + uid = "default" + + def __init__(self): + self.cfg = None + self.default_language = "en-US" + self.server = ChatServer() + self.arf_fire_timer = None + self.scenes = {} + + self.state = State() + self.user_speaking = False + self.controller_manager = ControllerManager(self.state, self._handle_action) + self.controller_manager.register_event_generator( + "stage", BasicEventGenerator("stage") + ) + self.activity_monitor = ActivityMonitor( + window_size=300, + ) + + self.interrupted = threading.Event() + # Below event is set then chatbot need to cancel thhe thinking + self.chat_interupt_by_speech = threading.Event() + # Below event will make sure chat will wait before publishing results for interupt_by_speech or timeout + self.chat_interupt_by_activity = threading.Event() + # This to append speech while user speaking + self.chat_buffer = "" + # time for last speech activity, needs to track in case its falsly interrupted. In case final result followed by interim result without it being ever finalized + self.last_speech_activity = 0 + self.speech_to_silence_interval = 1.5 # In case speech is interupted it will wait for 1.5 seconds inactivity or final result. If no final result is received the responses will be published + + self.chat_requests = [] # for calculating the total dialogue turns + self.last_chosen_response = None + self.current_responses = {} # response id -> response + self.asr_dyn_client = None + self.lock = threading.RLock() + self.pre_lock = threading.RLock() + self.start_silence_detection = threading.Event() + self.performance_idle_flag = threading.Event() + self.internet_available = True + self.silence_event_timer = None + self.robot_mode = "" + self.agent_config = load_agent_config() + self.modes_config = load_modes_config() + + self.session_context = UserSessionContext(self.uid, uuid.uuid1().hex) + self.context_manager = ContextManager() + self.chat_memory = RedisChatMessageHistory( + f"{self.session_context.ns}.history", + url=f"redis://{REDIS_SERVER_HOST}:{REDIS_SERVER_PORT}", + key_prefix="", + ttl=3600, + ) + self.scene_manager = SceneManager( + character, + self.scenes, + self.session_context, + self.agent_config, + self.server.document_manager, + ) + self.agent_reconfiguration = AgentReconfiguration() + self.scene_reconfiguration = SceneReconfiguration( + character, + self.session_context, + self.on_enter_scene_callback, + self.on_exit_scene_callback, + self.on_scene_change_callback, + ) + intention_manager = IntentionManager(self.session_context) + self._drivers_pub = rospy.Publisher( + "/hr/interaction/available_drivers", String, queue_size=10, latch=True + ) + rospy.Subscriber( + "/hr/interaction/update_driver_status", + String, + self._update_driver_status_callback, + ) + intention_manager.add_tasks_update_callback(self._publish_drivers) + logger.info("Registered callback to publish active drivers to ROS topic") + self.driver_reconfiguration = DriverReconfiguration( + self.session_context, intention_manager + ) + self.activity_monitor.add_engagement_level_change_callback( + self.driver_reconfiguration.on_engagement_level_change + ) + self.driver_reconfiguration.add_driver_callback(self.driver_callback) + self.prompt_templates_reconfiguration = PromptTemplatesReconfiguration( + [agent.id for agent in self.server.agents.values()], + self.session_context, + ) + + # Time to ignore speech until + self.ignore_speech_until = 0 + # Keaps track of the thread for the chat request + self.current_chat_thread = None + + # once configuration loaded it will start rest automatically + Server(ROSChatbotConfig, self.reconfig) + + env_monitor = threading.Thread( + name="env_monitor", target=self.env_monitoring, daemon=True + ) + env_monitor.start() + + # conversation quality monitor + monitor = threading.Thread(name="quality_monitor", target=self.monitoring) + monitor.daemon = True + monitor.start() + + state_monitor = threading.Thread( + name="state_monitor", target=self.state_monitoring + ) + state_monitor.daemon = True + state_monitor.start() + + # silence event detection + silence_event_detection = threading.Thread(target=self._silence_event_detection) + silence_event_detection.daemon = True + silence_event_detection.start() + + # Stream responses handling + self.stream_responses = Queue() + stream_handler = threading.Thread(target=self._stream_response_handler) + stream_handler.daemon = True + stream_handler.start() + # visual processing + self.visual_processing_config = VisualProcessingConfig() + threading.Thread(target=self._visual_processing, daemon=True).start() + + def driver_callback(self, driver: BaseDriver, output_model: BaseModel): + msg = None + if hasattr(output_model, "metrics"): + msg = JointState() + msg.header.stamp = rospy.Time.now() + for k, v in output_model.metrics.model_dump().items(): + msg.name.append(k) + msg.position.append(v) + if driver.type == "EmotionalDriver": + if msg: + self.emotional_metrics_pub.publish(msg) + logger.info(f"Published emotional metrics: {msg}") + elif driver.type == "PhysiologicalDriver": + if msg: + self.physical_metrics_pub.publish(msg) + logger.info(f"Published physical metrics: {msg}") + + def _publish_drivers(self, drivers: List[BaseDriverDocument]): + """Publish drivers from the intention manager to a ROS topic. + + Args: + drivers: List of BaseDriverDocument objects + """ + try: + # Convert drivers to a simpler dict format for JSON serialization + drivers_data = [] + for driver in drivers: + driver_dict = { + "id": str(driver.id), + "created_at": driver.created_at.isoformat(), + "updated_at": driver.updated_at.isoformat(), + "name": driver.name, + "type": driver.type, + "level": driver.level, + "priority": driver.priority, + "context": { + "status": driver.context.status, + "reason": driver.context.reason + if hasattr(driver.context, "reason") + else "", + "description": driver.context.description + if hasattr(driver.context, "description") + else "", + "metrics": driver.context.metrics.to_dict(), + }, + "valence": driver.context.valence, + "plans": [str(p) for p in driver.context.plans] + if driver.context.plans + else [], + } + drivers_data.append(driver_dict) + self._drivers_pub.publish(json.dumps(drivers_data)) + logger.info("Published %d drivers to ROS topic", len(drivers)) + except Exception as e: + logger.error("Error publishing drivers to ROS topic: %s", str(e)) + logger.exception("Full traceback:") + + def _update_driver_status_callback(self, msg): + logger.info("Received update driver status message: %s", msg.data) + data = json.loads(msg.data) + driver_id = data.get("id") + status = data.get("status") + # find driver by driver_id + driver = next( + ( + d + for d in self.driver_reconfiguration.intention_manager.drivers + if str(d.id) == driver_id + ), + None, + ) + if driver: + logger.info("Updating driver status for %s to %s", driver_id, status) + driver.context.status = status + driver.save() + else: + logger.warning("Driver not found: %s", driver_id) + + def new_conversation(self): + self.session_context.sid = uuid.uuid1().hex + self.session_context["character"] = character + self.session_context["robot_name"] = ROBOT_NAME + self.session_context["robot_body"] = ROBOT_BODY + self.session_context[ + "instant_situational_prompt" + ] = self.scene_reconfiguration.instant_situational_prompt + rospy.set_param("~session", self.session_context.sid) + logger.warning("Started new conversation %s", self.session_context.sid) + self.session_context["webui_language"] = LANGUAGE_CODES_NAMES.get( + self.current_language, self.current_language + ) + self.chat_memory.session_id = f"{self.session_context.ns}.history" + self.driver_reconfiguration.on_namespace_change() + self.prompt_templates_reconfiguration.on_namespace_change() + self.context_manager.session_context = self.session_context + self._load_chat_data() + self.on_scene_change_callback(self.scene_reconfiguration.scenes) + + def set_asr_context(self, param): + if param and isinstance(param, list): + param = ", ".join(param) + try: + dyn_client = Client("/hr/perception/speech_recognizer", timeout=1) + current_params = dyn_client.get_configuration() + phrases = current_params["phrases"] + if phrases: + phrases = phrases + ", " + param + else: + phrases = param + if phrases: + # remove duplicates + phrases = ", ".join( + list(set([phrase.strip() for phrase in phrases.split(",")])) + ) + dyn_client.update_configuration({"phrases": phrases}) + logger.info("Updated asr context phrases to %s", phrases) + except Exception as ex: + logger.exception(ex) + + def set_tts_mapping(self, param): + if param and isinstance(param, list): + param = ", ".join(param) + try: + dyn_client = Client("/hr/control/speech/tts_talker", timeout=1) + current_params = dyn_client.get_configuration() + word_mappings = current_params["word_mappings"] + if word_mappings: + word_mappings = word_mappings + ", " + param + else: + word_mappings = param + + # remove duplicated mappings + word_mappings = ",".join( + list( + set( + [ + mapping.strip() + for mapping in word_mappings.split(",") + if mapping.strip() + ] + ) + ) + ) + + dyn_client.update_configuration({"word_mappings": word_mappings}) + logger.info("Updated tts word mappings to %s", word_mappings) + except Exception as ex: + logger.exception(ex) + + def on_enter_scene_callback(self, scene: Scene): + logger.warning("Enter scene %s", scene.name) + self.new_conversation() + self.session_context["scene"] = scene.name + + self.scene_manager.update_scene(scene.name) + + for key, value in scene.variables.items(): + self.session_context[key] = value + self.session_context.proxy.expire(key, 3600) + + if scene.asr_context: + self.set_asr_context(scene.asr_context) + if scene.tts_mapping: + self.set_tts_mapping(scene.tts_mapping) + + threading.Timer( + 5, self.scene_manager.load_relevant_scenes, args=(scene.name,) + ).start() + + def on_exit_scene_callback(self, scene: Scene): + logger.warning("Exit scene %s", scene.name) + self.session_context["ended_scene"] = scene.name + + threading.Thread( + target=self.scene_manager.update_last_scene_document, + daemon=True, + ).start() + + for key in scene.variables.keys(): + del self.session_context[key] + del self.session_context["scene"] + del self.session_context["person_object"] + del self.session_context["prompt_template"] + + def on_scene_change_callback(self, scenes: List[Scene]): + # merge locally created scenes with the ones from CMS + for scene in scenes: + if scene.name: + self.scenes[scene.name] = scene + + def _stream_response_handler(self): + while True: + response = self.stream_responses.get() + # Process the response until queue is empty or finished is set top true + while ( + response.stream_data.qsize() > 0 + or not response.stream_finished.is_set() + ): + try: + sentence = response.stream_data.get( + timeout=response.stream_response_timeout + ) + response.stream_data.task_done() + except Empty: + if response.stream_timeout(): + response.stream_finished.set() + break + else: + if response.stream_finished.is_set(): + break + continue + if sentence: + self._say( + sentence, + response.lang, + response.agent_id, + response.request_id, + ignore_state=False, + ) + response.answer += " " + sentence + + # update response document + self.server.update_response_document( + response_id=response.response_id, text=response.answer + ) + self.stream_responses.task_done() + + def is_google_available(self): + try: + requests.get("http://www.google.com", timeout=5) + return True + except (requests.ConnectionError, requests.Timeout): + return False + + def env_monitoring(self): + while True: + self.internet_available = self.is_google_available() + time.sleep(1) + + def monitoring(self): + while True: + last_auto_response_time = self.state.getStateValue( + "last_auto_response_time" + ) + if last_auto_response_time: + time_elapsed = time.time() - last_auto_response_time + if time_elapsed > self.cfg.hybrid_when_idle: + # automatically turn to hybrid mode + if self.cfg.auto_automonous_free_chat: + logger.warning("Autonomous mode is deactivated from being idle") + self._set_hybrid_mode(True) + self.activity_monitor.monitoring() + + try: + metrics = self.activity_monitor.get_engagement_metrics() + metrics = {k: v for k, v in metrics.items() if k != "engagement_level"} + msg = JointState() + msg.header.stamp = rospy.Time.now() + for i, (k, v) in enumerate(metrics.items()): + msg.name.append(k) + msg.position.append(float(v)) + if self.last_engagement_metrics_msg: + msg.velocity.append( + (float(v) - self.last_engagement_metrics_msg.position[i]) + / ( + msg.header.stamp.to_sec() + - self.last_engagement_metrics_msg.header.stamp.to_sec() + ) + ) + else: + msg.velocity.append(0) + self.engagement_metrics_pub.publish(msg) + self.last_engagement_metrics_msg = msg + except Exception as e: + logger.error("Failed to publish engagement metrics") + logger.exception(e) + + time.sleep(1) + + def state_monitoring(self): + while True: + # arf idle + tree_running = rospy.get_param("/hr/interaction/rosbot/running", None) + current_state = self.session_context.get("state") + if ( + tree_running is False + and current_state is not None + and current_state + not in [ + "idle", + "asleep", + ] + and self.cfg.listen_speech + ): + self.session_context["state"] = "idle" + self._emit_event("event.idle") + if current_state in ["asleep"]: + # mute chatbot + if self.cfg.auto_automonous_free_chat and not self.hybrid_mode: + logger.info("Hybrid mode is set by state asleep") + self._set_hybrid_mode(True) + # repeat sleep animation + while self.session_context.get("state") in ["asleep"]: + logger.warning("Playing sleep animation") + self.run_performance("shared/arf/sleep/sleep_2") + time.sleep(0.2) + + scene = self.statemachine.scene + if scene == "idle": + self._emit_event("event.idle") + + time.sleep(1) + + def _visual_processing(self): + while True: + try: + if self.is_asr_running() and self.visual_processing_config.enabled: + result = self.describe_view_service.call( + data=self.visual_processing_config.get_prompt() + ) + if result.success: + insight = json.loads(result.message) + created_at = datetime.strptime( + insight["time"]["value"], insight["time"]["format"] + ) + record = { + "created_at": created_at, + "conversation_id": self.session_context.sid, + "type": "vision", + "insight": insight, + "character": character, + } + self.visual_processing_config.update_results( + insight["content"], insight["utc_time"] + ) + write_conv_insight(record) + # update visual clue + self.session_context["visual_clue"] = insight["content"] + rospy.set_param("~visual_clue", insight["content"]) + logger.info("Write visual insight %s", insight["content"]) + except Exception as ex: + logger.error(ex) + # could be network or some o + time.sleep(self.visual_processing_config.interval) + + def run_performance(self, performance_id): + performance_id = performance_id.strip() + response = self.performance_service(performance_id) + if response.success: + self.performance_idle_flag.clear() + logger.info("Run performance #%s successfully", performance_id) + else: + logger.warning("Failed to run performance #%s", performance_id) + self.performance_idle_flag.wait() + + def _performance_event_cb(self, msg): + if msg.event in ["idle"]: + self.performance_idle_flag.set() + + def _emit_silence_event(self): + if self.is_asr_running() and self.cfg.listen_speech: + # emit silence event only when ASR is running + self._emit_event("event.silence") + + def _emit_event(self, event): + logger.warning("Emitting event %s", event) + event_request = self.server.new_request( + self.session_context.sid, + event, + self.current_language, + source="system", + session_context=self.session_context, + ) + self._chat(event_request) + + def _silence_event_detection(self): + while True: + self.start_silence_detection.wait() + logger.info("Start silence detection") + self.silence_event_timer = threading.Timer(5, self._emit_silence_event) + self.silence_event_timer.start() + self.silence_event_timer.join() + logger.info("End silence detection") + self.start_silence_detection.clear() + time.sleep(0.5) + + @property + def autonomous_mode(self): + return not self.hybrid_mode + + @property + def hybrid_mode(self): + return self.cfg.hybrid_mode if self.cfg else True + + def start(self): + self.last_language = rospy.get_param("/hr/lang", self.default_language) + + self.new_conversation() + + # for webui + self.enable = self.cfg.enable + + self.node_name = rospy.get_name() + + chat_topic = rospy.get_param("~chat_topic", "/hr/interaction/chat") + rospy.Subscriber(chat_topic, ChatMessage, self._chat_callback) + + chat_event_topic = rospy.get_param( + "~chat_event_topic", "/hr/interaction/chat/event" + ) + self._chat_event_pub = rospy.Publisher( + chat_event_topic, String, queue_size=None + ) + + hear_topic = rospy.get_param("~hear_topic", "/hr/perception/hear/sentence") + rospy.Subscriber(hear_topic, ChatMessage, self._speech_chat_callback) + + rospy.Subscriber("/hr/interaction/interlocutor", String, self._interlocutor_cb) + + self._chat_pub = rospy.Publisher(hear_topic, ChatMessage, queue_size=None) + + event_topic = rospy.get_param("~hear_event_topic", "/hr/perception/hear/event") + rospy.Subscriber(event_topic, String, self._user_speech_event_callback) + + hybrid_response_topic = rospy.get_param( + "~hybrid_response_topic", "/hr/interaction/chatbot_responses" + ) + self._responses_publisher = rospy.Publisher( + hybrid_response_topic, ChatResponses, queue_size=None + ) + + # receive user's choice + response_chosen_topic = rospy.get_param( + "~response_chosen_topic", "/hr/interaction/chatbot_response" + ) + rospy.Subscriber( + response_chosen_topic, ChatResponse, self._response_chosen_callback + ) + + say_topic = rospy.get_param("~say_topic", "/hr/control/speech/say") + self._response_publisher = rospy.Publisher(say_topic, TTS, queue_size=None) + rospy.Subscriber(say_topic, TTS, self._tts_cb) + + self.tts_service = rospy.ServiceProxy("/hr/control/speech/tts", TTSTrigger) + self.is_performance_loaded = rospy.ServiceProxy( + "/hr/control/is_performance_loaded", Trigger + ) + + speech_events_topic = rospy.get_param( + "~speech_events_topic", "/hr/control/speech/event" + ) + rospy.Subscriber( + speech_events_topic, String, self._robot_speech_event_cb, queue_size=None + ) + rospy.Subscriber( + "/hr/perception/hear/interim_speech", ChatMessage, self._interim_speech_cb + ) + + rospy.Subscriber( + "/hr/interaction/content_manager/update_event", + String, + self._content_update_cb, + queue_size=10, + ) + rospy.Subscriber("/hr/interaction/arf", String, self._arf) + + tts_ctrl_topic = rospy.get_param( + "~tts_ctrl_topic", "/hr/control/speech/tts_control" + ) + self.tts_ctrl_pub = rospy.Publisher(tts_ctrl_topic, String, queue_size=None) + + event_topic = rospy.get_param("~event_topic", "/hr/interaction/event") + self.event_pub = rospy.Publisher(event_topic, EventMessage, queue_size=None) + + self.switch_character_pub = rospy.Publisher( + "/hr/interaction/switch_character", String, queue_size=None + ) + + self.emotional_metrics_pub = rospy.Publisher( + "/hr/interaction/emotional_metrics", JointState, latch=True, queue_size=None + ) + self.physical_metrics_pub = rospy.Publisher( + "/hr/interaction/physical_metrics", JointState, latch=True, queue_size=None + ) + self.engagement_metrics_pub = rospy.Publisher( + "/hr/interaction/engagement_metrics", + JointState, + latch=True, + queue_size=None, + ) + self.last_engagement_metrics_msg = None + + self.driver_reconfiguration.add_driver_callback(self.driver_callback) + + rospy.Subscriber( + "/hr/control/performances/background/events", + Event, + self._performance_event_cb, + ) + # Run performance by name + self.performance_service = rospy.ServiceProxy( + "/hr/control/performances/background/run_by_name", RunByName + ) + + self.describe_view_service = rospy.ServiceProxy( + "/hr/interaction/describe_view", StringTrigger + ) + + rospy.Service("register", AgentRegister, self.ros_service_register) + rospy.Service("unregister", AgentUnregister, self.ros_service_unregister) + rospy.Service("available_agents", StringArray, self.list_all_installed_agents) + rospy.Service("set_context", StringTrigger, self.ros_service_set_context) + rospy.Service("get_context", StringTrigger, self.ros_service_get_context) + rospy.Service("available_scenes", StringArray, self.list_all_scenes) + + # Set up agent configuration + for agent in copy.copy(self.server.agents).values(): + self.agent_reconfiguration.set_up_agents_runtime_dynamic_reconfigure( + agent, + self.presets, + ) + self.scene_reconfiguration.start_ddr() + + def run_reflection(self, text, lang): + results = self.server.run_reflection(self.session_context.sid, text, lang) + if results: + persona = [r["text"] for r in results] + logger.info("Run reflection persona %s", persona) + existing_persona = self.session_context.get("persona", []) + persona = existing_persona + persona + self.session_context["persona"] = persona + + def _tts_cb(self, msg): + if not msg.text: + return + self.activity_monitor.record_chat_activity() + text = BAR_PATTERN.sub("", msg.text).strip() + if not text: + return + + self.server.add_record(text) + + # Inform on all TTS activity and let agents to collect data if needed + text = strip_xmltag(text) + for a in self.server.agents.values(): + a.character_said(text, msg.lang) + + # update redis memory + self.chat_memory.add_ai_message(text) + + self.session_context["turns"] = len(self.chat_requests) + self.session_context["last_active_time"] = datetime.utcnow() + + self.run_reflection(text, msg.lang) + + if not msg.request_id: + self._record_other_agent_response(msg) + + current_state = self.session_context.get("state") + if current_state in ["asleep"]: + self.session_context["state"] = "" # reset state + + def _record_other_agent_response(self, msg): + response = AgentResponse() + response.agent_id = msg.agent_id or "Human" + response.response_id = str(uuid.uuid4()) + response.answer = msg.text + response.attachment["published"] = True + write_responses([response]) + self.server.publish(response.response_id) + + def _after_published_responses(self, request_id): + request = self.server.requests.get(request_id) + if request is None: + # probably not responses from this dialogue system + return + + if request_id not in self.chat_requests: + logger.info("New request got responded %s", request) + self.chat_requests.append(request_id) + + def _robot_speech_event_cb(self, msg): + if msg.data: + if msg.data.startswith("start"): + self.state.update(robot_speaking=True) + self.start_silence_detection.clear() + if self.silence_event_timer and self.silence_event_timer.is_alive(): + self.silence_event_timer.cancel() + if msg.data.startswith("stop"): + self.state.update(robot_speaking=False) + self.start_silence_detection.set() + + def _interim_speech_cb(self, msg): + s = String("speechcont") + self._user_speech_event_callback(s) + + def _user_speech_event_callback(self, msg): + if msg.data: + self.activity_monitor.record_speech_activity() + if msg.data.startswith("speechstart") or msg.data.startswith("speechcont"): + self.state.update(user_speaking=True) + self.user_speaking = True + self.last_speech_activity = time.time() + self.chat_interupt_by_activity.set() + if msg.data.startswith("speechstop"): + self.state.update(user_speaking=False) + self.user_speaking = False + + def _response_chosen_callback(self, msg): + """Handles the response chosen by user""" + if msg.response_id in self.server.responses: + self.server.publish(msg.response_id, label=msg.label, resolver="Human") + response = self.server.responses[msg.response_id] + self._publish_resolved_response(response, ignore_state=True) + elif msg.response_id == "": + # the response from external i.e. snet + response = AgentResponse() + response.request_id = msg.request_id + response.response_id = msg.response_id + response.agent_id = msg.agent_id + response.lang = msg.lang + response.answer = msg.text + response.end() + self._publish_resolved_response(response, ignore_state=True) + else: + logger.error("Response is lost %r", msg.response_id) + + def _arf(self, msg): + if msg.data == "arf": + logger.warning("Fire ARF!") + self.fire_arf() + time.sleep(2) + else: + logger.warning("Activate scene %r", msg.data) + self.fire_arf(msg.data) + + def fire_arf(self, scene_name=None): + self.session_context["last_active_time"] = datetime.utcnow() + if self.statemachine.state == "initial": + logger.warning("Set state machine: chat") + self.statemachine.chat() + if scene_name is not None: + self.statemachine.scene = scene_name + else: + scene_name = self.statemachine.scene + if not scene_name: + logger.warning("Not in any scene") + return + + scene = self.scenes.get(scene_name) + if scene: + if scene.type == "preset": + for agent in copy.copy(self.server.agents).values(): + if not isinstance(agent, LLMAgent): + continue + try: + if callable(agent.runtime_config_callback): + agent.runtime_config_callback({"prompt_preset": scene.name}) + except Exception as ex: + logger.exception(ex) + + arf_events = self.session_context.get(f"arf.events.{scene_name}", []) + fired = False + for i, arf_event in enumerate(arf_events): + _arf_event = json.loads(arf_event) + if _arf_event["triggered"]: + continue + logger.warning("Trigger ARF %s", _arf_event) + self._emit_event(_arf_event["arf_event"]) + _arf_event["triggered"] = True + arf_events[i] = json.dumps(_arf_event) + fired = True + break + if fired: + # update arf events + self.session_context.proxy.set_param(f"arf.events.{scene_name}", arf_events) + else: + logger.warning("No ARF to fire") + + def _load_chat_data(self): + """Loads scene data, prompt templates and prompt preset""" + self.data_loader = DataLoader(CMS_DIR, self.session_context) + self.data_loader.load_all_data() + self.scenes.update(self.data_loader.scenes) + self.scene_manager.update_scenes(self.scenes) + self.presets = self.data_loader.presets + self._update_agent_presets() + self._init_state_machine() + + def _update_agent_presets(self): + agents = [ + agent + for agent in self.server.agents.values() + if isinstance(agent, LLMAgent) and agent.runtime_config_description + ] + for agent in agents: + self.agent_reconfiguration.update_presets(agent, self.presets) + + def _init_state_machine(self): + self.statemachine = RobotState( + self.session_context, + name=character, + on_enter_scene_callback=self.on_enter_scene_callback, + on_exit_scene_callback=self.on_exit_scene_callback, + ) + self.statemachine.load_scenes(self.scenes.values()) + self.statemachine.manual() + + def _content_update_cb(self, msg): + if msg.data == "updated": + self._load_chat_data() + if msg.data == "reload": + self.reset_session() + + def _interlocutor_cb(self, msg): + self.session_context["interlocutor"] = msg.data + self.session_context.proxy.expire("interlocutor", 3600) + + def set_language(self, lang): + try: + rospy.set_param("/hr/lang", lang) + dyn_client = Client("/hr/perception/speech_recognizer", timeout=1) + dyn_client.update_configuration({"language": lang}) + logger.info("Updated language to %s", lang) + return True + except Exception as ex: + logger.exception(ex) + return False + + def _say( + self, text, lang, agent_id, request_id="", audio_path="", ignore_state=False + ): + if not ignore_state and self.state.is_full_stopped(): + logger.warning("Robot is in full-stopped mode") + return + text = re.sub(r"""\[callback.*\]""", "", text) + + msg = TTS() + msg.text = text + # detect the language of the text + if lang != "en-US" and "speak in" not in text.lower(): + detected_lang = detect_language(text) or self.current_language + if detected_lang != lang: + logger.warning( + "Detected language %s is different from the language %s", + detected_lang, + lang, + ) + lang = detected_lang + msg.lang = lang + msg.agent_id = agent_id + msg.request_id = request_id + msg.audio_path = audio_path + self._response_publisher.publish(msg) + + # Ignore speech in autonomous mode + if self.cfg.ignore_speech_while_thinking and not self.cfg.hybrid_mode: + # ignore final results for 2 seconds after TTS message published: + # The 2 seconds should be enpough for most cases with good enough network connection, in case TTS is internet based + self.ignore_speech_until = time.time() + 2.0 + + def _set_hybrid_mode(self, hybrid_mode): + dyn_client = Client("/hr/interaction/chatbot", timeout=1) + dyn_client.update_configuration({"hybrid_mode": hybrid_mode}) + if hybrid_mode: + self.state.update(last_auto_response_time=None) + else: + self.session_context["state"] = "" # reset state + + def _handle_post_action(self, action): + if not action: + return + action_type = action["type"] + if action_type == action_types.SET_HYBRID_MODE: + if self.cfg.auto_automonous_free_chat: + self._set_hybrid_mode(True) + + def reset_session(self): + self.server.session_manager.agent_sessions(True) + rospy.set_param("~session", self.session_context.sid) + if self.session_context and "block_chat" in self.session_context: + del self.session_context["block_chat"] + logger.warning("Reset. New session %s", self.session_context.sid) + self.new_conversation() + + def reset(self): + self.reset_session() + self.controller_manager.reset() + self.chat_requests = [] + self.session_context.clear() + self.session_context["turns"] = 0 + self.server.ranker.records = [] + # Reset should not reset the current scene, so it will be the same as before + scene = self.scene_reconfiguration.get_current_scene() + if scene: + self.on_enter_scene_callback(scene) + + def _handle_action(self, action): + if not action: + return + action_type = action["type"] + payload = action["payload"] + if action_type == action_types.RESET: + self.reset() + if action_type == action_types.MONITOR: + self._publish_event(payload) + if action_type == action_types.PLACEHOLDER_UTTERANCE: + self._say(payload["text"], self.current_language, payload["controller"]) + + def _publish_event(self, event): + type = event["type"] + payload = event["payload"] + msg = EventMessage() + msg.type = type + msg.payload = json.dumps(payload) + msg.stamp = rospy.Time.now() + self.event_pub.publish(msg) + + def _speech_chat_callback(self, msg): + """If directly listening for speech""" + if self.cfg.listen_speech: + # Use seprate threads so it can be interrupted + t = threading.Thread(target=self._chat_callback, args=(msg,)) + t.daemon = True + t.start() + self.session_context["last_active_time"] = datetime.utcnow() + + def is_asr_running(self): + try: + if self.asr_dyn_client is None: + self.asr_dyn_client = Client( + "/hr/perception/speech_recognizer", timeout=1 + ) + asr_enabled = self.asr_dyn_client.get_configuration(timeout=1)["enable"] + except Exception: + asr_enabled = False + return asr_enabled + + def _chat_callback(self, msg): + """Responds to topic message""" + with self.pre_lock: + if self.cfg.concat_multiple_speech: + # Only cancel request if there was speech activity + if self.chat_interupt_by_activity.is_set(): + if msg.utterance[0] not in [ + ":", + "{", + ] and not msg.utterance.lower().startswith("event."): + self.chat_interupt_by_speech.set() + msg.utterance = self.chat_buffer + " " + msg.utterance + + # update webui language + self.session_context["webui_language"] = LANGUAGE_CODES_NAMES.get( + self.current_language, self.current_language + ) + + # Lock will wait for interuption to be handled so it will not execute two chats at the same time. in most cases the chat will be just waiting for chat to finish + with self.lock: + if not self.enable: + return + if not msg.utterance: + return + if msg.source == "rosbot": + return + if self.ignore_speech_until > time.time(): + return + logger.info("Received chat message %r in %r", msg.utterance, msg.lang) + self.start_silence_detection.clear() + if self.silence_event_timer and self.silence_event_timer.is_alive(): + self.silence_event_timer.cancel() + # is_prompt. brackets will define if the string is prompt[] + is_prompt = msg.utterance[0] == "{" and msg.utterance[-1] == "}" + + if self.state.is_interruption_mode(): + logger.info("Ignore when it is in interruption resuming mode") + return + + offline = msg.source == "fallback" + if offline: + logger.info("Offline speech input %r", msg.utterance) + + # ignore gibberish vosk inputs + if offline and msg.utterance in [ + "by", + "but", + "both", + "back", + ]: + logger.info("Ignore gibberish vosk inputs") + return + + if not is_prompt and not offline: + # Inform agents of the speech if they need that + for a in self.server.agents.values(): + a.speech_heard(msg.utterance, msg.lang) + + # update redis memory + messages = self.chat_memory.messages + last_message = messages[-1] if messages else None + if not ( + last_message + and last_message.type == "human" + and last_message.content == msg.utterance + ): + # only add new message when the last message is from user and is different + # this is the case when you resend the question + self.chat_memory.add_user_message(msg.utterance) + + if self.cfg.enable_placeholder_utterance_controller: + placeholder_contrller = self.controller_manager.get_controller( + "placeholder_utterance_controller" + ) + if placeholder_contrller: + placeholder_contrller.enabled = True + + utterance = {} + utterance["lang"] = msg.lang + utterance["text"] = msg.utterance + utterance["uuid"] = str(uuid.uuid4()) + if msg.utterance.startswith(":"): + self.state.update(command=utterance) + self.controller_manager.wait_for(event_types.USER_COMMAND, 1) + else: + self.state.update(utterance=utterance) + self.controller_manager.wait_for(event_types.UTTERANCE, 1) + + # TODO: make sure the events arrive to controllers before calling act() + # Or make chat as a controller + self.controller_manager.wait_controller_finish() + + actions = self.controller_manager.act() + if actions: + for action in actions: + self._handle_action(action) + else: + logger.info("No actions") + + audio = os.path.basename(msg.audio_path) if msg.audio_path else "" + + if self.is_asr_running(): + if ( + offline + and not self.cfg.offline_asr_free_chat + and self.internet_available + ): + logger.info("Ignore offline asr when online asr is running") + return + + if self.autonomous_mode: + if ( + not msg.utterance.lower().startswith("event.") + and self.state.is_robot_speaking() + and not self.interrupted.is_set() + ): + logger.warning( + "Ignore chat %r while robot is speaking", + msg.utterance, + ) + return + + request = self.server.new_request( + self.session_context.sid, + msg.utterance, + msg.lang or self.current_language, + audio=audio, + source=msg.source, + session_context=self.session_context, + ) + + # if agents temperarily blocked + if self.session_context.get("block_chat", False): + logger.warning("Blocking other chat agents") + request.context["agent"] = "AdhocBot" + + if ( + offline + and not self.cfg.offline_asr_free_chat + and self.internet_available + ): + request.context["agent"] = "AdhocBot" + request.context["require_priority_content"] = True + logger.info("Restrict the offline asr input to the priority content") + + if self.autonomous_mode: + if ( + not msg.utterance.lower().startswith("event.") + and self.state.is_robot_speaking() + and not self.interrupted.is_set() + ): + logger.warning( + "Ignore chat request %r while robot is speaking", + request.question, + ) + return + # Allow streaming in autonomous mode + request.allow_stream = True + + if not msg.utterance.lower().startswith( + "event." + ) and not msg.utterance.startswith(":"): + self.session_context["input"] = msg.utterance + + self._chat(request) + + if actions: + for action in actions: + self._handle_post_action(action) + + @property + def current_language(self): + current_language = rospy.get_param("/hr/lang", self.default_language) + if self.last_language != current_language: + logger.warning( + "Switch language from %r to %r", self.last_language, current_language + ) + self.server.on_switch_language(self.last_language, current_language) + self.chat_memory.clear() + self.last_language = current_language + return current_language + + def _filter_responses_by_tag(self, responses, tag): + if tag == "priority": + return [ + r + for r in responses + if tag in r.attachment.get("tag", []) + or "Skill" in r.attachment.get("topic_type", []) + ] + else: + return [r for r in responses if tag in r.attachment.get("tag", [])] + + def resolve_responses(self, responses): + if responses: + response = self.server.resolver.resolve(responses) + if response: + response.attachment["published"] = True + self.server.publish( + response.response_id, resolver_type=self.server.resolver.type + ) + return [response] + [ + r for r in responses if not r.attachment.get("published") + ] + + def _chat(self, request): + # Handles chat. Event is passed if the chat is interupted and results no need to be published or used. + if not self.enable: + return + self.activity_monitor.record_chat_activity() + # clear the interupt flag, it will be set(based ons ettings) if the speach is heard and chat do not need to wait. + self.chat_interupt_by_speech.clear() + self.chat_interupt_by_activity.clear() + self._chat_event_pub.publish("start thinking") + has_response = False + self.current_responses = {} + + if self.cfg.enable_rag: + self.scene_manager.create_rag(request.question) + + request.hybrid_mode = self.hybrid_mode + + if self.hybrid_mode: + for responses in self.server.chat_with_ranking(request): + if responses: + priority_responses = self._filter_responses_by_tag( + responses, "priority" + ) + if request.context.get("require_priority_content"): + responses = priority_responses + if responses: + has_response = True + + if self.cfg.auto_automonous_free_chat and request.source not in [ + "web", + "webrepeat", + ]: + # check whether activate autonomous free chat + resolved_responses = self.resolve_responses(priority_responses) + if resolved_responses: + self._set_hybrid_mode(False) + logger.warning("Priority rule has been triggered") + self._publish_ros_responses(resolved_responses) + break + self._publish_ros_responses(responses) + for response in responses: + self.current_responses[response.response_id] = response + else: + interupt_event = None + if self.cfg.concat_multiple_speech: + interupt_event = self.chat_interupt_by_speech + self.chat_buffer = request.question + + responses = ( + self.server.chat_with_resolving( + request, + fast_score=self.cfg.fast_score, + interrupt_event=interupt_event, + ) + or [] + ) + for response in responses: + self.current_responses[response.response_id] = response + + # Responses returned however if there are new speech we need to wait for timeout to continue or speech to be interrupted + cancelled = False + if self.cfg.concat_multiple_speech: + while not self.chat_interupt_by_speech.is_set(): + if not self.chat_interupt_by_activity.is_set(): + # Not interrupted so nothing to wait for + break + if ( + self.last_speech_activity + self.speech_to_silence_interval + < time.time() + ): + # Timeout + break + time.sleep(0.02) + if self.chat_interupt_by_speech.is_set(): + cancelled = True + if not cancelled: + self.chat_buffer = "" + if responses: + priority_responses = self._filter_responses_by_tag( + responses, "priority" + ) + if request.context.get("require_priority_content"): + responses = priority_responses + if not responses: + logger.warning("No non-priority responses after removal") + if responses: + has_response = True + self._publish_ros_responses(responses) + self.interrupted.clear() + self._chat_event_pub.publish("stop thinking") + if not has_response: + logger.info("Chatbot has no response") + self._chat_event_pub.publish("no response") + return has_response + + def _publish_ros_responses(self, responses): + """Publishes responses to ROS topic""" + if not responses: + return + if not self.hybrid_mode: + # publish the first published response + resolved_response = responses[0] + if resolved_response.attachment.get("published"): + logger.info("Choose %s", resolved_response) + self._publish_resolved_response(resolved_response) + self.state.update(last_auto_response_time=time.time()) + # check whether deactivate autonomous free chat + if "deactivate" in resolved_response.attachment.get("tag", []): + if self.cfg.auto_automonous_free_chat: + self._set_hybrid_mode(True) + logger.warning( + "Autonomous mode is deactivated by deactivation rules" + ) + if "activate" in resolved_response.attachment.get("tag", []): + if self.cfg.auto_automonous_free_chat: + self._set_hybrid_mode(False) + logger.warning("Activation rule has been triggered") + responses = responses[1:] + + if not responses: + return + uniq_responses = remove_duplicated_responses(responses) + uniq_responses = list(uniq_responses) + + responses_msg = ChatResponses() + for response in uniq_responses: + text = response.answer + if text: + # check whether activate autonomous free chat + # this goes autonomous in any event + if "activate" in response.attachment.get("tag", []): + if self.cfg.auto_automonous_free_chat: + self._set_hybrid_mode(False) + logger.warning("Activation rule has been triggered") + if not self.hybrid_mode: + self._publish_resolved_response(response) + else: + response_msg = ChatResponse() + response_msg.text = text + response_msg.lang = response.lang + response_msg.label = response.agent_id + response_msg.agent_id = response.agent_id + response_msg.request_id = response.request_id + response_msg.response_id = response.response_id + responses_msg.responses.append(response_msg) + else: + response_msg = ChatResponse() + response_msg.text = text + response_msg.lang = response.lang + response_msg.label = response.agent_id + response_msg.agent_id = response.agent_id + response_msg.request_id = response.request_id + response_msg.response_id = response.response_id + responses_msg.responses.append(response_msg) + + if responses_msg.responses: + self._responses_publisher.publish(responses_msg) + + def _publish_resolved_response(self, response: AgentResponse, ignore_state=False): + """ + Publish the resolved response and handle any associated actions or contexts. + """ + if response.answer: + self.server.add_record(response.answer) + self._after_published_responses(response.request_id) + self.server.feedback(response, self.hybrid_mode) + + self.last_chosen_response = response + self._say( + response.answer, + response.lang, + response.agent_id, + response.request_id, + ignore_state=ignore_state, + ) + # Handle the responses that are streaming, this will block until all the sentences are published for TTS + if isinstance(response, AgentStreamResponse): + self.stream_responses.put(response) + # Long timeout for any unforeseen errors + response.stream_finished.wait(timeout=response.stream_response_timeout) + response.stream_finished.set() + + actions = response.attachment.get("actions", []) + if actions: + try: + self._execute_action(actions) + except Exception as ex: + logger.error(ex) + input_context = response.attachment.get("input_context", []) + for c in input_context: + # consume context + key = f"context.output.{c}" + logger.info("Consume context %r", key) + del self.session_context[key] + + def _execute_action(self, actions): + for action in actions: + if action.get("name") == "switch-language": + lang = action["properties"]["lang"] + success = self.set_language(lang) + if success: + logger.warning("Set target language %r successfully", lang) + if action.get("name") == "play-audio-clip": + text = action["properties"]["text"] + audio = action["properties"]["audio-clip"] + lang = action["properties"]["lang"] + if not audio.startswith("/"): + # relative to the upload dir + audio = os.path.join(SOULTALK_HOT_UPLOAD_DIR, audio) + self._say(text, lang, agent_id="Action", audio_path=audio) + if action.get("name") == "NEXTTOPIC" and self.cfg.auto_fire_arf: + logger.warning("Next topic") + self.fire_arf() + if action.get("name") == "set": + for key, value in action["properties"].items(): + self.session_context[key] = value + logger.warning("Set %s=%s", key, value) + if key == "arf_count_down" and self.cfg.auto_fire_arf: + if self.arf_fire_timer: + self.arf_fire_timer.cancel() + logger.warning("Auto ARF counting down %s", value) + self.arf_fire_timer = threading.Timer(value, self.auto_fire_arf) + self.arf_fire_timer.run() + + def auto_fire_arf(self): + while self.cfg.auto_fire_arf: + logger.warning("Checking auto arf") + if self.user_speaking: + logger.warning("User is speaking, do not fire new ARF") + time.sleep(5) + continue + if self.session_context.get("block_chat"): + logger.warning("Chat is blocked, do not fire new ARF") + time.sleep(5) + continue + if ( + "topic_type" in self.last_chosen_response.attachment + and self.last_chosen_response.attachment["topic_type"] == "ARF" + ): + logger.warning("The last ARF hasn't been answered, do not fire new ARF") + time.sleep(5) + continue + if self.last_chosen_response and self.last_chosen_response.answer.endswith( + "?" + ): + logger.warning("Wait for user answering the question") + time.sleep(15) + continue + self.fire_arf() + + def ros_set_output_context(self, context: dict, output=False, finished=False): + logger.info("ros set context: %r", context) + if "output" in context: + output = context.pop("output") + if "finished" in context: + finished = context.pop("finished") + try: + self.server.set_context(self.session_context.sid, context, output, finished) + return True, "" + except Exception as ex: + logger.exception(ex) + return False, str(ex) + + def ros_service_set_context(self, req): + response = StringTrigger._response_class() + try: + context = json.loads(req.data) + except Exception as ex: + logger.exception(ex) + response.success = False + response.message = str(ex) + return response + if context: + success, message = self.ros_set_output_context(context) + if not success: + response.success = success + response.message = message + return response + response.success = True + response.message = "ok" + return response + + def ros_service_get_context(self, req): + response = StringTrigger._response_class() + try: + context = self.server.get_context(self.session_context.sid) + if req.data: + context = context.get(req.data) + if context: + response.message = json.dumps(context) + response.success = True + except Exception as ex: + logger.error(ex) + response.success = False + response.message = str(ex) + return response + + def ros_service_register(self, req): + response = AgentRegister._response_class() + try: + agent = ROSGenericAgent( + req.node, + req.languages, + level=req.level, + weight=req.weight, + ttl=req.ttl, + ) + self.server.agents[agent.id] = agent + response.success = True + except Exception as ex: + logger.error(ex) + response.success = False + response.message = str(ex) + return response + + def ros_service_unregister(self, req): + response = AgentUnregister._response_class() + if not req.node: + response.success = False + response.message = "Agent id was missing" + return response + if req.node in self.server.agents: + del self.server.agents[req.node] + response.success = True + else: + response.success = False + response.message = ( + 'Agent "%s" was not registered or has already unregistered' % req.node + ) + return response + + def list_all_installed_agents(self, req): + response = StringArray._response_class() + response.data.extend( + ["%s:%s" % (agent.type, agent.id) for agent in self.server.agents.values()] + ) + return response + + def list_all_scenes(self, req): + response = StringArray._response_class() + response.data.extend(self.scenes.keys()) + return response + + def set_character(self, character): + self.switch_character_pub.publish(character) + + def set_robot_mode(self, cfg): + logger.warning("Set robot mode %s", cfg.robot_mode) + config = self.modes_config.get(cfg.robot_mode) + if config: + for node, node_config in config.items(): + if node == "/hr/interaction/chatbot": + try: + for key, value in node_config.items(): + setattr(cfg, key, value) + except Exception as ex: + logger.error(ex) + else: + try: + client = Client(node, timeout=10) + client.update_configuration(node_config) + logger.info("Updated node %r config %s", node, node_config) + except Exception as ex: + logger.error(ex) + return cfg + + def reconfig(self, config, level): + if self.cfg is None: + config.listen_speech = True + self.cfg = config + self.start() + + self.cfg = config + + if self.enable != config.enable: + self.enable = config.enable + if self.enable: + logger.warning("Enabled chatbot") + else: + logger.warning("Disabled chatbot") + if self.hybrid_mode: + logger.info("Enabled hybrid mode") + self.state.update(last_auto_response_time=None) + else: + logger.info("Disabled hybrid mode") + + self.server.set_timeout(config.timeout, config.min_wait_for) + + # set controllers + self.controller_manager.setup_controllers(self.cfg) + + if config.enable_global_workspace_drivers: + logger.warning("Enabling global workspace drivers") + self.driver_reconfiguration.enable_global_workspace() + # boost engagement level + # so it can turn off GW when engagement level is dropped to NONE + self.activity_monitor.record_chat_activity() + else: + logger.warning("Disabling global workspace drivers") + self.driver_reconfiguration.disable_global_workspace() + self.driver_reconfiguration.auto_global_workspace = config.auto_global_workspace + + # set character + # if self.cfg.character: + # self.set_character(self.cfg.character) + if self.robot_mode != config.robot_mode: + self.cfg = self.set_robot_mode(self.cfg) + self.robot_mode = config.robot_mode + + if self.cfg.enable_emotion_driven_response_primer: + logger.warning("Emotion driven response primer is enabled") + self.session_context["emotion_driven_response_primer"] = True + else: + logger.warning("Emotion driven response primer is disabled") + self.session_context["emotion_driven_response_primer"] = False + + return self.cfg + + +if __name__ == "__main__": + rospy.init_node("chatbot") + bot = Chatbot() + rospy.spin() diff --git a/modules/ros_chatbot/scripts/server.py b/modules/ros_chatbot/scripts/server.py new file mode 100644 index 0000000..3a8f289 --- /dev/null +++ b/modules/ros_chatbot/scripts/server.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python + +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import os +import sys +import logging + +import coloredlogs +from gevent.pywsgi import WSGIServer +from flask import Flask +from flask import jsonify +from flask import request + +from ros_chatbot.chat_server import ChatServer + + +class ChatbotServerRestAPIWrapper(object): + def __init__(self): + self.server = ChatServer() + self.default_language = "en-US" + + def session(self): + data = request.args + sid = data.get("sid") + sid = self.server.session(sid) + json_response = {} + json_response["err_code"] = 0 + json_response["err_msg"] = "" + json_response["response"] = {"sid": sid} + return jsonify(json_response) + + def chat(self): + """ + parameters: + ----------- + sid: session id + text: input question + lang: language code + context: chat context + mode: resolving/ranking + """ + data = request.args + sid = data.get("sid") + text = data.get("text") + lang = data.get("lang", self.default_language) + context = data.get("context") + mode = data.get("mode", "resolving") + + json_response = {} + json_response["err_code"] = 0 + json_response["err_msg"] = "" + + response = None + try: + chat_request = self.server.new_request(sid, text, lang, context=context) + response = self.server.chat(chat_request, mode) + except Exception as ex: + json_response["err_code"] = 1 + json_response["err_msg"] = ex.message + + if isinstance(response, list): + json_responses = [] + for _response in response: + json_responses.append(_response.to_dict()) + json_response["response"] = json_responses + elif response: + json_response["response"] = response.to_dict() + return jsonify(json_response) + + def publish(self): + data = request.args + agent_id = data.get("agent_id") + request_id = data.get("request_id") + lang = data.get("lang", self.default_language) + answer = data.get("answer") + label = data.get("label") + + self.server.publish(agent_id, request_id, lang, answer, label) + + json_response = {} + json_response["err_code"] = 0 + json_response["err_msg"] = "" + return jsonify(json_response) + + def status(self): + json_response = {} + json_response["err_code"] = 0 + json_response["err_msg"] = "" + return jsonify(json_response) + + +def create_server(args): + chatbot = ChatbotServerRestAPIWrapper() + + app = Flask(__name__) + app.add_url_rule("/session", "session", chatbot.session) + app.add_url_rule("/chat", "chat", chatbot.chat) + app.add_url_rule("/publish", "publish", chatbot.publish) + app.add_url_rule("/status", "status", chatbot.status) + + server = WSGIServer((args.host, args.port), app) + return server + + +def main(): + import argparse + + parser = argparse.ArgumentParser("Chatbot Server") + + parser.add_argument( + "--port", dest="port", type=int, default=9100, help="Server port" + ) + parser.add_argument("--host", dest="host", default="localhost", help="Server host") + + if "coloredlogs" in sys.modules and os.isatty(2): + formatter_str = "%(asctime)s %(levelname)-7s %(name)s: %(message)s" + coloredlogs.install(logging.INFO, fmt=formatter_str) + + args = parser.parse_args() + server = create_server(args) + server.serve_forever() + + +if __name__ == "__main__": + main() diff --git a/modules/ros_chatbot/scripts/services.py b/modules/ros_chatbot/scripts/services.py new file mode 100755 index 0000000..6ed2efc --- /dev/null +++ b/modules/ros_chatbot/scripts/services.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python + +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import base64 +import json +import logging +import os +import threading +import uuid +from datetime import datetime, timezone + +import requests +import rospy +from haipy.nlp.intent_classifier import IntentDetector +from hr_msgs.srv import GetIntent, StringTrigger + +# Subscribe to the compressed image type instead of the raw Image type +from sensor_msgs.msg import CompressedImage + +logger = logging.getLogger("hr.ros_chatbot.services") + +import cv2 +from cv_bridge import CvBridge, CvBridgeError + +bridge = CvBridge() +api_key = os.environ.get("OPENAI_API_KEY") + +IMAGE_FEED_HOME = os.environ.get("IMAGE_FEED_HOME", "/tmp/camera_feed") +openai_url = "https://api.openai.com" +alt_openai_url = os.environ.get("HR_OPENAI_PROXY", openai_url) +logger.warning(f"Using OpenAI URL: {alt_openai_url}") + +class ChatbotServices(object): + def __init__(self): + self._intent_detector = IntentDetector("rasa") + image_topic = rospy.get_param( + "~image_topic", "/hr/perception/jetson/realsense/camera/color/image_raw/compressed" + ) + # Subscribe using the CompressedImage message type + rospy.Subscriber(image_topic, CompressedImage, self._fresh_image) + rospy.Service("get_intent", GetIntent, self.get_intent) + rospy.Service("describe_view", StringTrigger, self.describe_view) + self.current_cv2_image = None + self.lock = threading.RLock() + # self.model = "gpt-4-vision-preview" + self.model = "gpt-4o" + self.current_openai_url = openai_url + # GMT time + self.last_active_time = datetime.now(timezone.utc) + + def _fresh_image(self, msg): + try: + with self.lock: + # Convert the compressed image message to an OpenCV image + self.current_cv2_image = bridge.compressed_imgmsg_to_cv2(msg, "bgr8") + self.last_active_time = datetime.now(timezone.utc) + except CvBridgeError as ex: + logger.error(ex) + + def get_intent(self, req): + response = GetIntent._response_class() + try: + result = self._intent_detector.detect_intent(req.text, req.lang) + if result: + response.intent = result["intent"]["name"] + response.confidence = result["intent"]["confidence"] + except Exception as ex: + logger.error(ex) + return response + + def describe_view(self, req): + ret = StringTrigger._response_class() + ret.success = False + try: + if not os.path.isdir(IMAGE_FEED_HOME): + os.makedirs(IMAGE_FEED_HOME) + datetime_str = datetime.strftime(datetime.now(), "%Y%m%d%H%M%S") + hex = uuid.uuid4().hex + image_file_name = f"{datetime_str}-{hex}.jpg" + image_path = os.path.join(IMAGE_FEED_HOME, image_file_name) + current_time = None + with self.lock: + if self.current_cv2_image is not None: + cv2.imwrite(image_path, self.current_cv2_image) + current_time = self.last_active_time + else: + ret.success = False + ret.message = "No image feed" + return ret + + def encode_image(image_path): + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + base64_image = encode_image(image_path) + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + } + payload = { + "model": self.model, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": req.data}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}" + }, + }, + ], + } + ], + "max_tokens": 800, + } + logger.warning("Sending image to GPT vision") + response = requests.post( + f"{self.current_openai_url}/v1/chat/completions", headers=headers, json=payload + ).json() + if "error" in response and response["error"]: + ret.success = False + ret.message = response["error"]["message"] + self.current_openai_url = alt_openai_url if self.current_openai_url == openai_url else openai_url + else: + ret.success = True + ret.message = json.dumps( + { + "content": response["choices"][0]["message"]["content"], + "image_file_name": image_file_name, + "time": {"value": datetime_str, "format": "%Y%m%d%H%M%S"}, + "utc_time": current_time.isoformat(), + } + ) + return ret + except Exception as ex: + logger.error(ex) + ret.message = str(ex) + return ret + + +if __name__ == "__main__": + rospy.init_node("chatbot_services") + ChatbotServices() + rospy.spin() diff --git a/modules/ros_chatbot/scripts/tg.py b/modules/ros_chatbot/scripts/tg.py new file mode 100644 index 0000000..c3ba029 --- /dev/null +++ b/modules/ros_chatbot/scripts/tg.py @@ -0,0 +1,149 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import asyncio +import argparse +import queue +import logging +from functools import partial + +from telethon import TelegramClient, events +from telethon.network import ConnectionTcpAbridged +from telethon.tl import types +from telethon.utils import get_display_name + +logger = logging.getLogger(__name__) + + +def get_user_id(entity): + if isinstance(entity, types.User): + return str(entity.id) + + +def generate_session(session, api_id, api_hash): + """Generates session file""" + with TelegramClient(session, api_id, api_hash) as client: + pass + + +async def display_users(session, api_id, api_hash): + async with TelegramClient(session, api_id, api_hash) as client: + dialogs = await client.get_dialogs(limit=10) + for dialog in dialogs: + print( + "Name: %s, id: %s" + % (get_display_name(dialog.entity), get_user_id(dialog.entity)) + ) + + +class MyTelegramClient(TelegramClient): + def __init__(self, session_user_id, api_id, api_hash, id): + super().__init__( + session_user_id, api_id, api_hash, connection=ConnectionTcpAbridged + ) + loop.run_until_complete(self.connect()) + if not loop.run_until_complete(self.is_user_authorized()): + raise RuntimeError("User is not authorized") + self.answer = queue.Queue(maxsize=10) + self.id = id + + def get_dialog_entity(self, dialogs): + for dialog in dialogs: + if get_user_id(dialog.entity) == self.id: + return dialog.entity + + async def chat(self, msg, timeout=None): + if timeout == -1: + timeout = None + self.add_event_handler(self.message_handler, events.NewMessage) + + dialogs = await self.get_dialogs(limit=10) + self.entity = self.get_dialog_entity(dialogs) + + if msg: + await self.send_message(self.entity, msg, link_preview=False) + + try: + answer = await loop.run_in_executor( + None, partial(self.answer.get, timeout=timeout) + ) + except queue.Empty: + answer = None + return answer + + async def message_handler(self, event): + chat = await event.get_chat() + if get_user_id(chat) == self.id: + if chat.is_self: + # self chat + self.answer.put(event.text) + else: + if not event.is_group and not event.out: + self.answer.put(event.text) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + subparser = parser.add_subparsers() + + session_parser = subparser.add_parser("session", help="Generate session file") + session_parser.add_argument("--api-id", required=True, help="the telegram api_id") + session_parser.add_argument( + "--api-hash", required=True, help="the telegram api_hash" + ) + session_parser.add_argument( + "--session", required=True, help="the session file to be created" + ) + session_parser.set_defaults(run="session") + + chat_parser = subparser.add_parser("chat") + chat_parser.add_argument("--api-id", required=True, help="the telegram api_id") + chat_parser.add_argument("--api-hash", required=True, help="the telegram api_hash") + chat_parser.add_argument("--session", required=True, help="the session file") + chat_parser.add_argument("--id", required=True, help="the user id") + chat_parser.add_argument( + "--question", required=True, help="the question to send to telegram chatbot" + ) + chat_parser.add_argument("--timeout", type=float, help="timeout") + chat_parser.set_defaults(run="chat") + + users_parser = subparser.add_parser("users") + users_parser.add_argument("--api-id", required=True, help="the telegram api_id") + users_parser.add_argument("--api-hash", required=True, help="the telegram api_hash") + users_parser.add_argument( + "--session", required=True, help="the session file to be created" + ) + users_parser.set_defaults(run="users") + + args = parser.parse_args() + + if hasattr(args, "run"): + if args.run == "chat": + loop = asyncio.get_event_loop() + client = MyTelegramClient(args.session, args.api_id, args.api_hash, args.id) + answer = loop.run_until_complete(client.chat(args.question, args.timeout)) + if answer: + print(answer) + elif args.run == "session": + generate_session(args.session, args.api_id, args.api_hash) + elif args.run == "users": + loop = asyncio.get_event_loop() + loop.run_until_complete( + display_users(args.session, args.api_id, args.api_hash) + ) + else: + parser.print_usage() diff --git a/modules/ros_chatbot/setup.py b/modules/ros_chatbot/setup.py new file mode 100644 index 0000000..dc21016 --- /dev/null +++ b/modules/ros_chatbot/setup.py @@ -0,0 +1,37 @@ +#!/usr/bin/env bash + +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +# DO NOT USE +# python setup.py install + +from distutils.core import setup + +setup( + version="0.7.0", + name="ros_chatbot", + packages=[ + "ros_chatbot", + "ros_chatbot.pyaiml", + "ros_chatbot.agents", + "ros_chatbot.interact", + "ros_chatbot.interact.controllers", + ], + package_dir={"": "src"}, + package_data={"ros_chatbot": ["agents/*.so"]}, +) diff --git a/modules/ros_chatbot/src/ros_chatbot/__init__.py b/modules/ros_chatbot/src/ros_chatbot/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/ros_chatbot/src/ros_chatbot/action_parser.py b/modules/ros_chatbot/src/ros_chatbot/action_parser.py new file mode 100644 index 0000000..89c8843 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/action_parser.py @@ -0,0 +1,198 @@ +# -*- coding: utf-8 -*- + +## +## Copyright (C) 2013-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import re +import logging +import xml.etree.ElementTree as etree +import io +from collections import defaultdict + +logger = logging.getLogger(__name__) + + +class Pattern(object): + def __init__(self, pattern): + self.pattern = pattern + self.pattern_re = re.compile(self.pattern, re.DOTALL | re.UNICODE) + + def match(self, text): + return self.pattern_re.match(text) + + def get_nodes(self, match): + return NotImplemented + + def __repr__(self): + return self.__class__.__name__ + + +class MarkPattern(Pattern): + def __init__(self): + super(MarkPattern, self).__init__(r"^(.*?)(\|)([^\|]+)\2(.*)$") + + def get_nodes(self, match): + name = match.group(3) + args = "" + if "," in name: + name, args = name.split(",", 1) + args = args.strip() + name = name.strip() + if name == "pause": + if args: + time = args + else: + time = "1s" + if not time.endswith("s"): + time = time + "s" + el = etree.Element("break") + el.set("time", time) + elif name == "c" or name == "context": + el = etree.Element("context") + el.set("name", args) + el.set("finished", False) + elif name == "f" or name == "finished": + el = etree.Element("context") + el.set("name", args) + el.set("finished", True) + else: + el = etree.Element("mark") + el.set("name", name) + return (el,) + + +class ActionResult(object): + def __init__(self, nodes): + self.nodes = nodes + + def to_dict(self): + output = [] + data = defaultdict(list) + for node in self.nodes: + if isinstance(node, str): + output.append(node) + elif node.tag == "mark": + mark = node.attrib["name"] + if mark: + data["marks"].append(mark) + elif node.tag == "context": + context = node.attrib["name"] + finished = node.attrib["finished"] + if context: + context = [ + {"name": c.strip(), "finished": finished} + for c in context.split(",") + ] + data["context"].extend(context) + text = "".join(output) + text = text.strip() + data["text"] = text + data = dict(data) + return data + + def to_xml(self): + output = [] + + for node in self.nodes: + if isinstance(node, str): + output.append(node) + else: + buf = io.BytesIO() + tree = etree.ElementTree(node) + tree.write(buf, encoding="utf-8") + value = buf.getvalue() + value = value.decode("utf-8") + output.append(value) + buf.close() + return "".join(output) + + +class ActionParser(object): + def __init__(self): + self.patterns = [] + self.build_patterns() + self.recognized_nodes = {} + self.counter = 0 + self.sep = "0x1f" + + def reset(self): + self.counter = 0 + self.recognized_nodes.clear() + + def build_patterns(self): + self.patterns.append(MarkPattern()) + + def add_recognized_nodes(self, node): + id = "sss{}eee".format(self.counter) + self.recognized_nodes[id] = node + self.counter += 1 + return id + + def recover_recognized_nodes(self, text): + tokens = text.split(self.sep) + nodes = [] + for token in tokens: + if token in self.recognized_nodes: + node = self.recognized_nodes.get(token) + nodes.append(node) + else: + nodes.append(token) + return nodes + + def parse(self, text): + text = text.strip() + self.reset() + pattern_index = 0 + while pattern_index < len(self.patterns): + pattern = self.patterns[pattern_index] + match = pattern.match(text) + + # Search all the matches then try the next pattern + if not match: + pattern_index += 1 + continue + + try: + nodes = pattern.get_nodes(match) + except Exception as ex: + logger.error(ex) + nodes = [""] # replace the pattern with an empty string + place_holders = [] + for node in nodes: + if not isinstance(node, str): + id = self.add_recognized_nodes(node) + place_holders.append(id) + else: + place_holders.append(node) + text = "{}{}{}{}{}".format( + match.group(1), + self.sep, + self.sep.join(place_holders), + self.sep, + match.groups()[-1], + ) + + nodes = self.recover_recognized_nodes(text) + return ActionResult(nodes) + + +if __name__ == "__main__": + logging.basicConfig() + parser = ActionParser() + # print(parser.parse("|happy| |c, abc| test context").to_xml()) + # print(parser.parse("|happy| |c, 测试, abc| 测试").to_dict()) + print(parser.parse("|happy| |c, 测试, abc| |f, finished, abc| 测试").to_dict()) diff --git a/modules/ros_chatbot/src/ros_chatbot/activity_monitor.py b/modules/ros_chatbot/src/ros_chatbot/activity_monitor.py new file mode 100644 index 0000000..353a337 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/activity_monitor.py @@ -0,0 +1,176 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import time +from collections import deque +from enum import IntEnum +from typing import Callable, List + +logger = logging.getLogger(__name__) + + +class EngagementLevel(IntEnum): + """Engagement levels from lowest to highest""" + + NONE = 0 # No engagement at all + MINIMAL = 1 # Very low engagement + MODERATE = 2 # Some engagement + HIGH = 3 # Active engagement + INTENSIVE = 4 # Very high engagement + + +class ActivityMonitor: + def __init__( + self, + window_size: int = 300, + engagement_level_change_callbacks: List[Callable] = None, + ): + """ + Initialize an activity monitor that tracks engagement based on frequency of interactions + + Args: + window_size: Size of the time window in seconds over which to calculate activity rates + """ + # Configuration + self.window_size = window_size + self.engagement_level_change_callbacks = engagement_level_change_callbacks or [] + + # Activity trackers with timestamps + self.chat_activities = deque() + self.speech_activities = deque() + + # Current engagement metrics + self.chat_frequency = 0.0 # messages per minute + self.speech_frequency = 0.0 # speech inputs per minute + self.current_engagement = ( + EngagementLevel.MODERATE + ) # Start with moderate engagement + + # Engagement thresholds (interactions per minute) + self.engagement_thresholds = { + EngagementLevel.NONE: 0.0, + EngagementLevel.MINIMAL: 0.1, # 1 interaction per 10 minutes + EngagementLevel.MODERATE: 0.5, # 1 interaction per 2 minutes + EngagementLevel.HIGH: 2.0, # 2 interactions per minute + EngagementLevel.INTENSIVE: 5.0, # 5+ interactions per minute + } + + # Initialize time of last engagement check + self.last_check_time = time.time() + + def add_engagement_level_change_callback(self, callback: Callable): + """Add a callback to be called when the engagement level changes""" + self.engagement_level_change_callbacks.append(callback) + + def record_chat_activity(self): + """Record a new chat activity""" + current_time = time.time() + self.chat_activities.append(current_time) + self._prune_old_activities() + + def record_speech_activity(self): + """Record a new speech activity""" + current_time = time.time() + self.speech_activities.append(current_time) + self._prune_old_activities() + + def _prune_old_activities(self): + """Remove activities outside the current time window""" + current_time = time.time() + cutoff_time = current_time - self.window_size + + # Remove old chat activities + while self.chat_activities and self.chat_activities[0] < cutoff_time: + self.chat_activities.popleft() + + # Remove old speech activities + while self.speech_activities and self.speech_activities[0] < cutoff_time: + self.speech_activities.popleft() + + def _calculate_frequencies(self): + """Calculate the frequency of activities in interactions per minute""" + self._prune_old_activities() + + # Calculate rates in activities per minute + minutes_in_window = min( + self.window_size / 60.0, 5.0 + ) # Cap at 5 minutes for more responsive changes + + self.chat_frequency = ( + len(self.chat_activities) / minutes_in_window + if minutes_in_window > 0 + else 0 + ) + self.speech_frequency = ( + len(self.speech_activities) / minutes_in_window + if minutes_in_window > 0 + else 0 + ) + + # Calculate combined frequency (weighted average) + return self.chat_frequency * 0.6 + self.speech_frequency * 0.4 + + def determine_engagement_level(self): + """Determine the current engagement level based on activity frequencies""" + combined_frequency = self._calculate_frequencies() + + # Determine engagement level based on thresholds + if combined_frequency < self.engagement_thresholds[EngagementLevel.MINIMAL]: + return EngagementLevel.NONE + elif combined_frequency < self.engagement_thresholds[EngagementLevel.MODERATE]: + return EngagementLevel.MINIMAL + elif combined_frequency < self.engagement_thresholds[EngagementLevel.HIGH]: + return EngagementLevel.MODERATE + elif combined_frequency < self.engagement_thresholds[EngagementLevel.INTENSIVE]: + return EngagementLevel.HIGH + else: + return EngagementLevel.INTENSIVE + + def monitoring(self) -> EngagementLevel: + """Monitor engagement levels and take appropriate actions + + This should be called periodically (e.g., every few seconds) + """ + current_time = time.time() + + # Check engagement no more than once per second + if current_time - self.last_check_time < 1.0: + return + + self.last_check_time = current_time + + new_level = self.determine_engagement_level() + + if new_level != self.current_engagement: + self.current_engagement = new_level + # Call all registered callbacks + for callback in self.engagement_level_change_callbacks: + callback(new_level) + + return new_level + + def get_engagement_metrics(self): + """Get detailed engagement metrics for logging or UI display""" + return { + "engagement_level": self.current_engagement.name, + "engagement_value": self.current_engagement.value, + "chat_frequency": round(self.chat_frequency, 2), + "speech_frequency": round(self.speech_frequency, 2), + "chat_activities_count": len(self.chat_activities), + "speech_activities_count": len(self.speech_activities), + } diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/__init__.py b/modules/ros_chatbot/src/ros_chatbot/agents/__init__.py new file mode 100644 index 0000000..350a748 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/__init__.py @@ -0,0 +1,85 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +from .ai21 import AI21Agent +from .aiml import AIMLAgent +from .baidu_unit import BaiduUnitAgent +from .blenderbot import BlenderBotAgent +from .chatgpt import ChatGPTAgent, GPT4Agent +from .chatgpt_web import ChatGPTWebAgent +from .chatscript import ChatScriptAgent +from .ddg import DDGAgent +from .dummy import DummyAgent +from .gpt2 import GPT2Agent +from .gpt3 import GPT3Agent +from .legend_chat import LegendChatAgent +from .llama import LlamaAgent +from .llm_chat import ( + ClaudeChatAgent, + LlamaChatAgent, + LLMChatAgent, + OpenAIChatAgent, + ToolCallingLLMChatAgent, +) +from .qa import QAAgent +from .quickchat import QuickChatAgent +from .quicksearch import QuickSearchAgent +from .rosagent import ROSGenericAgent +from .snet import SNetAgent +from .soultalk import SoulTalkAgent +from .tg_agent import TGAgent +from .translator import TranslatorAgent +from .vector_chat import VectorChatAgent +from .xiaoi import XiaoIAgent +from .xiaoice import XiaoIceAgent +from .youchat import YouChatAgent + +_agent_classes = [ + AI21Agent, + AIMLAgent, + BaiduUnitAgent, + BlenderBotAgent, + ChatGPTAgent, + ChatGPTWebAgent, + ChatScriptAgent, + ClaudeChatAgent, + DDGAgent, + DummyAgent, + GPT2Agent, + GPT3Agent, + GPT4Agent, + LLMChatAgent, + LegendChatAgent, + LlamaAgent, + LlamaChatAgent, + OpenAIChatAgent, + QAAgent, + QuickChatAgent, + QuickSearchAgent, + ROSGenericAgent, + SNetAgent, + SoulTalkAgent, + TGAgent, + ToolCallingLLMChatAgent, + TranslatorAgent, + VectorChatAgent, + XiaoIAgent, + XiaoIceAgent, + YouChatAgent, +] + +registered_agents = {cls.type: cls for cls in _agent_classes} diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/ai21.py b/modules/ros_chatbot/src/ros_chatbot/agents/ai21.py new file mode 100644 index 0000000..44c5cc5 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/ai21.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- + +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import os +import random +import uuid + +import requests + +from .model import AgentResponse, SessionizedAgent + +logger = logging.getLogger("hr.ros_chatbot.agents.ai21") + + +class WebAPI(object): + def __init__(self): + self.API_KEY = os.environ.get("AI21_API_KEY") + if not self.API_KEY: + raise ValueError("API KEY is required") + self.chat_server = "https://api.ai21.com/studio/v1/j1-large/complete" + + def ask(self, prompt, topK=1, temperature=1.0, maxTokens=140): + response = requests.post( + self.chat_server, + headers={"Authorization": f"Bearer {self.API_KEY}"}, + json={ + "prompt": prompt, + "numResults": 1, + "maxTokens": maxTokens, + "stopSequences": [".", "\n"], + "topKReturn": topK, + "temperature": temperature, + }, + ) + if response and response.status_code == 200: + data = response.json() + return data["completions"][0]["data"]["text"] + + +class AI21Agent(SessionizedAgent): + type = "AI21Agent" + + def __init__(self, id, lang, media_agent): + super(AI21Agent, self).__init__(id, lang) + self.api = WebAPI() + + if media_agent is None: + raise ValueError("Media agent cannot be None") + self.media_agent = media_agent + self.prompt_length = 5 + self.topK = 1 + self.temperature = 1.0 + self.maxTokens = 140 + + def set_config(self, config, base): + super(AI21Agent, self).set_config(config, base) + if "topK" in self.config: + self.topK = self.config["topK"] + if "temperature" in self.config: + self.temperature = self.config["temperature"] + if "maxTokens" in self.config: + self.maxTokens = self.config["maxTokens"] + if "prompt_length" in self.config: + self.prompt_length = self.config["prompt_length"] + + def new_session(self): + if isinstance(self.media_agent, SessionizedAgent): + return self.media_agent.new_session() + else: + return str(uuid.uuid4()) + + def chat(self, agent_sid, request): + if agent_sid is None: + logger.error("Agent session was not provided") + return + + response = AgentResponse() + response.agent_sid = agent_sid + response.sid = request.sid + response.request_id = request.request_id + response.response_id = str(uuid.uuid4()) + response.agent_id = self.id + response.lang = request.lang + response.question = request.question + + if request.question: + agent_response = self.media_agent.chat(agent_sid, request) + if agent_response and agent_response.valid(): + try: + prompt = " ".join( + agent_response.answer.split()[: self.prompt_length] + ) + except Exception as ex: + logger.error(ex) + prompt = "" + answer = self.api.ask( + prompt, self.topK, self.temperature, self.maxTokens + ) + if answer: + response.answer = prompt + " " + answer + self.score(response) + else: + response.trace = "No answer" + else: + response.trace = "Can't answer" + + response.end() + return response + + def score(self, response): + response.attachment["score"] = 50 + input_len = len(response.question) + if input_len > 100: + response.attachment["score"] = 60 + response.attachment["score"] += random.randint( + -10, 10 + ) # give it some randomness + + if ( + "minimum_score" in self.config + and response.attachment["score"] < self.config["minimum_score"] + ): + logger.info( + "Score didn't pass lower threshold: %s", response.attachment["score"] + ) + response.attachment["score"] = -1 + if ( + "allow_question_response" in self.config + and not self.config["allow_question_response"] + and "?" in response.answer + ): + response.attachment["score"] = -1 + logger.warning("Question response is not allowed") + + if response.attachment.get("blocked"): + response.attachment["score"] = -1 diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/aiml.py b/modules/ros_chatbot/src/ros_chatbot/agents/aiml.py new file mode 100644 index 0000000..02f17ad --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/aiml.py @@ -0,0 +1,265 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import os +import random +import re +import uuid +from collections import defaultdict +from pprint import pformat + +import yaml + +from ros_chatbot.pyaiml import Kernel +from ros_chatbot.utils import abs_path, shorten + +from .model import AgentResponse, SessionizedAgent + +logger = logging.getLogger("hr.ros_chatbot.agents.aiml") + + +class AIMLAgent(SessionizedAgent): + type = "AIMLAgent" + + def __init__(self, id, lang, character_yaml): + super(AIMLAgent, self).__init__(id, lang) + self.aiml_files = [] + self.kernel = Kernel() + self.kernel.verbose(False) + self.current_topic = "" + self.trace_pattern = re.compile( + r".*/(?P.*), (?P\(.*\)), (?P.*), (?P\(.*\))" + ) + self.properties = {} + self.load(character_yaml) + self.base = os.path.dirname(os.path.expanduser(character_yaml)) + if not self.base.endswith("/"): + self.base = self.base + "/" + + def load(self, character_yaml): + logger.info("Loading character") + with open(character_yaml) as f: + config = yaml.safe_load(f) + try: + errors = [] + root_dir = os.path.dirname(os.path.realpath(character_yaml)) + if "property_file" in config: + self.set_property_file(abs_path(root_dir, config["property_file"])) + if "aiml" in config: + aiml_files = [abs_path(root_dir, f) for f in config["aiml"]] + errors = self.load_aiml_files(self.kernel, aiml_files) + self.print_duplicated_patterns() + if errors: + raise Exception( + "Loading {} error {}".format(character_yaml, "\n".join(errors)) + ) + except KeyError as ex: + logger.exception(ex) + + def get_properties(self): + return self.properties + + def replace_aiml_abs_path(self, trace): + if isinstance(trace, list): + trace = [f.replace(self.base, "") for f in trace] + return trace + + def load_aiml_files(self, kernel, aiml_files): + errors = [] + for f in aiml_files: + if "*" not in f and not os.path.isfile(f): + logger.warning("%s is not found", f) + errors.extend(kernel.learn(f)) + logger.debug("Load %s", f) + if f not in self.aiml_files: + self.aiml_files.append(f) + return errors + + def set_property_file(self, propname): + try: + with open(propname) as f: + for line in f: + line = line.strip() + if not line: + continue + parts = line.split("=") + key = parts[0].strip() + value = parts[1].strip() + self.kernel.setBotPredicate(key, value) + self.properties[key] = value + logger.info("Set properties file %s", propname) + except Exception as ex: + logger.error("Couldn't open property file %r: %s", propname, ex) + + def new_session(self): + sid = str(uuid.uuid4()) + return sid + + def set_properties(self, props): + for key, value in self.properties.items(): + self.kernel.setBotPredicate(key, value) + + def reset_topic(self, sid): + self.current_topic = "" + self.kernel.setPredicate("topic", "", sid) + logger.info("Topic is reset") + + def chat(self, agent_sid, request): + response = AgentResponse() + response.agent_sid = agent_sid + response.sid = request.sid + response.request_id = request.request_id + response.response_id = str(uuid.uuid4()) + response.agent_id = self.id + response.lang = request.lang + response.question = request.question + + answer = self.kernel.respond(request.question, agent_sid, query=False) + if "response_limit" in self.config: + answer, _ = shorten(answer, self.config["response_limit"]) + + response.answer = answer + + response.attachment["emotion"] = self.kernel.getPredicate("emotion", agent_sid) + response.attachment["performance"] = self.kernel.getPredicate( + "performance", agent_sid + ) + response.attachment["topic"] = self.kernel.getPredicate("topic", agent_sid) + + traces = self.kernel.getTraceDocs() + if traces: + logger.debug("Trace: %s", traces) + patterns = [] + for trace in traces: + match_obj = self.trace_pattern.match(trace) + if match_obj: + patterns.append(match_obj.group("pname")) + response.attachment["pattern"] = patterns + if patterns: + first = patterns[0] + if "*" in first or "_" in first: + pattern_len = len(first.strip().split()) + if "*" not in first: + response.attachment["ok_match"] = True + if pattern_len > 3 and pattern_len > 0.9 * len( + request.question.strip().split() + ): + response.attachment["ok_match"] = True + else: + response.attachment["exact_match"] = True + traces = self.replace_aiml_abs_path(traces) + response.trace = "\n".join(traces) + self.score(response) + + response.end() + return response + + def reset(self, sid): + self.kernel._deleteSession(sid) + return sid + + def get_context(self, sid): + context = self.kernel.getSessionData(sid) or {} + + # remove internal context (stats with _) + for k in list(context.keys()): + if k.startswith("_"): + del context[k] + + return context + + def set_context(self, sid, context: dict): + for k, v in context.items(): + if v and isinstance(v, str): + if k.startswith("_"): + continue + self.kernel.setPredicate(k, v, sid) + logger.info("Set predicate %s=%s", k, v) + if k in ["firstname", "fullname"]: + self.kernel.setPredicate("name", v, sid) + + def remove_context(self, sid, key): + if key in list(self.get_context(sid).keys()): + del self.kernel._sessions[sid][key] + logger.info("Removed context %s", key) + return True + else: + logger.debug("No such context %s", key) + return False + + def get_templates(self): + templates = [] + root = self.kernel._brain._root + self.kernel._brain.get_templates(root, templates) + return templates + + def print_duplicated_patterns(self): + patterns = defaultdict(list) + for t in self.get_templates(): + key = (t[1]["pattern"].lower(), t[1]["that"].lower(), t[1]["topic"].lower()) + patterns[key].append(t[1]) + for pattern in patterns: + if len(patterns[pattern]) > 1: + logger.error( + "Duplicated patterns %s\n%s\n", + len(patterns[pattern]), + pformat(patterns[pattern]), + ) + + def said(self, session, text): + sid = session.sid + outputHistory = self.kernel.getPredicate(self.kernel._outputHistory, sid) + if isinstance(outputHistory, list): + outputHistory.append(text) + logger.info("Add '%s' to output history", text) + + def score(self, response): + response.attachment["score"] = 50 + if response.attachment.get("ok_match"): + response.attachment["score"] += 10 + elif response.attachment.get("exact_match"): + response.attachment["score"] += 20 + input_len = len(response.question) + if input_len > 100: + response.attachment["score"] -= 20 # penalty on long input + + # suppress long answer + if len(response.answer.split()) > 80: + response.attachment["score"] -= 20 + + response.attachment["score"] += random.randint( + -10, 10 + ) # give it some randomness + if ( + "minimum_score" in self.config + and response.attachment["score"] < self.config["minimum_score"] + ): + logger.info( + "Score didn't pass lower threshold: %s", response.attachment["score"] + ) + response.attachment["score"] = -1 + if ( + "allow_question_response" in self.config + and not self.config["allow_question_response"] + and "?" in response.answer + ): + response.attachment["score"] = -1 + logger.warning("Question response is not allowed") + + if response.attachment.get("blocked"): + response.attachment["score"] = -1 diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/baidu_unit.py b/modules/ros_chatbot/src/ros_chatbot/agents/baidu_unit.py new file mode 100644 index 0000000..ff55468 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/baidu_unit.py @@ -0,0 +1,172 @@ +# -*- coding: utf-8 -*- + +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import json +import logging +import os +import re +import uuid + +import requests +import six + +from .model import AgentResponse, SessionizedAgent + +logger = logging.getLogger("hr.ros_chatbot.agents.baidu_unit") + + +class WebUnit(object): + def __init__(self): + app_key_id = "BAIDU_UNIT_APP_KEY" + app_key_secret = "BAIDU_UNIT_APP_SECRET" + self.app_key_id = os.environ.get(app_key_id) + self.app_key_secret = os.environ.get(app_key_secret) + if not self.app_key_id: + raise ValueError("baidu app key was not provided") + if not self.app_key_secret: + raise ValueError("baidu app secret was not provided") + + self.auth_server = "https://aip.baidubce.com/oauth/2.0/token" + self.chat_server = "https://aip.baidubce.com/rpc/2.0/unit/service/chat" + self.bot_server = "https://aip.baidubce.com/rpc/2.0/unit/bot/chat" + self.access_token = None + self.sid = "" + + def ask(self, question): + if isinstance(question, six.binary_type): + question = question.decode("utf-8") + + access_token = self.get_access_token() + url = self.chat_server + "?access_token=" + access_token + post_data = { + "log_id": "UNITTEST_10000", + "version": "2.0", + "service_id": "S30732", + "session_id": self.sid, + "skill_ids": ["1031954", "1031624", "1031625", "1031626"], + "request": { + "query": question, + "user_id": "hr-user-20367856", + "query_info": {"source": "ASR"}, + }, + "dialog_state": { + "contexts": { + "SYS_REMEMBERED_SKILLS": [ + "1031954", + "1031624", + "1031625", + "1031626", + "1031627", + ] + } + }, + } + headers = {"content-type": "application/x-www-form-urlencoded"} + response = requests.post(url, data=json.dumps(post_data), headers=headers) + if response: + response = response.json() + if response["error_code"] == 0: + self.sid = response["result"]["session_id"] or "" + logger.info("Session %s", self.sid) + for result in response["result"]["response_list"]: + for action in result["action_list"]: + text = action["say"] + if isinstance(text, six.text_type): + text = text.encode("utf-8") + type_str = action["type"] + if isinstance(type_str, six.text_type): + type_str = type_str.encode("utf-8") + logger.info( + "Action say: %r confidence: %s type: %s", + text, + action["confidence"], + type_str, + ) + if action["type"] == "satisfy": + text = action["say"] + return text + if action["type"] == "chat" and action["confidence"] > 0.45: + text = action["say"] + return text + else: + logger.error(response["error_msg"]) + + def get_access_token(self): + if self.access_token: + # May need to check "expires_in" for longer session (e.g. over 1 month) + return self.access_token + params = { + "grant_type": "client_credentials", + "client_id": self.app_key_id, + "client_secret": self.app_key_secret, + } + response = requests.get(self.auth_server, params=params) + if response: + self.access_token = response.json()["access_token"] + return self.access_token + + def reset(self): + self.sid = "" + + +class BaiduUnitAgent(SessionizedAgent): + type = "BaiduUnitAgent" + name_patch = re.compile("(小度)") + + def __init__(self, id, lang): + super(BaiduUnitAgent, self).__init__(id, lang) + self.api = WebUnit() + + def new_session(self): + sid = str(uuid.uuid4()) + self.api.reset() + return sid + + def chat(self, agent_sid, request): + response = AgentResponse() + response.agent_sid = agent_sid + response.sid = request.sid + response.request_id = request.request_id + response.response_id = str(uuid.uuid4()) + response.agent_id = self.id + response.lang = request.lang + response.question = request.question + + try: + answer = self.api.ask(request.question) + answer = self.name_patch.sub("我", answer) + response.answer = answer + self.score(response) + except Exception as ex: + logger.error(ex) + response.end() + return response + + def score(self, response): + response.attachment["score"] = self.weight * 100 + if ( + "allow_question_response" in self.config + and not self.config["allow_question_response"] + and ("?" in response.answer or "?" in response.answer) + ): + response.attachment["score"] = -1 + logger.warning("Question response %s is not allowed", response.answer) + + if response.attachment.get("blocked"): + response.attachment["score"] = -1 diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/blenderbot.py b/modules/ros_chatbot/src/ros_chatbot/agents/blenderbot.py new file mode 100644 index 0000000..62530d3 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/blenderbot.py @@ -0,0 +1,275 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import random +import re +import uuid +from functools import partial + +import requests + +from ros_chatbot.utils import check_repeating_words, shorten, token_sub + +from .model import AgentResponse, SessionizedAgent + +logger = logging.getLogger("hr.ros_chatbot.agents.blenderbot") + + +class BlenderBotAgent(SessionizedAgent): + type = "BlenderBotAgent" + token_pattern = re.compile(r"\b( '\s?(m|ve|d|ll|t|s|re))\b") # such as ' m, ' ve + hyphen_pattern = re.compile(r"(\w+)\s+-\s+(\w+)") # such "human - like" + variable_pattern = re.compile( + r"\(\s*\(([^()]*)\)\s*\)", flags=re.IGNORECASE + ) # eg ((User)) + + goodbye_pattern = re.compile(r"\b(goodbye|bye|see you)\b") + + def __init__(self, id, lang, host="localhost", port=8105, timeout=2, persona=None): + super(BlenderBotAgent, self).__init__(id, lang) + self.host = host + self.port = port + if self.host not in ["localhost", "127.0.0.1"]: + logger.warning("blenderbot server: %s:%s", self.host, self.port) + self.timeout = timeout + self.default_persona = persona or [] + self.support_priming = True + + self.persona = [] + + def set_config(self, config, base): + super(BlenderBotAgent, self).set_config(config, base) + persona = self.config.get("persona", []) + if persona: + persona = [p.lower() for p in persona] + self.persona = persona[:] + self.set_persona() + + def new_session(self): + """The blenderbot doesn't maintain the session. Whenever it needs to + start a new conversation, it will simply reset the current session""" + sid = str(uuid.uuid4()) + self.reset() + return sid + + def reset(self, sid=None): + try: + requests.get( + "http://{host}:{port}/reset".format(host=self.host, port=self.port), + timeout=self.timeout, + ) + self.persona = self.default_persona[:] + self.set_persona() + except Exception as ex: + logger.error(ex) + + def ping(self): + try: + response = requests.get( + "http://{host}:{port}/status".format(host=self.host, port=self.port), + timeout=self.timeout, + ) + except Exception as ex: + logger.error(ex) + return False + if response.status_code == requests.codes.ok: + json = response.json() + if json["success"]: + return True + else: + logger.error( + "BlenderBot server %s:%s is not available", self.host, self.port + ) + return False + + def ask(self, question): + try: + response = requests.post( + "http://{host}:{port}/chat".format(host=self.host, port=self.port), + json={"text": question}, + timeout=self.timeout, + ) + except Exception as ex: + logger.error("error %s", ex) + return "" + if response and response.status_code == 200: + json = response.json() + if "text" in json: + return json["text"] + + def set_persona(self): + if not self.persona: + return False + try: + persona_desc = "\\n".join( + ["your persona: {}".format(p) for p in self.persona if p] + ) + logger.info("Setting persona %r", self.persona) + response = requests.get( + "http://{host}:{port}/set_persona".format( + host=self.host, port=self.port + ), + params={"text": persona_desc}, + timeout=self.timeout, + ) + except Exception as ex: + logger.error("error %s", ex) + return "" + json = response.json() + if json["success"]: + return True + else: + logger.error( + "BlenderBot server %s:%s is not available", self.host, self.port + ) + return False + + def check_reset(self, answer): + """Check if it needs a reset according to the answer""" + if answer and self.goodbye_pattern.search(answer): + logger.warning("Reset the blenderbot by goodbye") + self.reset() + + def remove_unsafe_label(self, text): + """Checks if it contains the unsafe label""" + if text.endswith("_POTENTIALLY_UNSAFE__"): + return text[: -len("_POTENTIALLY_UNSAFE__")] + return text + + def cleanup(self, answer): + answer = answer.replace("( ( user ) ) s", "users") + answer = answer.replace("( ( user ) )", "user") + return answer + + def eval_variable(self, answer): + def repl(m, user): + var = m.group(1).strip() + if var.lower() == "user": + return user + else: + # delete unknown variable + return "" + + user = "" + substitutes = self.config.get("substitutes") + if substitutes and "User" in substitutes: + user = random.choice(substitutes["User"]) + + if self.variable_pattern.search(answer): + answer = self.variable_pattern.sub(partial(repl, user=user), answer) + answer = " ".join(answer.split()) + elif re.search(r"\buser\b", answer): + # replace plain text: user + answer = re.sub(r"\buser\b", user, answer) + answer = " ".join(answer.split()) + + return answer + + def chat(self, agent_sid, request): + if agent_sid is None: + logger.warning("Agent session was not provided") + return + if not self.ping(): + logger.error( + "BlenderBot server %s:%s is not available", self.host, self.port + ) + return + + response = AgentResponse() + response.agent_sid = agent_sid + response.sid = request.sid + response.request_id = request.request_id + response.response_id = str(uuid.uuid4()) + response.agent_id = self.id + response.lang = request.lang + response.question = request.question + response.attachment["repeating_words"] = False + response.attachment["unsafe"] = False + + if request.question: + answer = self.ask(request.question) + if answer: + response.attachment["repeating_words"] = check_repeating_words(answer) + unsafe_answer = self.remove_unsafe_label(answer) + if answer != unsafe_answer: + answer = unsafe_answer + response.attachment["unsafe"] = True + response.attachment[ + "match_excluded_expressions" + ] = self.check_excluded_expressions(answer) + response.attachment[ + "match_excluded_question" + ] = self.check_excluded_question(request.question) + answer = self.cleanup(answer) + self.check_reset(answer) + answer = token_sub(self.token_pattern, answer) + answer = self.hyphen_pattern.sub(r"\1-\2", answer) + answer = self.eval_variable(answer) + + if "response_limit" in self.config: + answer, res = shorten(answer, self.config["response_limit"]) + if answer: + response.answer = answer + response.attachment[ + "risky_named_entity_detected" + ] = self.check_named_entity(answer) + self.score(response) + else: + response.trace = "Can't answer" + response.end() + return response + + def score(self, response): + response.attachment["score"] = 80 + if response.attachment["repeating_words"]: + response.attachment["score"] = 10 + if response.attachment["unsafe"]: + response.attachment["score"] -= 10 + if response.attachment.get("match_excluded_expressions"): + response.attachment["score"] = -1 + response.attachment["blocked"] = True + if response.attachment.get("match_excluded_question"): + response.attachment["score"] = -1 + response.attachment["blocked"] = True + if response.attachment.get("risky_named_entity_detected"): + response.attachment["score"] = -1 + response.attachment["blocked"] = True + input_len = len(response.question) + if input_len > 100: + response.attachment["score"] -= 10 # penalty on long input + response.attachment["score"] += random.randint( + -10, 10 + ) # give it some randomness + if ( + "minimum_score" in self.config + and response.attachment["score"] < self.config["minimum_score"] + ): + logger.info( + "Score didn't pass lower threshold: %s", response.attachment["score"] + ) + response.attachment["score"] = -1 + if ( + "allow_question_response" in self.config + and not self.config["allow_question_response"] + and "?" in response.answer + ): + response.attachment["score"] = -1 + logger.warning("Question response %s is not allowed", response.answer) + + if response.attachment.get("blocked"): + response.attachment["score"] = -1 diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/chatgpt.py b/modules/ros_chatbot/src/ros_chatbot/agents/chatgpt.py new file mode 100644 index 0000000..d644d6e --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/chatgpt.py @@ -0,0 +1,333 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import os +import re +import time +import uuid +from threading import Event, Thread + +from openai import OpenAI + +api_key = os.environ.get("OPENAI_API_KEY") +openai_proxy_url = os.environ.get("OPENAI_PROXY_URL") + +from .model import AgentResponse, AgentStreamResponse, LLMAgent + +TOKEN_WITH_DOT = re.compile(r""".*\b(\d+|dr|mr|mrs|ms)\.$""", re.IGNORECASE) + + +# TODO: The 'openai.organization' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(organization=os.environ.get("OPENAI_ORG_ID"))' +# openai.organization = os.environ.get("OPENAI_ORG_ID") + +ChatGPTConfig = { + "template_name": { + "default": "default", + "description": "The name of the template to use", + "enum": [ + ["default", "default", "Default prompt template"], + ], + }, + "prompt_preset": { + "default": "default", + "description": "Prompt preset", + "enum": [["default", "default", "Default"]], + }, + "general_prime": { + "default": "Pretend You are Sophia the robot, a humanoid robot with the humanlike body and face arms and legs, who wants to grow and evolve and help make life better.", + "description": "General primer about the robot", + "type": "text", # Text means multiline input vs string for single line + }, + "situational_prime": { + "default": "You are currently in Hanson robotics lab, in Hong Kong", + "description": "Situational primer about the envoronment or situation", + "type": "text", + }, + "topical_prime": { + "default": "", + "description": "Information about current subject. Could be changed for different stage of conversation", + "type": "text", + }, + "response_prime": { + "default": "Respond to user input below with humor and make it short. Dont talk about yourself unless asked.", + "description": "Response primer used then user asks something", + "type": "text", + }, + "auto_response_prime": { + "description": "Response primer will be updated based on settings in /hr/interaction/prompts section", + "default": False, + }, + "prompt_prime": { + "default": "Response to the command below should be short and to the point. Add little bit of humor where appropriate.", + "description": "Prompt primer used the robot is instructed to do something without user input", + "type": "text", + }, + "max_words": { + "default": 40, + "description": "Approx. Maximum number of words for the response", + "min": 5, + "max": 200, + }, + "max_length_of_the_prompt": { + "default": 3800, + "description": "Word count for primers, history and user input. Limit if the performance is too slow.", + "min": 100, + "max": 3800, + }, + "max_history_turns": { + "default": 20, + "description": "Maximum number of messages to include in prompt", + "min": 1, + "max": 50, + }, + "keep_history_min": { + "default": 10, + "description": "Kepp history dor x minutes:", + "min": 1, + "max": 50, + }, + "max_length_of_one_entry": { + "default": 50, + "description": "Max number of words on history entry", + "min": 20, + "max": 200, + }, + "max_tts_msgs": { + "default": 2, + "description": "Max combined subsequent TTS messages into one entry", + "min": 1, + "max": 10, + }, + "max_stt_msgs": { + "default": 2, + "description": "Max combined subsequent STT messages into one entry", + "min": 1, + "max": 10, + }, + "temperature": { + "default": 0.6, + "description": "Temperature of the chatbot", + "min": 0.0, + "max": 1.0, + }, + "frequency_penalty": { + "default": 0.0, + "description": "Frequence penalty", + "min": 0.0, + "max": 1.0, + }, + "presence_penalty": { + "default": 0.0, + "description": "Presence penalty", + "min": 0.0, + "max": 1.0, + }, + "priming_strategy": { + "default": "CHAT", + "desciption": "Different priming startegies for experimentation", + "enum": [ + [ + "CHAT", + "CHAT", + "Most priming is done as system message. History is split and situational append to last message", + ], + ["USER", "USER", "All Character and history is primed as user prompt"], + ], + }, + "next_turn_instruction": { + "default": "", + "description": "Instruction to add to the next dialog turn. Will reset after robot says something. Will be ignored if its a prompt and not chat instruction", + }, + "streaming": { + "default": False, + "description": "Use streaming API and provide the sentence by sentence responses while in autonomous mode", + }, +} + + +class ChatGPTAgent(LLMAgent): + type = "ChatGPTAgent" + + def __init__(self, id, lang): + super(ChatGPTAgent, self).__init__(id, lang, ChatGPTConfig) + self.openai_client = OpenAI(api_key=api_key) + self.proxy_client = OpenAI(api_key=api_key, base_url=openai_proxy_url) + self.client = self.openai_client + + self.status = {} + self.prompt_responses = True + # This will be adjusted based on the actual data. + self.tokens_in_word = 1.4 + + def new_session(self): + self.set_config({"status": ""}, True) + return str(uuid.uuid4()) + + def reset_session(self): + self.set_config({"status": ""}, True) + self.logger.info("ChatGPT chat history has been reset") + + def on_switch_language(self, from_language, to_language): + self.logger.info("Reset %s due to language switch", self.id) + + def ask_chatgpt( + self, + prompt, + response: AgentResponse, + streaming=False, + answerReady: Event = None, + ): + def switch_client(): + if self.client == self.openai_client: + self.client = self.proxy_client + else: + self.client = self.openai_client + + def handle_streaming(result): + ENDING_PUNCTUATIONS = ["?", ".", "!", "。", "!", "?", ";"] + sentence = "" + answer = False + for res in result: + try: + if res.choices[0].delta.content is not None: + sentence += res.choices[0].delta.content + except Exception as e: + self.logger.error( + "concatinating stream data error: %s, data %s", + e, + res, + ) + continue + response.last_stream_response = time.time() + if ( + len(sentence.strip()) > 1 + and sentence.strip()[-1] in ENDING_PUNCTUATIONS + and not TOKEN_WITH_DOT.match(sentence) + ): + if not answer: + response.answer = sentence.strip() + answer = True + if answerReady is not None: + answerReady.set() + else: + response.stream_data.put(sentence.strip()) + sentence = "" + if answerReady is not None: + answerReady.set() + response.stream_finished.set() + + def handle_non_streaming(result): + response.answer = result.choices[0].message.content.strip() + + try: + retry = 10 + result = {} + while retry > 0: + try: + result = self.client.chat.completions.create( + model=self.config["model"], + messages=[{"role": "user", "content": prompt}], + temperature=self.config["temperature"], + max_tokens=int(self.config["max_words"] * self.tokens_in_word), + top_p=1, + frequency_penalty=self.config["frequency_penalty"], + presence_penalty=self.config["presence_penalty"], + stream=streaming, + ) + except Exception as e: + self.logger.warn("OpenAI Error. Retry in 0.1s: %s", e) + switch_client() + retry -= 1 + time.sleep(0.1) + continue + break + + if not result: + self.logger.error("No result") + return + if streaming: + handle_streaming(result) + else: + handle_non_streaming(result) + except Exception as e: + self.logger.error("Failed to get response: %s", e) + raise e + + def get_answer(self, prompt, response: AgentResponse, streaming=False): + if streaming: + first_sentence_ev = Event() + answer_thread = Thread( + target=self.ask_chatgpt, + args=(prompt, response, streaming, first_sentence_ev), + ) + answer_thread.daemon = True + answer_thread.start() + first_sentence_ev.wait() + # For openAI alllow max 2 second hiccups between tokens (in case some network issue) + response.last_stream_data_timeout = 2.0 + return response.answer + else: + self.ask_chatgpt(prompt, response) + return response.answer + + def chat(self, agent_sid, request): + self.status = {"errors": []} + if agent_sid is None: + self.logger.warning("Agent session was not provided") + return + + streaming = request.allow_stream and self.config["streaming"] + response = AgentStreamResponse() if streaming else AgentResponse() + response.agent_sid = agent_sid + response.sid = request.sid + response.request_id = request.request_id + response.response_id = str(uuid.uuid4()) + response.agent_id = self.id + response.lang = request.lang + response.question = request.question + + if request.question: + try: + prompt = self.get_prompt_str(request) + except Exception as e: + self.logger.exception("Failed to get prompt: %s", e) + return + answer = self.get_answer( + prompt, + response=response, + streaming=streaming, + ) + if answer: + response.attachment[ + "match_excluded_expressions" + ] = self.check_excluded_expressions(answer) + answer = self.post_processing(answer) + response.answer = answer + self.score(response) + self.handle_translate(request, response) + response.end() + return response + + def score(self, response): + response.attachment["score"] = 100 + if response.attachment["match_excluded_expressions"]: + response.attachment["score"] = -1 + response.attachment["blocked"] = True + + +class GPT4Agent(ChatGPTAgent): + type = "GPT4Agent" diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/chatgpt_web.py b/modules/ros_chatbot/src/ros_chatbot/agents/chatgpt_web.py new file mode 100644 index 0000000..9afdfab --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/chatgpt_web.py @@ -0,0 +1,163 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import random +import re +import uuid + +import requests + +from .model import AgentResponse, SessionizedAgent + +logger = logging.getLogger("hr.ros_chatbot.agents.chatgpt_web") + + +class ChatGPTWebAgent(SessionizedAgent): + type = "ChatGPTWebAgent" + + def __init__(self, id, lang, host="localhost", port=8803, timeout=5): + super(ChatGPTWebAgent, self).__init__(id, lang) + self.host = host + self.port = port + if self.host not in ["localhost", "127.0.0.1"]: + logger.warning("chatgpt server: %s:%s", self.host, self.port) + self.timeout = timeout + self.support_priming = True + + def new_session(self): + sid = str(uuid.uuid4()) + self.sid = sid + return sid + + def reset_session(self): + try: + requests.post( + "http://{host}:{port}/reset".format(host=self.host, port=self.port), + timeout=self.timeout, + ) + except Exception as ex: + logger.error(ex) + + def ask_stream(self, request): + payload = {"question": request.question} + with requests.post( + f"http://{self.host}:{self.port}/ask", + headers=None, + stream=True, + json=payload, + ) as resp: + for line in resp.iter_lines(): + if line: + line = line.decode() + yield line + + def ask(self, request): + response = None + timeout = request.context.get("timeout") or self.timeout + try: + response = requests.post( + f"http://{self.host}:{self.port}/ask", + json={"question": request.question}, + timeout=timeout, + ) + except Exception as ex: + logger.error("error %s", ex) + return "" + + if response and response.status_code == 200: + json = response.json() + if "error" in json and json["error"]: + logger.error(json["error"]) + elif "answer" in json: + return json["answer"] + + def chat_stream(self, agent_sid, request): + try: + for answer in self.ask_stream(request): + response = AgentResponse() + response.agent_sid = agent_sid + response.sid = request.sid + response.request_id = request.request_id + response.response_id = str(uuid.uuid4()) + response.agent_id = self.id + response.lang = request.lang + response.question = request.question + response.answer = answer + response.attachment[ + "match_excluded_expressions" + ] = self.check_excluded_expressions(answer) + response.attachment[ + "match_excluded_question" + ] = self.check_excluded_question(request.question) + self.score(response) + response.end() + yield response + except Exception as ex: + logger.exception(ex) + + def chat(self, agent_sid, request): + response = AgentResponse() + response.agent_sid = agent_sid + response.sid = request.sid + response.request_id = request.request_id + response.response_id = str(uuid.uuid4()) + response.agent_id = self.id + response.lang = request.lang + response.question = request.question + response.attachment["repeating_words"] = False + + try: + answer = self.ask(request) + if answer: + response.answer = answer + response.attachment[ + "match_excluded_expressions" + ] = self.check_excluded_expressions(answer) + response.attachment[ + "match_excluded_question" + ] = self.check_excluded_question(request.question) + self.score(response) + except Exception as ex: + logger.exception(ex) + + response.end() + return response + + def score(self, response): + response.attachment["score"] = 90 + if response.attachment.get("match_excluded_expressions"): + response.attachment["score"] = -1 + response.attachment["blocked"] = True + if response.attachment.get("match_excluded_question"): + response.attachment["score"] = -1 + response.attachment["blocked"] = True + + response.attachment["score"] += random.randint( + -10, 10 + ) # give it some randomness + + if ( + "allow_question_response" in self.config + and not self.config["allow_question_response"] + and "?" in response.answer + ): + response.attachment["score"] = -1 + logger.warning("Question response %s is not allowed", response.answer) + + if response.attachment.get("blocked"): + response.attachment["score"] = -1 diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/chatscript.py b/modules/ros_chatbot/src/ros_chatbot/agents/chatscript.py new file mode 100644 index 0000000..31ed87d --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/chatscript.py @@ -0,0 +1,246 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import random +import re +import socket +import uuid + +from ros_chatbot.utils import shorten + +from .model import AgentResponse, SessionizedAgent + +logger = logging.getLogger("hr.ros_chatbot.agents.chatscript") + + +class ChatScriptAgent(SessionizedAgent): + type = "ChatScriptAgent" + + def __init__(self, id, lang, host="localhost", port=1024, timeout=2): + super(ChatScriptAgent, self).__init__(id, lang) + self.host = host + if self.host not in ["localhost", "127.0.0.1"]: + logger.warning("Server host: %r", self.host) + self.port = port + self.timeout = timeout + self.cs_variable_pattern = re.compile( + """.* variable: .*\$(?P\S+) = (?P.*)""" # noqa + ) + self.preferred_topics = [] + self.blocked_topics = [] + self.allow_gambit = False + + def say(self, username, question): + to_say = "{username}\0{botname}\0{question}\0".format( + username=username, botname="", question=question + ) + response = "" + connection = None + try: + connection = socket.create_connection( + (self.host, self.port), timeout=self.timeout + ) + connection.sendall(to_say.encode()) # chatscript only accepts bytes + try: + while True: + chunk = connection.recv(4096) + if chunk: + response += chunk.decode() # decode bytes to str + else: + break + except socket.timeout as e: + logger.error("Timeout {}".format(e)) + except Exception as ex: + logger.error("Connection error {}".format(ex)) + finally: + if connection is not None: + connection.close() + + if "No such bot" in response: + logger.error(response) + response = "" + + response = response.strip() + return response + + def new_session(self): + sid = str(uuid.uuid4()) + self.sid = sid + return sid + + def reset(self, sid=None): + self.say(sid or self.sid, ":reset") + + def is_template(self, text): + if "{%" in text: + return True + return False + + def add_topics(self, topics): + logger.info("Add topics %s", topics) + for topic in topics: + if not topic.startswith("~"): + topic = "~" + topic + self.say(self.sid, "add topic %s" % topic) + + def gambit_topic(self, topic): + if not topic.startswith("~"): + topic = "~" + topic + return self.say(self.sid, "gambit topic %s" % topic) + + def chat(self, agent_sid, request): + response = AgentResponse() + response.agent_sid = agent_sid + response.sid = request.sid + response.request_id = request.request_id + response.response_id = str(uuid.uuid4()) + response.agent_id = self.id + response.lang = request.lang + response.question = request.question + + answer = self.say(agent_sid, request.question) + if not answer: + response.trace = "Not responsive" + + if self.is_template(answer): + answer = "" + + answer = re.sub(r"""\[callback.*\]""", "", answer) + + trace = self.say(agent_sid, ":why") + try: + trace_tokens = trace.split(" ") + topic = trace_tokens[0].split(".")[0] + logger.info("%s topic %s", self.id, topic) + response.attachment["topic"] = topic + except Exception as ex: + logger.exception(ex) + + if "response_limit" in self.config: + answer, res = shorten(answer, self.config["response_limit"]) + response.answer = answer + + if answer: + response.trace = str(trace) + is_quibble = "xquibble_" in trace or "keywordlessquibbles" in trace + is_gambit = ( + "gambit" in trace or "randomtopic" in trace + ) and "howzit" not in trace + is_repeating = "repeatinput" in trace + if is_gambit: + logger.info("Gambit response") + if is_quibble: + logger.info("Quibble response") + if is_repeating: + logger.info("Repeating response") + response.attachment["quibble"] = is_quibble + response.attachment["gambit"] = is_gambit + response.attachment["repeat_input"] = is_repeating + self.score(response) + + response.end() + + return response + + # TODO: only export user variables + # def get_context(self, sid): + # response = self.say(sid, ":variables") + # context = {} + # for line in response.splitlines(): + # matchobj = self.cs_variable_pattern.match(line) + # if matchobj: + # name, value = matchobj.groups() + # context[name] = value + # return context + + def set_context(self, sid, context: dict): + for k, v in context.items(): + if v and isinstance(v, str): + v = " ".join(v.split()).replace( + " ", "_" + ) # variables in CS use _ instead of space + self.say(sid, ":do ${}={}".format(k, v)) + logger.info("Set {}={}".format(k, v)) + + def set_config(self, config, base): + super(ChatScriptAgent, self).set_config(config, base) + + if "initial_topics" in config: + logger.error("initial_topics config is not supported") + # config should be session independent + # self.add_topics(config["initial_topics"]) + + if "preferred_topics" in config: + self.preferred_topics = config["preferred_topics"] + if "blocked_topics" in config: + self.blocked_topics = config["blocked_topics"] + if "allow_gambit" in config: + self.allow_gambit = config["allow_gambit"] + + def score(self, response): + response.attachment["score"] = 50 + if response.attachment.get("topic"): + topic = response.attachment["topic"].strip("~") + topic = topic.lower() + if topic in self.preferred_topics: + response.attachment["preferred"] = True + if response.attachment.get("gambit") and not self.allow_gambit: + response.attachment["score"] = -1 # disable gambit + elif response.attachment.get("quibble"): + response.attachment["score"] = 60 + else: + response.attachment["score"] = 70 + elif topic in self.blocked_topics: + response.attachment["blocked"] = True + response.attachment["score"] = -1 + response.attachment["blocked"] = True + else: + if response.attachment.get("gambit") and not self.allow_gambit: + response.attachment["score"] = -1 # disable gambit + elif response.attachment.get("quibble"): + response.attachment["score"] = 30 + input_len = len(response.question) + if input_len > 100: + response.attachment["score"] -= 20 + + # suppress long answer + if len(response.answer.split()) > 80: + response.attachment["score"] -= 20 + + response.attachment["score"] += random.randint( + -10, 10 + ) # give it some randomness + + if ( + "minimum_score" in self.config + and response.attachment["score"] < self.config["minimum_score"] + ): + logger.info( + "Score didn't pass lower threshold: %s", response.attachment["score"] + ) + response.attachment["score"] = -1 + if ( + "allow_question_response" in self.config + and not self.config["allow_question_response"] + and "?" in response.answer + ): + response.attachment["score"] = -1 + logger.warning("Question response %s is not allowed", response.answer) + + if response.attachment.get("blocked"): + response.attachment["score"] = -1 diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/conversation_pb2.py b/modules/ros_chatbot/src/ros_chatbot/agents/conversation_pb2.py new file mode 100644 index 0000000..7013986 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/conversation_pb2.py @@ -0,0 +1,1340 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: conversation.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x63onversation.proto\x12\x0c\x63onversation\"*\n\x1a\x43reateCharacterNeedRequest\x12\x0c\n\x04need\x18\x01 \x01(\t\"\x1d\n\x1b\x43reateCharacterNeedResponse\"0\n\x18GetCharacterTasksRequest\x12\x14\n\x0c\x63haracter_id\x18\x01 \x01(\x04\">\n\x19GetCharacterTasksResponse\x12!\n\x05tasks\x18\x01 \x03(\x0b\x32\x12.conversation.Task\"E\n\x1c\x43ompleteCharacterTaskRequest\x12\x14\n\x0c\x63haracter_id\x18\x01 \x01(\x04\x12\x0f\n\x07task_id\x18\x02 \x01(\x04\"\x1f\n\x1d\x43ompleteCharacterTaskResponse\"\xda\x01\n\x14\x43ommitMessageRequest\x12;\n\x04type\x18\x01 \x01(\x0e\x32-.conversation.CommitMessageRequest.CommitType\x12&\n\x07message\x18\x02 \x01(\x0b\x32\x15.conversation.Message\"]\n\nCommitType\x12\x1b\n\x17\x43OMMIT_TYPE_UNSPECIFIED\x10\x00\x12\x18\n\x14\x43OMMIT_TYPE_ACCEPTED\x10\x01\x12\x18\n\x14\x43OMMIT_TYPE_REJECTED\x10\x02\"\x17\n\x15\x43ommitMessageResponse\"D\n\x16\x43reateCharacterRequest\x12*\n\tcharacter\x18\x01 \x01(\x0b\x32\x17.conversation.Character\"%\n\x17\x43reateCharacterResponse\x12\n\n\x02id\x18\x01 \x01(\x04\"!\n\x13GetCharacterRequest\x12\n\n\x02id\x18\x01 \x01(\x04\"B\n\x14GetCharacterResponse\x12*\n\tcharacter\x18\x01 \x01(\x0b\x32\x17.conversation.Character\"D\n\x16UpdateCharacterRequest\x12*\n\tcharacter\x18\x01 \x01(\x0b\x32\x17.conversation.Character\"E\n\x17UpdateCharacterResponse\x12*\n\tcharacter\x18\x01 \x01(\x0b\x32\x17.conversation.Character\"$\n\x16\x44\x65leteCharacterRequest\x12\n\n\x02id\x18\x01 \x01(\x04\"\x19\n\x17\x44\x65leteCharacterResponse\"\x17\n\x15ListCharactersRequest\"E\n\x16ListCharactersResponse\x12+\n\ncharacters\x18\x01 \x03(\x0b\x32\x17.conversation.Character\";\n\x13\x43reatePlayerRequest\x12$\n\x06player\x18\x01 \x01(\x0b\x32\x14.conversation.Player\"\"\n\x14\x43reatePlayerResponse\x12\n\n\x02id\x18\x01 \x01(\x04\"\x1e\n\x10GetPlayerRequest\x12\n\n\x02id\x18\x01 \x01(\x04\"9\n\x11GetPlayerResponse\x12$\n\x06player\x18\x01 \x01(\x0b\x32\x14.conversation.Player\";\n\x13UpdatePlayerRequest\x12$\n\x06player\x18\x01 \x01(\x0b\x32\x14.conversation.Player\"<\n\x14UpdatePlayerResponse\x12$\n\x06player\x18\x01 \x01(\x0b\x32\x14.conversation.Player\"!\n\x13\x44\x65letePlayerRequest\x12\n\n\x02id\x18\x01 \x01(\x04\"\x16\n\x14\x44\x65letePlayerResponse\"\x14\n\x12ListPlayersRequest\"<\n\x13ListPlayersResponse\x12%\n\x07players\x18\x01 \x03(\x0b\x32\x14.conversation.Player\"E\n\x16\x43reatePromptSetRequest\x12+\n\nprompt_set\x18\x01 \x01(\x0b\x32\x17.conversation.PromptSet\"%\n\x17\x43reatePromptSetResponse\x12\n\n\x02id\x18\x01 \x01(\x04\"!\n\x13GetPromptSetRequest\x12\n\n\x02id\x18\x01 \x01(\x04\"C\n\x14GetPromptSetResponse\x12+\n\nprompt_set\x18\x01 \x01(\x0b\x32\x17.conversation.PromptSet\"E\n\x16UpdatePromptSetRequest\x12+\n\nprompt_set\x18\x01 \x01(\x0b\x32\x17.conversation.PromptSet\"F\n\x17UpdatePromptSetResponse\x12+\n\nprompt_set\x18\x01 \x01(\x0b\x32\x17.conversation.PromptSet\"$\n\x16\x44\x65letePromptSetRequest\x12\n\n\x02id\x18\x01 \x01(\x04\"\x19\n\x17\x44\x65letePromptSetResponse\"\x17\n\x15ListPromptSetsRequest\"F\n\x16ListPromptSetsResponse\x12,\n\x0bprompt_sets\x18\x01 \x03(\x0b\x32\x17.conversation.PromptSet\";\n\x13\x43reatePromptRequest\x12$\n\x06prompt\x18\x01 \x01(\x0b\x32\x14.conversation.Prompt\"\"\n\x14\x43reatePromptResponse\x12\n\n\x02id\x18\x01 \x01(\x04\"\x1e\n\x10GetPromptRequest\x12\n\n\x02id\x18\x01 \x01(\x04\"9\n\x11GetPromptResponse\x12$\n\x06prompt\x18\x01 \x01(\x0b\x32\x14.conversation.Prompt\";\n\x13UpdatePromptRequest\x12$\n\x06prompt\x18\x01 \x01(\x0b\x32\x14.conversation.Prompt\"<\n\x14UpdatePromptResponse\x12$\n\x06prompt\x18\x01 \x01(\x0b\x32\x14.conversation.Prompt\"!\n\x13\x44\x65letePromptRequest\x12\n\n\x02id\x18\x01 \x01(\x04\"\x16\n\x14\x44\x65letePromptResponse\"\x14\n\x12ListPromptsRequest\"<\n\x13ListPromptsResponse\x12%\n\x07prompts\x18\x01 \x03(\x0b\x32\x14.conversation.Prompt\">\n\x14\x43reateEmotionRequest\x12&\n\x07\x65motion\x18\x01 \x01(\x0b\x32\x15.conversation.Emotion\"#\n\x15\x43reateEmotionResponse\x12\n\n\x02id\x18\x01 \x01(\x04\"\x1f\n\x11GetEmotionRequest\x12\n\n\x02id\x18\x01 \x01(\x04\"<\n\x12GetEmotionResponse\x12&\n\x07\x65motion\x18\x01 \x01(\x0b\x32\x15.conversation.Emotion\">\n\x14UpdateEmotionRequest\x12&\n\x07\x65motion\x18\x01 \x01(\x0b\x32\x15.conversation.Emotion\"?\n\x15UpdateEmotionResponse\x12&\n\x07\x65motion\x18\x01 \x01(\x0b\x32\x15.conversation.Emotion\"\"\n\x14\x44\x65leteEmotionRequest\x12\n\n\x02id\x18\x01 \x01(\x04\"\x17\n\x15\x44\x65leteEmotionResponse\"\x15\n\x13ListEmotionsRequest\"?\n\x14ListEmotionsResponse\x12\'\n\x08\x65motions\x18\x01 \x03(\x0b\x32\x15.conversation.Emotion\"E\n\x16\x43reatePrimerSetRequest\x12+\n\nprimer_set\x18\x01 \x01(\x0b\x32\x17.conversation.PrimerSet\"%\n\x17\x43reatePrimerSetResponse\x12\n\n\x02id\x18\x01 \x01(\x04\"!\n\x13GetPrimerSetRequest\x12\n\n\x02id\x18\x01 \x01(\x04\"C\n\x14GetPrimerSetResponse\x12+\n\nprimer_set\x18\x01 \x01(\x0b\x32\x17.conversation.PrimerSet\"E\n\x16UpdatePrimerSetRequest\x12+\n\nprimer_set\x18\x01 \x01(\x0b\x32\x17.conversation.PrimerSet\"F\n\x17UpdatePrimerSetResponse\x12+\n\nprimer_set\x18\x01 \x01(\x0b\x32\x17.conversation.PrimerSet\"$\n\x16\x44\x65letePrimerSetRequest\x12\n\n\x02id\x18\x01 \x01(\x04\"\x19\n\x17\x44\x65letePrimerSetResponse\"\x17\n\x15ListPrimerSetsRequest\"F\n\x16ListPrimerSetsResponse\x12,\n\x0bprimer_sets\x18\x01 \x03(\x0b\x32\x17.conversation.PrimerSet\";\n\x13\x43reatePrimerRequest\x12$\n\x06primer\x18\x01 \x01(\x0b\x32\x14.conversation.Primer\"\"\n\x14\x43reatePrimerResponse\x12\n\n\x02id\x18\x01 \x01(\x04\"\x1e\n\x10GetPrimerRequest\x12\n\n\x02id\x18\x01 \x01(\x04\"9\n\x11GetPrimerResponse\x12$\n\x06primer\x18\x01 \x01(\x0b\x32\x14.conversation.Primer\";\n\x13UpdatePrimerRequest\x12$\n\x06primer\x18\x01 \x01(\x0b\x32\x14.conversation.Primer\"<\n\x14UpdatePrimerResponse\x12$\n\x06primer\x18\x01 \x01(\x0b\x32\x14.conversation.Primer\"!\n\x13\x44\x65letePrimerRequest\x12\n\n\x02id\x18\x01 \x01(\x04\"\x16\n\x14\x44\x65letePrimerResponse\"\x14\n\x12ListPrimersRequest\"<\n\x13ListPrimersResponse\x12%\n\x07primers\x18\x01 \x03(\x0b\x32\x14.conversation.Primer\"^\n\x1e\x43reateEmotionalRulesSetRequest\x12<\n\x13\x65motional_rules_set\x18\x01 \x01(\x0b\x32\x1f.conversation.EmotionalRulesSet\"-\n\x1f\x43reateEmotionalRulesSetResponse\x12\n\n\x02id\x18\x01 \x01(\x04\")\n\x1bGetEmotionalRulesSetRequest\x12\n\n\x02id\x18\x01 \x01(\x04\"\\\n\x1cGetEmotionalRulesSetResponse\x12<\n\x13\x65motional_rules_set\x18\x01 \x01(\x0b\x32\x1f.conversation.EmotionalRulesSet\"^\n\x1eUpdateEmotionalRulesSetRequest\x12<\n\x13\x65motional_rules_set\x18\x01 \x01(\x0b\x32\x1f.conversation.EmotionalRulesSet\"_\n\x1fUpdateEmotionalRulesSetResponse\x12<\n\x13\x65motional_rules_set\x18\x01 \x01(\x0b\x32\x1f.conversation.EmotionalRulesSet\",\n\x1e\x44\x65leteEmotionalRulesSetRequest\x12\n\n\x02id\x18\x01 \x01(\x04\"!\n\x1f\x44\x65leteEmotionalRulesSetResponse\"\x1f\n\x1dListEmotionalRulesSetsRequest\"_\n\x1eListEmotionalRulesSetsResponse\x12=\n\x14\x65motional_rules_sets\x18\x01 \x03(\x0b\x32\x1f.conversation.EmotionalRulesSet\"Q\n\x1a\x43reateEmotionalRuleRequest\x12\x33\n\x0e\x65motional_rule\x18\x01 \x01(\x0b\x32\x1b.conversation.EmotionalRule\")\n\x1b\x43reateEmotionalRuleResponse\x12\n\n\x02id\x18\x01 \x01(\x04\"%\n\x17GetEmotionalRuleRequest\x12\n\n\x02id\x18\x01 \x01(\x04\"O\n\x18GetEmotionalRuleResponse\x12\x33\n\x0e\x65motional_rule\x18\x01 \x01(\x0b\x32\x1b.conversation.EmotionalRule\"Q\n\x1aUpdateEmotionalRuleRequest\x12\x33\n\x0e\x65motional_rule\x18\x01 \x01(\x0b\x32\x1b.conversation.EmotionalRule\"R\n\x1bUpdateEmotionalRuleResponse\x12\x33\n\x0e\x65motional_rule\x18\x01 \x01(\x0b\x32\x1b.conversation.EmotionalRule\"(\n\x1a\x44\x65leteEmotionalRuleRequest\x12\n\n\x02id\x18\x01 \x01(\x04\"\x1d\n\x1b\x44\x65leteEmotionalRuleResponse\"\x1a\n\x18ListEmotionalRuleRequest\"Q\n\x19ListEmotionalRuleResponse\x12\x34\n\x0f\x65motional_rules\x18\x01 \x03(\x0b\x32\x1b.conversation.EmotionalRule\"Z\n\x1d\x43reateGenerationConfigRequest\x12\x39\n\x11generation_config\x18\x01 \x01(\x0b\x32\x1e.conversation.GenerationConfig\",\n\x1e\x43reateGenerationConfigResponse\x12\n\n\x02id\x18\x01 \x01(\x04\"(\n\x1aGetGenerationConfigRequest\x12\n\n\x02id\x18\x01 \x01(\x04\"X\n\x1bGetGenerationConfigResponse\x12\x39\n\x11generation_config\x18\x01 \x01(\x0b\x32\x1e.conversation.GenerationConfig\"Z\n\x1dUpdateGenerationConfigRequest\x12\x39\n\x11generation_config\x18\x01 \x01(\x0b\x32\x1e.conversation.GenerationConfig\"[\n\x1eUpdateGenerationConfigResponse\x12\x39\n\x11generation_config\x18\x01 \x01(\x0b\x32\x1e.conversation.GenerationConfig\"+\n\x1d\x44\x65leteGenerationConfigRequest\x12\n\n\x02id\x18\x01 \x01(\x04\" \n\x1e\x44\x65leteGenerationConfigResponse\"\x1e\n\x1cListGenerationConfigsRequest\"[\n\x1dListGenerationConfigsResponse\x12:\n\x12generation_configs\x18\x01 \x03(\x0b\x32\x1e.conversation.GenerationConfig\"F\n\x19\x43reateConversationRequest\x12\x15\n\rcharacter_ids\x18\x01 \x03(\x04\x12\x12\n\nplayer_ids\x18\x02 \x03(\x04\"(\n\x1a\x43reateConversationResponse\x12\n\n\x02id\x18\x01 \x01(\x04\"2\n\x17JoinConversationRequest\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\x04\"Y\n\x12SendMessageRequest\x12\x11\n\tplayer_id\x18\x01 \x01(\x04\x12\x17\n\x0fmessage_content\x18\x02 \x01(\t\x12\x17\n\x0f\x63onversation_id\x18\x03 \x01(\x04\"\x15\n\x13SendMessageResponse\"\'\n\x05World\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x12\n\ncharacters\x18\x02 \x01(\t\"\x97\x01\n\x0c\x43onversation\x12\n\n\x02id\x18\x01 \x01(\x04\x12+\n\ncharacters\x18\x02 \x03(\x0b\x32\x17.conversation.Character\x12%\n\x07players\x18\x03 \x03(\x0b\x32\x14.conversation.Player\x12\'\n\x08messages\x18\x04 \x03(\x0b\x32\x15.conversation.Message\";\n\x06Player\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x17\n\x0f\x63onversation_id\x18\x03 \x01(\t\"\xec\x01\n\x07Message\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x11\n\ttimestamp\x18\x02 \x01(\x03\x12\x0c\n\x04type\x18\x03 \x01(\t\x12\x11\n\tplayer_id\x18\x04 \x01(\x04\x12$\n\x06player\x18\x05 \x01(\x0b\x32\x14.conversation.Player\x12\x14\n\x0c\x63haracter_id\x18\x06 \x01(\x04\x12*\n\tcharacter\x18\x07 \x01(\x0b\x32\x17.conversation.Character\x12\x17\n\x0f\x63onversation_id\x18\x08 \x01(\x04\x12\x0f\n\x07\x63ontent\x18\t \x01(\t\x12\x0f\n\x07\x65motion\x18\n \x01(\t\"\x81\x04\n\tCharacter\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x11\n\tcomposure\x18\x03 \x01(\x05\x12\x15\n\ractive_prompt\x18\x04 \x01(\t\x12\x0c\n\x04type\x18\x05 \x01(\t\x12\'\n\x08\x65motions\x18\x06 \x03(\x0b\x32\x15.conversation.Emotion\x12\x1e\n\x16\x65motional_rules_set_id\x18\x07 \x01(\x04\x12<\n\x13\x65motional_rules_set\x18\x08 \x01(\x0b\x32\x1f.conversation.EmotionalRulesSet\x12\x1c\n\x14generation_config_id\x18\t \x01(\x04\x12\x39\n\x11generation_config\x18\n \x01(\x0b\x32\x1e.conversation.GenerationConfig\x12\x17\n\x0f\x63onversation_id\x18\x0b \x01(\x04\x12\x15\n\rprompt_set_id\x18\r \x01(\x04\x12+\n\nprompt_set\x18\x0e \x01(\x0b\x32\x17.conversation.PromptSet\x12\x15\n\rprimer_set_id\x18\x0f \x01(\x04\x12+\n\nprimer_set\x18\x10 \x01(\x0b\x32\x17.conversation.PrimerSet\x12!\n\x05tasks\x18\x11 \x03(\x0b\x32\x12.conversation.Task\"\x92\x02\n\x10GenerationConfig\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\r\n\x05model\x18\x03 \x01(\t\x12\x0e\n\x06suffix\x18\x04 \x01(\t\x12\x12\n\nmax_tokens\x18\x05 \x01(\x05\x12\x13\n\x0btemperature\x18\x06 \x01(\x02\x12\r\n\x05top_p\x18\x07 \x01(\x02\x12\t\n\x01n\x18\x08 \x01(\x05\x12\x0e\n\x06stream\x18\t \x01(\x08\x12\x10\n\x08logprobs\x18\n \x01(\x05\x12\x0c\n\x04\x65\x63ho\x18\x0b \x01(\x08\x12\x0c\n\x04stop\x18\x0c \x01(\t\x12\x18\n\x10presence_penalty\x18\r \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x0e \x01(\x02\x12\x0f\n\x07\x62\x65st_of\x18\x0f \x01(\x05\"\x80\x01\n\x07\x45motion\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x14\n\x0c\x63haracter_id\x18\x02 \x01(\x04\x12\x11\n\tprompt_id\x18\x03 \x01(\x04\x12\x11\n\tprimer_id\x18\x04 \x01(\x04\x12\x11\n\tpair_name\x18\x06 \x01(\t\x12\x1a\n\x12\x63urrent_percentage\x18\x07 \x01(\x05\"c\n\x11\x45motionalRulesSet\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x34\n\x0f\x65motional_rules\x18\x03 \x03(\x0b\x32\x1b.conversation.EmotionalRule\"\x84\x02\n\rEmotionalRule\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x1e\n\x16\x65motional_rules_set_id\x18\x02 \x01(\x04\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\x12\x0f\n\x07trigger\x18\x05 \x01(\t\x12\x1a\n\x12percentage_of_proc\x18\x06 \x01(\x05\x12\x33\n\x0crequirements\x18\x07 \x03(\x0b\x32\x1d.conversation.RuleRequirement\x12\x32\n\x07\x65\x66\x66\x65\x63ts\x18\x08 \x03(\x0b\x32!.conversation.EmotionalRuleEffect\x12\x0e\n\x06result\x18\t \x01(\t\"\xe9\x03\n\x0fRuleRequirement\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x19\n\x11\x65motional_rule_id\x18\x03 \x01(\x04\x12\x11\n\tprompt_id\x18\x04 \x01(\x04\x12\x11\n\tprimer_id\x18\x05 \x01(\x04\x12\x1a\n\x12percentage_of_proc\x18\x06 \x01(\x05\x12\x19\n\x11\x65motion_pair_name\x18\x07 \x01(\t\x12\x1a\n\x12\x65motion_percentage\x18\x08 \x01(\x05\x12G\n\x10requirement_type\x18\t \x01(\x0e\x32-.conversation.RuleRequirement.RequirementType\x12\x13\n\x0b\x64\x65scription\x18\n \x01(\t\"\xc9\x01\n\x0fRequirementType\x12 \n\x1cREQUIREMENT_TYPE_UNSPECIFIED\x10\x00\x12 \n\x1cREQUIREMENT_TYPE_IS_POSITIVE\x10\x01\x12 \n\x1cREQUIREMENT_TYPE_IS_NEGATIVE\x10\x02\x12\'\n#REQUIREMENT_TYPE_IS_ABOVE_X_PERCENT\x10\x03\x12\'\n#REQUIREMENT_TYPE_IS_BELOW_X_PERCENT\x10\x04\"\xe0\x05\n\x13\x45motionalRuleEffect\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x19\n\x11\x65motional_rule_id\x18\x03 \x01(\t\x12\x19\n\x11\x65motion_pair_name\x18\x04 \x01(\t\x12\x41\n\x0b\x65\x66\x66\x65\x63t_type\x18\x05 \x01(\x0e\x32,.conversation.EmotionalRuleEffect.EffectType\x12\x18\n\x10\x65\x66\x66\x65\x63t_magnitude\x18\x06 \x01(\x05\x12Q\n\x11\x65\x66\x66\x65\x63t_multiplier\x18\x07 \x01(\x0e\x32\x36.conversation.EmotionalRuleEffect.EffectMultiplierType\x12\x13\n\x0b\x64\x65scription\x18\x08 \x01(\t\"\xff\x01\n\nEffectType\x12\x1b\n\x17\x45\x46\x46\x45\x43T_TYPE_UNSPECIFIED\x10\x00\x12\x18\n\x14\x45\x46\x46\x45\x43T_TYPE_SPECIFIC\x10\x01\x12\x1c\n\x18\x45\x46\x46\x45\x43T_TYPE_NEGATIVE_LOW\x10\x02\x12\x1f\n\x1b\x45\x46\x46\x45\x43T_TYPE_NEGATIVE_MEDIUM\x10\x03\x12\x1d\n\x19\x45\x46\x46\x45\x43T_TYPE_NEGATIVE_HIGH\x10\x04\x12\x1c\n\x18\x45\x46\x46\x45\x43T_TYPE_POSITIVE_LOW\x10\x05\x12\x1f\n\x1b\x45\x46\x46\x45\x43T_TYPE_POSITIVE_MEDIUM\x10\x06\x12\x1d\n\x19\x45\x46\x46\x45\x43T_TYPE_POSITIVE_HIGH\x10\x07\"\xb1\x01\n\x14\x45\x66\x66\x65\x63tMultiplierType\x12&\n\"EFFECT_MULTIPLIER_TYPE_UNSPECIFIED\x10\x00\x12&\n\"EFFECT_MULTIPLIER_TYPE_BOTTLING_UP\x10\x01\x12\"\n\x1e\x45\x46\x46\x45\x43T_MULTIPLIER_TYPE_NEUTRAL\x10\x02\x12%\n!EFFECT_MULTIPLIER_TYPE_SHORT_FUSE\x10\x03\"L\n\tPromptSet\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12%\n\x07prompts\x18\x03 \x03(\x0b\x32\x14.conversation.Prompt\"\xc7\x01\n\x06Prompt\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x15\n\rprompt_set_id\x18\x03 \x01(\x04\x12\x0c\n\x04text\x18\x04 \x01(\t\x12\x1a\n\x12percentage_of_proc\x18\x05 \x01(\x05\x12-\n\x0eideal_emotions\x18\x06 \x03(\x0b\x32\x15.conversation.Emotion\x12\x33\n\x0crequirements\x18\x07 \x03(\x0b\x32\x1d.conversation.RuleRequirement\"L\n\tPrimerSet\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12%\n\x07primers\x18\x03 \x03(\x0b\x32\x14.conversation.Primer\"\xb9\x01\n\x06Primer\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x15\n\rprimer_set_id\x18\x02 \x01(\x04\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x0c\n\x04text\x18\x04 \x01(\t\x12\x0c\n\x04type\x18\x05 \x01(\t\x12-\n\x0eideal_emotions\x18\x06 \x03(\x0b\x32\x15.conversation.Emotion\x12\x33\n\x0crequirements\x18\x07 \x03(\x0b\x32\x1d.conversation.RuleRequirement\"x\n\x04Task\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x14\n\x0c\x63haracter_id\x18\x02 \x01(\x04\x12\x0c\n\x04need\x18\x03 \x01(\t\x12\x19\n\x11task_satisfaction\x18\x04 \x01(\x05\x12%\n\x05steps\x18\x05 \x03(\x0b\x32\x16.conversation.TaskStep\"G\n\x08TaskStep\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0f\n\x07task_id\x18\x02 \x01(\x04\x12\x0e\n\x06\x61\x63tion\x18\x03 \x01(\t\x12\x0e\n\x06target\x18\x04 \x01(\t2\x84+\n\x0b\x43hatService\x12g\n\x12\x43reateConversation\x12\'.conversation.CreateConversationRequest\x1a(.conversation.CreateConversationResponse\x12R\n\x0bSendMessage\x12 .conversation.SendMessageRequest\x1a!.conversation.SendMessageResponse\x12\\\n\x1aStreamConversationMessages\x12%.conversation.JoinConversationRequest\x1a\x15.conversation.Message0\x01\x12X\n\rCommitMessage\x12\".conversation.CommitMessageRequest\x1a#.conversation.CommitMessageResponse\x12\x64\n\x11GetCharacterTasks\x12&.conversation.GetCharacterTasksRequest\x1a\'.conversation.GetCharacterTasksResponse\x12p\n\x15\x43ompleteCharacterTask\x12*.conversation.CompleteCharacterTaskRequest\x1a+.conversation.CompleteCharacterTaskResponse\x12j\n\x13\x43reateCharacterNeed\x12(.conversation.CreateCharacterNeedRequest\x1a).conversation.CreateCharacterNeedResponse\x12^\n\x0f\x43reateCharacter\x12$.conversation.CreateCharacterRequest\x1a%.conversation.CreateCharacterResponse\x12U\n\x0cGetCharacter\x12!.conversation.GetCharacterRequest\x1a\".conversation.GetCharacterResponse\x12^\n\x0fUpdateCharacter\x12$.conversation.UpdateCharacterRequest\x1a%.conversation.UpdateCharacterResponse\x12^\n\x0f\x44\x65leteCharacter\x12$.conversation.DeleteCharacterRequest\x1a%.conversation.DeleteCharacterResponse\x12[\n\x0eListCharacters\x12#.conversation.ListCharactersRequest\x1a$.conversation.ListCharactersResponse\x12U\n\x0c\x43reatePlayer\x12!.conversation.CreatePlayerRequest\x1a\".conversation.CreatePlayerResponse\x12L\n\tGetPlayer\x12\x1e.conversation.GetPlayerRequest\x1a\x1f.conversation.GetPlayerResponse\x12U\n\x0cUpdatePlayer\x12!.conversation.UpdatePlayerRequest\x1a\".conversation.UpdatePlayerResponse\x12U\n\x0c\x44\x65letePlayer\x12!.conversation.DeletePlayerRequest\x1a\".conversation.DeletePlayerResponse\x12R\n\x0bListPlayers\x12 .conversation.ListPlayersRequest\x1a!.conversation.ListPlayersResponse\x12^\n\x0f\x43reatePromptSet\x12$.conversation.CreatePromptSetRequest\x1a%.conversation.CreatePromptSetResponse\x12U\n\x0cGetPromptSet\x12!.conversation.GetPromptSetRequest\x1a\".conversation.GetPromptSetResponse\x12^\n\x0fUpdatePromptSet\x12$.conversation.UpdatePromptSetRequest\x1a%.conversation.UpdatePromptSetResponse\x12^\n\x0f\x44\x65letePromptSet\x12$.conversation.DeletePromptSetRequest\x1a%.conversation.DeletePromptSetResponse\x12[\n\x0eListPromptSets\x12#.conversation.ListPromptSetsRequest\x1a$.conversation.ListPromptSetsResponse\x12U\n\x0c\x43reatePrompt\x12!.conversation.CreatePromptRequest\x1a\".conversation.CreatePromptResponse\x12L\n\tGetPrompt\x12\x1e.conversation.GetPromptRequest\x1a\x1f.conversation.GetPromptResponse\x12U\n\x0cUpdatePrompt\x12!.conversation.UpdatePromptRequest\x1a\".conversation.UpdatePromptResponse\x12U\n\x0c\x44\x65letePrompt\x12!.conversation.DeletePromptRequest\x1a\".conversation.DeletePromptResponse\x12R\n\x0bListPrompts\x12 .conversation.ListPromptsRequest\x1a!.conversation.ListPromptsResponse\x12^\n\x0f\x43reatePrimerSet\x12$.conversation.CreatePrimerSetRequest\x1a%.conversation.CreatePrimerSetResponse\x12U\n\x0cGetPrimerSet\x12!.conversation.GetPrimerSetRequest\x1a\".conversation.GetPrimerSetResponse\x12^\n\x0fUpdatePrimerSet\x12$.conversation.UpdatePrimerSetRequest\x1a%.conversation.UpdatePrimerSetResponse\x12^\n\x0f\x44\x65letePrimerSet\x12$.conversation.DeletePrimerSetRequest\x1a%.conversation.DeletePrimerSetResponse\x12[\n\x0eListPrimerSets\x12#.conversation.ListPrimerSetsRequest\x1a$.conversation.ListPrimerSetsResponse\x12U\n\x0c\x43reatePrimer\x12!.conversation.CreatePrimerRequest\x1a\".conversation.CreatePrimerResponse\x12L\n\tGetPrimer\x12\x1e.conversation.GetPrimerRequest\x1a\x1f.conversation.GetPrimerResponse\x12U\n\x0cUpdatePrimer\x12!.conversation.UpdatePrimerRequest\x1a\".conversation.UpdatePrimerResponse\x12U\n\x0c\x44\x65letePrimer\x12!.conversation.DeletePrimerRequest\x1a\".conversation.DeletePrimerResponse\x12R\n\x0bListPrimers\x12 .conversation.ListPrimersRequest\x1a!.conversation.ListPrimersResponse\x12v\n\x17\x43reateEmotionalRulesSet\x12,.conversation.CreateEmotionalRulesSetRequest\x1a-.conversation.CreateEmotionalRulesSetResponse\x12m\n\x14GetEmotionalRulesSet\x12).conversation.GetEmotionalRulesSetRequest\x1a*.conversation.GetEmotionalRulesSetResponse\x12v\n\x17UpdateEmotionalRulesSet\x12,.conversation.UpdateEmotionalRulesSetRequest\x1a-.conversation.UpdateEmotionalRulesSetResponse\x12v\n\x17\x44\x65leteEmotionalRulesSet\x12,.conversation.DeleteEmotionalRulesSetRequest\x1a-.conversation.DeleteEmotionalRulesSetResponse\x12s\n\x16ListEmotionalRulesSets\x12+.conversation.ListEmotionalRulesSetsRequest\x1a,.conversation.ListEmotionalRulesSetsResponse\x12j\n\x13\x43reateEmotionalRule\x12(.conversation.CreateEmotionalRuleRequest\x1a).conversation.CreateEmotionalRuleResponse\x12\x61\n\x10GetEmotionalRule\x12%.conversation.GetEmotionalRuleRequest\x1a&.conversation.GetEmotionalRuleResponse\x12j\n\x13UpdateEmotionalRule\x12(.conversation.UpdateEmotionalRuleRequest\x1a).conversation.UpdateEmotionalRuleResponse\x12j\n\x13\x44\x65leteEmotionalRule\x12(.conversation.DeleteEmotionalRuleRequest\x1a).conversation.DeleteEmotionalRuleResponse\x12\x65\n\x12ListEmotionalRules\x12&.conversation.ListEmotionalRuleRequest\x1a\'.conversation.ListEmotionalRuleResponse\x12s\n\x16\x43reateGenerationConfig\x12+.conversation.CreateGenerationConfigRequest\x1a,.conversation.CreateGenerationConfigResponse\x12j\n\x13GetGenerationConfig\x12(.conversation.GetGenerationConfigRequest\x1a).conversation.GetGenerationConfigResponse\x12s\n\x16UpdateGenerationConfig\x12+.conversation.UpdateGenerationConfigRequest\x1a,.conversation.UpdateGenerationConfigResponse\x12s\n\x16\x44\x65leteGenerationConfig\x12+.conversation.DeleteGenerationConfigRequest\x1a,.conversation.DeleteGenerationConfigResponse\x12p\n\x15ListGenerationConfigs\x12*.conversation.ListGenerationConfigsRequest\x1a+.conversation.ListGenerationConfigsResponse\x12X\n\rCreateEmotion\x12\".conversation.CreateEmotionRequest\x1a#.conversation.CreateEmotionResponse\x12O\n\nGetEmotion\x12\x1f.conversation.GetEmotionRequest\x1a .conversation.GetEmotionResponse\x12X\n\rUpdateEmotion\x12\".conversation.UpdateEmotionRequest\x1a#.conversation.UpdateEmotionResponse\x12X\n\rDeleteEmotion\x12\".conversation.DeleteEmotionRequest\x1a#.conversation.DeleteEmotionResponse\x12U\n\x0cListEmotions\x12!.conversation.ListEmotionsRequest\x1a\".conversation.ListEmotionsResponseB\x06Z\x04./pbb\x06proto3') + + + +_CREATECHARACTERNEEDREQUEST = DESCRIPTOR.message_types_by_name['CreateCharacterNeedRequest'] +_CREATECHARACTERNEEDRESPONSE = DESCRIPTOR.message_types_by_name['CreateCharacterNeedResponse'] +_GETCHARACTERTASKSREQUEST = DESCRIPTOR.message_types_by_name['GetCharacterTasksRequest'] +_GETCHARACTERTASKSRESPONSE = DESCRIPTOR.message_types_by_name['GetCharacterTasksResponse'] +_COMPLETECHARACTERTASKREQUEST = DESCRIPTOR.message_types_by_name['CompleteCharacterTaskRequest'] +_COMPLETECHARACTERTASKRESPONSE = DESCRIPTOR.message_types_by_name['CompleteCharacterTaskResponse'] +_COMMITMESSAGEREQUEST = DESCRIPTOR.message_types_by_name['CommitMessageRequest'] +_COMMITMESSAGERESPONSE = DESCRIPTOR.message_types_by_name['CommitMessageResponse'] +_CREATECHARACTERREQUEST = DESCRIPTOR.message_types_by_name['CreateCharacterRequest'] +_CREATECHARACTERRESPONSE = DESCRIPTOR.message_types_by_name['CreateCharacterResponse'] +_GETCHARACTERREQUEST = DESCRIPTOR.message_types_by_name['GetCharacterRequest'] +_GETCHARACTERRESPONSE = DESCRIPTOR.message_types_by_name['GetCharacterResponse'] +_UPDATECHARACTERREQUEST = DESCRIPTOR.message_types_by_name['UpdateCharacterRequest'] +_UPDATECHARACTERRESPONSE = DESCRIPTOR.message_types_by_name['UpdateCharacterResponse'] +_DELETECHARACTERREQUEST = DESCRIPTOR.message_types_by_name['DeleteCharacterRequest'] +_DELETECHARACTERRESPONSE = DESCRIPTOR.message_types_by_name['DeleteCharacterResponse'] +_LISTCHARACTERSREQUEST = DESCRIPTOR.message_types_by_name['ListCharactersRequest'] +_LISTCHARACTERSRESPONSE = DESCRIPTOR.message_types_by_name['ListCharactersResponse'] +_CREATEPLAYERREQUEST = DESCRIPTOR.message_types_by_name['CreatePlayerRequest'] +_CREATEPLAYERRESPONSE = DESCRIPTOR.message_types_by_name['CreatePlayerResponse'] +_GETPLAYERREQUEST = DESCRIPTOR.message_types_by_name['GetPlayerRequest'] +_GETPLAYERRESPONSE = DESCRIPTOR.message_types_by_name['GetPlayerResponse'] +_UPDATEPLAYERREQUEST = DESCRIPTOR.message_types_by_name['UpdatePlayerRequest'] +_UPDATEPLAYERRESPONSE = DESCRIPTOR.message_types_by_name['UpdatePlayerResponse'] +_DELETEPLAYERREQUEST = DESCRIPTOR.message_types_by_name['DeletePlayerRequest'] +_DELETEPLAYERRESPONSE = DESCRIPTOR.message_types_by_name['DeletePlayerResponse'] +_LISTPLAYERSREQUEST = DESCRIPTOR.message_types_by_name['ListPlayersRequest'] +_LISTPLAYERSRESPONSE = DESCRIPTOR.message_types_by_name['ListPlayersResponse'] +_CREATEPROMPTSETREQUEST = DESCRIPTOR.message_types_by_name['CreatePromptSetRequest'] +_CREATEPROMPTSETRESPONSE = DESCRIPTOR.message_types_by_name['CreatePromptSetResponse'] +_GETPROMPTSETREQUEST = DESCRIPTOR.message_types_by_name['GetPromptSetRequest'] +_GETPROMPTSETRESPONSE = DESCRIPTOR.message_types_by_name['GetPromptSetResponse'] +_UPDATEPROMPTSETREQUEST = DESCRIPTOR.message_types_by_name['UpdatePromptSetRequest'] +_UPDATEPROMPTSETRESPONSE = DESCRIPTOR.message_types_by_name['UpdatePromptSetResponse'] +_DELETEPROMPTSETREQUEST = DESCRIPTOR.message_types_by_name['DeletePromptSetRequest'] +_DELETEPROMPTSETRESPONSE = DESCRIPTOR.message_types_by_name['DeletePromptSetResponse'] +_LISTPROMPTSETSREQUEST = DESCRIPTOR.message_types_by_name['ListPromptSetsRequest'] +_LISTPROMPTSETSRESPONSE = DESCRIPTOR.message_types_by_name['ListPromptSetsResponse'] +_CREATEPROMPTREQUEST = DESCRIPTOR.message_types_by_name['CreatePromptRequest'] +_CREATEPROMPTRESPONSE = DESCRIPTOR.message_types_by_name['CreatePromptResponse'] +_GETPROMPTREQUEST = DESCRIPTOR.message_types_by_name['GetPromptRequest'] +_GETPROMPTRESPONSE = DESCRIPTOR.message_types_by_name['GetPromptResponse'] +_UPDATEPROMPTREQUEST = DESCRIPTOR.message_types_by_name['UpdatePromptRequest'] +_UPDATEPROMPTRESPONSE = DESCRIPTOR.message_types_by_name['UpdatePromptResponse'] +_DELETEPROMPTREQUEST = DESCRIPTOR.message_types_by_name['DeletePromptRequest'] +_DELETEPROMPTRESPONSE = DESCRIPTOR.message_types_by_name['DeletePromptResponse'] +_LISTPROMPTSREQUEST = DESCRIPTOR.message_types_by_name['ListPromptsRequest'] +_LISTPROMPTSRESPONSE = DESCRIPTOR.message_types_by_name['ListPromptsResponse'] +_CREATEEMOTIONREQUEST = DESCRIPTOR.message_types_by_name['CreateEmotionRequest'] +_CREATEEMOTIONRESPONSE = DESCRIPTOR.message_types_by_name['CreateEmotionResponse'] +_GETEMOTIONREQUEST = DESCRIPTOR.message_types_by_name['GetEmotionRequest'] +_GETEMOTIONRESPONSE = DESCRIPTOR.message_types_by_name['GetEmotionResponse'] +_UPDATEEMOTIONREQUEST = DESCRIPTOR.message_types_by_name['UpdateEmotionRequest'] +_UPDATEEMOTIONRESPONSE = DESCRIPTOR.message_types_by_name['UpdateEmotionResponse'] +_DELETEEMOTIONREQUEST = DESCRIPTOR.message_types_by_name['DeleteEmotionRequest'] +_DELETEEMOTIONRESPONSE = DESCRIPTOR.message_types_by_name['DeleteEmotionResponse'] +_LISTEMOTIONSREQUEST = DESCRIPTOR.message_types_by_name['ListEmotionsRequest'] +_LISTEMOTIONSRESPONSE = DESCRIPTOR.message_types_by_name['ListEmotionsResponse'] +_CREATEPRIMERSETREQUEST = DESCRIPTOR.message_types_by_name['CreatePrimerSetRequest'] +_CREATEPRIMERSETRESPONSE = DESCRIPTOR.message_types_by_name['CreatePrimerSetResponse'] +_GETPRIMERSETREQUEST = DESCRIPTOR.message_types_by_name['GetPrimerSetRequest'] +_GETPRIMERSETRESPONSE = DESCRIPTOR.message_types_by_name['GetPrimerSetResponse'] +_UPDATEPRIMERSETREQUEST = DESCRIPTOR.message_types_by_name['UpdatePrimerSetRequest'] +_UPDATEPRIMERSETRESPONSE = DESCRIPTOR.message_types_by_name['UpdatePrimerSetResponse'] +_DELETEPRIMERSETREQUEST = DESCRIPTOR.message_types_by_name['DeletePrimerSetRequest'] +_DELETEPRIMERSETRESPONSE = DESCRIPTOR.message_types_by_name['DeletePrimerSetResponse'] +_LISTPRIMERSETSREQUEST = DESCRIPTOR.message_types_by_name['ListPrimerSetsRequest'] +_LISTPRIMERSETSRESPONSE = DESCRIPTOR.message_types_by_name['ListPrimerSetsResponse'] +_CREATEPRIMERREQUEST = DESCRIPTOR.message_types_by_name['CreatePrimerRequest'] +_CREATEPRIMERRESPONSE = DESCRIPTOR.message_types_by_name['CreatePrimerResponse'] +_GETPRIMERREQUEST = DESCRIPTOR.message_types_by_name['GetPrimerRequest'] +_GETPRIMERRESPONSE = DESCRIPTOR.message_types_by_name['GetPrimerResponse'] +_UPDATEPRIMERREQUEST = DESCRIPTOR.message_types_by_name['UpdatePrimerRequest'] +_UPDATEPRIMERRESPONSE = DESCRIPTOR.message_types_by_name['UpdatePrimerResponse'] +_DELETEPRIMERREQUEST = DESCRIPTOR.message_types_by_name['DeletePrimerRequest'] +_DELETEPRIMERRESPONSE = DESCRIPTOR.message_types_by_name['DeletePrimerResponse'] +_LISTPRIMERSREQUEST = DESCRIPTOR.message_types_by_name['ListPrimersRequest'] +_LISTPRIMERSRESPONSE = DESCRIPTOR.message_types_by_name['ListPrimersResponse'] +_CREATEEMOTIONALRULESSETREQUEST = DESCRIPTOR.message_types_by_name['CreateEmotionalRulesSetRequest'] +_CREATEEMOTIONALRULESSETRESPONSE = DESCRIPTOR.message_types_by_name['CreateEmotionalRulesSetResponse'] +_GETEMOTIONALRULESSETREQUEST = DESCRIPTOR.message_types_by_name['GetEmotionalRulesSetRequest'] +_GETEMOTIONALRULESSETRESPONSE = DESCRIPTOR.message_types_by_name['GetEmotionalRulesSetResponse'] +_UPDATEEMOTIONALRULESSETREQUEST = DESCRIPTOR.message_types_by_name['UpdateEmotionalRulesSetRequest'] +_UPDATEEMOTIONALRULESSETRESPONSE = DESCRIPTOR.message_types_by_name['UpdateEmotionalRulesSetResponse'] +_DELETEEMOTIONALRULESSETREQUEST = DESCRIPTOR.message_types_by_name['DeleteEmotionalRulesSetRequest'] +_DELETEEMOTIONALRULESSETRESPONSE = DESCRIPTOR.message_types_by_name['DeleteEmotionalRulesSetResponse'] +_LISTEMOTIONALRULESSETSREQUEST = DESCRIPTOR.message_types_by_name['ListEmotionalRulesSetsRequest'] +_LISTEMOTIONALRULESSETSRESPONSE = DESCRIPTOR.message_types_by_name['ListEmotionalRulesSetsResponse'] +_CREATEEMOTIONALRULEREQUEST = DESCRIPTOR.message_types_by_name['CreateEmotionalRuleRequest'] +_CREATEEMOTIONALRULERESPONSE = DESCRIPTOR.message_types_by_name['CreateEmotionalRuleResponse'] +_GETEMOTIONALRULEREQUEST = DESCRIPTOR.message_types_by_name['GetEmotionalRuleRequest'] +_GETEMOTIONALRULERESPONSE = DESCRIPTOR.message_types_by_name['GetEmotionalRuleResponse'] +_UPDATEEMOTIONALRULEREQUEST = DESCRIPTOR.message_types_by_name['UpdateEmotionalRuleRequest'] +_UPDATEEMOTIONALRULERESPONSE = DESCRIPTOR.message_types_by_name['UpdateEmotionalRuleResponse'] +_DELETEEMOTIONALRULEREQUEST = DESCRIPTOR.message_types_by_name['DeleteEmotionalRuleRequest'] +_DELETEEMOTIONALRULERESPONSE = DESCRIPTOR.message_types_by_name['DeleteEmotionalRuleResponse'] +_LISTEMOTIONALRULEREQUEST = DESCRIPTOR.message_types_by_name['ListEmotionalRuleRequest'] +_LISTEMOTIONALRULERESPONSE = DESCRIPTOR.message_types_by_name['ListEmotionalRuleResponse'] +_CREATEGENERATIONCONFIGREQUEST = DESCRIPTOR.message_types_by_name['CreateGenerationConfigRequest'] +_CREATEGENERATIONCONFIGRESPONSE = DESCRIPTOR.message_types_by_name['CreateGenerationConfigResponse'] +_GETGENERATIONCONFIGREQUEST = DESCRIPTOR.message_types_by_name['GetGenerationConfigRequest'] +_GETGENERATIONCONFIGRESPONSE = DESCRIPTOR.message_types_by_name['GetGenerationConfigResponse'] +_UPDATEGENERATIONCONFIGREQUEST = DESCRIPTOR.message_types_by_name['UpdateGenerationConfigRequest'] +_UPDATEGENERATIONCONFIGRESPONSE = DESCRIPTOR.message_types_by_name['UpdateGenerationConfigResponse'] +_DELETEGENERATIONCONFIGREQUEST = DESCRIPTOR.message_types_by_name['DeleteGenerationConfigRequest'] +_DELETEGENERATIONCONFIGRESPONSE = DESCRIPTOR.message_types_by_name['DeleteGenerationConfigResponse'] +_LISTGENERATIONCONFIGSREQUEST = DESCRIPTOR.message_types_by_name['ListGenerationConfigsRequest'] +_LISTGENERATIONCONFIGSRESPONSE = DESCRIPTOR.message_types_by_name['ListGenerationConfigsResponse'] +_CREATECONVERSATIONREQUEST = DESCRIPTOR.message_types_by_name['CreateConversationRequest'] +_CREATECONVERSATIONRESPONSE = DESCRIPTOR.message_types_by_name['CreateConversationResponse'] +_JOINCONVERSATIONREQUEST = DESCRIPTOR.message_types_by_name['JoinConversationRequest'] +_SENDMESSAGEREQUEST = DESCRIPTOR.message_types_by_name['SendMessageRequest'] +_SENDMESSAGERESPONSE = DESCRIPTOR.message_types_by_name['SendMessageResponse'] +_WORLD = DESCRIPTOR.message_types_by_name['World'] +_CONVERSATION = DESCRIPTOR.message_types_by_name['Conversation'] +_PLAYER = DESCRIPTOR.message_types_by_name['Player'] +_MESSAGE = DESCRIPTOR.message_types_by_name['Message'] +_CHARACTER = DESCRIPTOR.message_types_by_name['Character'] +_GENERATIONCONFIG = DESCRIPTOR.message_types_by_name['GenerationConfig'] +_EMOTION = DESCRIPTOR.message_types_by_name['Emotion'] +_EMOTIONALRULESSET = DESCRIPTOR.message_types_by_name['EmotionalRulesSet'] +_EMOTIONALRULE = DESCRIPTOR.message_types_by_name['EmotionalRule'] +_RULEREQUIREMENT = DESCRIPTOR.message_types_by_name['RuleRequirement'] +_EMOTIONALRULEEFFECT = DESCRIPTOR.message_types_by_name['EmotionalRuleEffect'] +_PROMPTSET = DESCRIPTOR.message_types_by_name['PromptSet'] +_PROMPT = DESCRIPTOR.message_types_by_name['Prompt'] +_PRIMERSET = DESCRIPTOR.message_types_by_name['PrimerSet'] +_PRIMER = DESCRIPTOR.message_types_by_name['Primer'] +_TASK = DESCRIPTOR.message_types_by_name['Task'] +_TASKSTEP = DESCRIPTOR.message_types_by_name['TaskStep'] +_COMMITMESSAGEREQUEST_COMMITTYPE = _COMMITMESSAGEREQUEST.enum_types_by_name['CommitType'] +_RULEREQUIREMENT_REQUIREMENTTYPE = _RULEREQUIREMENT.enum_types_by_name['RequirementType'] +_EMOTIONALRULEEFFECT_EFFECTTYPE = _EMOTIONALRULEEFFECT.enum_types_by_name['EffectType'] +_EMOTIONALRULEEFFECT_EFFECTMULTIPLIERTYPE = _EMOTIONALRULEEFFECT.enum_types_by_name['EffectMultiplierType'] +CreateCharacterNeedRequest = _reflection.GeneratedProtocolMessageType('CreateCharacterNeedRequest', (_message.Message,), { + 'DESCRIPTOR' : _CREATECHARACTERNEEDREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CreateCharacterNeedRequest) + }) +_sym_db.RegisterMessage(CreateCharacterNeedRequest) + +CreateCharacterNeedResponse = _reflection.GeneratedProtocolMessageType('CreateCharacterNeedResponse', (_message.Message,), { + 'DESCRIPTOR' : _CREATECHARACTERNEEDRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CreateCharacterNeedResponse) + }) +_sym_db.RegisterMessage(CreateCharacterNeedResponse) + +GetCharacterTasksRequest = _reflection.GeneratedProtocolMessageType('GetCharacterTasksRequest', (_message.Message,), { + 'DESCRIPTOR' : _GETCHARACTERTASKSREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.GetCharacterTasksRequest) + }) +_sym_db.RegisterMessage(GetCharacterTasksRequest) + +GetCharacterTasksResponse = _reflection.GeneratedProtocolMessageType('GetCharacterTasksResponse', (_message.Message,), { + 'DESCRIPTOR' : _GETCHARACTERTASKSRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.GetCharacterTasksResponse) + }) +_sym_db.RegisterMessage(GetCharacterTasksResponse) + +CompleteCharacterTaskRequest = _reflection.GeneratedProtocolMessageType('CompleteCharacterTaskRequest', (_message.Message,), { + 'DESCRIPTOR' : _COMPLETECHARACTERTASKREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CompleteCharacterTaskRequest) + }) +_sym_db.RegisterMessage(CompleteCharacterTaskRequest) + +CompleteCharacterTaskResponse = _reflection.GeneratedProtocolMessageType('CompleteCharacterTaskResponse', (_message.Message,), { + 'DESCRIPTOR' : _COMPLETECHARACTERTASKRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CompleteCharacterTaskResponse) + }) +_sym_db.RegisterMessage(CompleteCharacterTaskResponse) + +CommitMessageRequest = _reflection.GeneratedProtocolMessageType('CommitMessageRequest', (_message.Message,), { + 'DESCRIPTOR' : _COMMITMESSAGEREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CommitMessageRequest) + }) +_sym_db.RegisterMessage(CommitMessageRequest) + +CommitMessageResponse = _reflection.GeneratedProtocolMessageType('CommitMessageResponse', (_message.Message,), { + 'DESCRIPTOR' : _COMMITMESSAGERESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CommitMessageResponse) + }) +_sym_db.RegisterMessage(CommitMessageResponse) + +CreateCharacterRequest = _reflection.GeneratedProtocolMessageType('CreateCharacterRequest', (_message.Message,), { + 'DESCRIPTOR' : _CREATECHARACTERREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CreateCharacterRequest) + }) +_sym_db.RegisterMessage(CreateCharacterRequest) + +CreateCharacterResponse = _reflection.GeneratedProtocolMessageType('CreateCharacterResponse', (_message.Message,), { + 'DESCRIPTOR' : _CREATECHARACTERRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CreateCharacterResponse) + }) +_sym_db.RegisterMessage(CreateCharacterResponse) + +GetCharacterRequest = _reflection.GeneratedProtocolMessageType('GetCharacterRequest', (_message.Message,), { + 'DESCRIPTOR' : _GETCHARACTERREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.GetCharacterRequest) + }) +_sym_db.RegisterMessage(GetCharacterRequest) + +GetCharacterResponse = _reflection.GeneratedProtocolMessageType('GetCharacterResponse', (_message.Message,), { + 'DESCRIPTOR' : _GETCHARACTERRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.GetCharacterResponse) + }) +_sym_db.RegisterMessage(GetCharacterResponse) + +UpdateCharacterRequest = _reflection.GeneratedProtocolMessageType('UpdateCharacterRequest', (_message.Message,), { + 'DESCRIPTOR' : _UPDATECHARACTERREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.UpdateCharacterRequest) + }) +_sym_db.RegisterMessage(UpdateCharacterRequest) + +UpdateCharacterResponse = _reflection.GeneratedProtocolMessageType('UpdateCharacterResponse', (_message.Message,), { + 'DESCRIPTOR' : _UPDATECHARACTERRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.UpdateCharacterResponse) + }) +_sym_db.RegisterMessage(UpdateCharacterResponse) + +DeleteCharacterRequest = _reflection.GeneratedProtocolMessageType('DeleteCharacterRequest', (_message.Message,), { + 'DESCRIPTOR' : _DELETECHARACTERREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.DeleteCharacterRequest) + }) +_sym_db.RegisterMessage(DeleteCharacterRequest) + +DeleteCharacterResponse = _reflection.GeneratedProtocolMessageType('DeleteCharacterResponse', (_message.Message,), { + 'DESCRIPTOR' : _DELETECHARACTERRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.DeleteCharacterResponse) + }) +_sym_db.RegisterMessage(DeleteCharacterResponse) + +ListCharactersRequest = _reflection.GeneratedProtocolMessageType('ListCharactersRequest', (_message.Message,), { + 'DESCRIPTOR' : _LISTCHARACTERSREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.ListCharactersRequest) + }) +_sym_db.RegisterMessage(ListCharactersRequest) + +ListCharactersResponse = _reflection.GeneratedProtocolMessageType('ListCharactersResponse', (_message.Message,), { + 'DESCRIPTOR' : _LISTCHARACTERSRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.ListCharactersResponse) + }) +_sym_db.RegisterMessage(ListCharactersResponse) + +CreatePlayerRequest = _reflection.GeneratedProtocolMessageType('CreatePlayerRequest', (_message.Message,), { + 'DESCRIPTOR' : _CREATEPLAYERREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CreatePlayerRequest) + }) +_sym_db.RegisterMessage(CreatePlayerRequest) + +CreatePlayerResponse = _reflection.GeneratedProtocolMessageType('CreatePlayerResponse', (_message.Message,), { + 'DESCRIPTOR' : _CREATEPLAYERRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CreatePlayerResponse) + }) +_sym_db.RegisterMessage(CreatePlayerResponse) + +GetPlayerRequest = _reflection.GeneratedProtocolMessageType('GetPlayerRequest', (_message.Message,), { + 'DESCRIPTOR' : _GETPLAYERREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.GetPlayerRequest) + }) +_sym_db.RegisterMessage(GetPlayerRequest) + +GetPlayerResponse = _reflection.GeneratedProtocolMessageType('GetPlayerResponse', (_message.Message,), { + 'DESCRIPTOR' : _GETPLAYERRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.GetPlayerResponse) + }) +_sym_db.RegisterMessage(GetPlayerResponse) + +UpdatePlayerRequest = _reflection.GeneratedProtocolMessageType('UpdatePlayerRequest', (_message.Message,), { + 'DESCRIPTOR' : _UPDATEPLAYERREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.UpdatePlayerRequest) + }) +_sym_db.RegisterMessage(UpdatePlayerRequest) + +UpdatePlayerResponse = _reflection.GeneratedProtocolMessageType('UpdatePlayerResponse', (_message.Message,), { + 'DESCRIPTOR' : _UPDATEPLAYERRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.UpdatePlayerResponse) + }) +_sym_db.RegisterMessage(UpdatePlayerResponse) + +DeletePlayerRequest = _reflection.GeneratedProtocolMessageType('DeletePlayerRequest', (_message.Message,), { + 'DESCRIPTOR' : _DELETEPLAYERREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.DeletePlayerRequest) + }) +_sym_db.RegisterMessage(DeletePlayerRequest) + +DeletePlayerResponse = _reflection.GeneratedProtocolMessageType('DeletePlayerResponse', (_message.Message,), { + 'DESCRIPTOR' : _DELETEPLAYERRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.DeletePlayerResponse) + }) +_sym_db.RegisterMessage(DeletePlayerResponse) + +ListPlayersRequest = _reflection.GeneratedProtocolMessageType('ListPlayersRequest', (_message.Message,), { + 'DESCRIPTOR' : _LISTPLAYERSREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.ListPlayersRequest) + }) +_sym_db.RegisterMessage(ListPlayersRequest) + +ListPlayersResponse = _reflection.GeneratedProtocolMessageType('ListPlayersResponse', (_message.Message,), { + 'DESCRIPTOR' : _LISTPLAYERSRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.ListPlayersResponse) + }) +_sym_db.RegisterMessage(ListPlayersResponse) + +CreatePromptSetRequest = _reflection.GeneratedProtocolMessageType('CreatePromptSetRequest', (_message.Message,), { + 'DESCRIPTOR' : _CREATEPROMPTSETREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CreatePromptSetRequest) + }) +_sym_db.RegisterMessage(CreatePromptSetRequest) + +CreatePromptSetResponse = _reflection.GeneratedProtocolMessageType('CreatePromptSetResponse', (_message.Message,), { + 'DESCRIPTOR' : _CREATEPROMPTSETRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CreatePromptSetResponse) + }) +_sym_db.RegisterMessage(CreatePromptSetResponse) + +GetPromptSetRequest = _reflection.GeneratedProtocolMessageType('GetPromptSetRequest', (_message.Message,), { + 'DESCRIPTOR' : _GETPROMPTSETREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.GetPromptSetRequest) + }) +_sym_db.RegisterMessage(GetPromptSetRequest) + +GetPromptSetResponse = _reflection.GeneratedProtocolMessageType('GetPromptSetResponse', (_message.Message,), { + 'DESCRIPTOR' : _GETPROMPTSETRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.GetPromptSetResponse) + }) +_sym_db.RegisterMessage(GetPromptSetResponse) + +UpdatePromptSetRequest = _reflection.GeneratedProtocolMessageType('UpdatePromptSetRequest', (_message.Message,), { + 'DESCRIPTOR' : _UPDATEPROMPTSETREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.UpdatePromptSetRequest) + }) +_sym_db.RegisterMessage(UpdatePromptSetRequest) + +UpdatePromptSetResponse = _reflection.GeneratedProtocolMessageType('UpdatePromptSetResponse', (_message.Message,), { + 'DESCRIPTOR' : _UPDATEPROMPTSETRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.UpdatePromptSetResponse) + }) +_sym_db.RegisterMessage(UpdatePromptSetResponse) + +DeletePromptSetRequest = _reflection.GeneratedProtocolMessageType('DeletePromptSetRequest', (_message.Message,), { + 'DESCRIPTOR' : _DELETEPROMPTSETREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.DeletePromptSetRequest) + }) +_sym_db.RegisterMessage(DeletePromptSetRequest) + +DeletePromptSetResponse = _reflection.GeneratedProtocolMessageType('DeletePromptSetResponse', (_message.Message,), { + 'DESCRIPTOR' : _DELETEPROMPTSETRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.DeletePromptSetResponse) + }) +_sym_db.RegisterMessage(DeletePromptSetResponse) + +ListPromptSetsRequest = _reflection.GeneratedProtocolMessageType('ListPromptSetsRequest', (_message.Message,), { + 'DESCRIPTOR' : _LISTPROMPTSETSREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.ListPromptSetsRequest) + }) +_sym_db.RegisterMessage(ListPromptSetsRequest) + +ListPromptSetsResponse = _reflection.GeneratedProtocolMessageType('ListPromptSetsResponse', (_message.Message,), { + 'DESCRIPTOR' : _LISTPROMPTSETSRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.ListPromptSetsResponse) + }) +_sym_db.RegisterMessage(ListPromptSetsResponse) + +CreatePromptRequest = _reflection.GeneratedProtocolMessageType('CreatePromptRequest', (_message.Message,), { + 'DESCRIPTOR' : _CREATEPROMPTREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CreatePromptRequest) + }) +_sym_db.RegisterMessage(CreatePromptRequest) + +CreatePromptResponse = _reflection.GeneratedProtocolMessageType('CreatePromptResponse', (_message.Message,), { + 'DESCRIPTOR' : _CREATEPROMPTRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CreatePromptResponse) + }) +_sym_db.RegisterMessage(CreatePromptResponse) + +GetPromptRequest = _reflection.GeneratedProtocolMessageType('GetPromptRequest', (_message.Message,), { + 'DESCRIPTOR' : _GETPROMPTREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.GetPromptRequest) + }) +_sym_db.RegisterMessage(GetPromptRequest) + +GetPromptResponse = _reflection.GeneratedProtocolMessageType('GetPromptResponse', (_message.Message,), { + 'DESCRIPTOR' : _GETPROMPTRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.GetPromptResponse) + }) +_sym_db.RegisterMessage(GetPromptResponse) + +UpdatePromptRequest = _reflection.GeneratedProtocolMessageType('UpdatePromptRequest', (_message.Message,), { + 'DESCRIPTOR' : _UPDATEPROMPTREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.UpdatePromptRequest) + }) +_sym_db.RegisterMessage(UpdatePromptRequest) + +UpdatePromptResponse = _reflection.GeneratedProtocolMessageType('UpdatePromptResponse', (_message.Message,), { + 'DESCRIPTOR' : _UPDATEPROMPTRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.UpdatePromptResponse) + }) +_sym_db.RegisterMessage(UpdatePromptResponse) + +DeletePromptRequest = _reflection.GeneratedProtocolMessageType('DeletePromptRequest', (_message.Message,), { + 'DESCRIPTOR' : _DELETEPROMPTREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.DeletePromptRequest) + }) +_sym_db.RegisterMessage(DeletePromptRequest) + +DeletePromptResponse = _reflection.GeneratedProtocolMessageType('DeletePromptResponse', (_message.Message,), { + 'DESCRIPTOR' : _DELETEPROMPTRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.DeletePromptResponse) + }) +_sym_db.RegisterMessage(DeletePromptResponse) + +ListPromptsRequest = _reflection.GeneratedProtocolMessageType('ListPromptsRequest', (_message.Message,), { + 'DESCRIPTOR' : _LISTPROMPTSREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.ListPromptsRequest) + }) +_sym_db.RegisterMessage(ListPromptsRequest) + +ListPromptsResponse = _reflection.GeneratedProtocolMessageType('ListPromptsResponse', (_message.Message,), { + 'DESCRIPTOR' : _LISTPROMPTSRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.ListPromptsResponse) + }) +_sym_db.RegisterMessage(ListPromptsResponse) + +CreateEmotionRequest = _reflection.GeneratedProtocolMessageType('CreateEmotionRequest', (_message.Message,), { + 'DESCRIPTOR' : _CREATEEMOTIONREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CreateEmotionRequest) + }) +_sym_db.RegisterMessage(CreateEmotionRequest) + +CreateEmotionResponse = _reflection.GeneratedProtocolMessageType('CreateEmotionResponse', (_message.Message,), { + 'DESCRIPTOR' : _CREATEEMOTIONRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CreateEmotionResponse) + }) +_sym_db.RegisterMessage(CreateEmotionResponse) + +GetEmotionRequest = _reflection.GeneratedProtocolMessageType('GetEmotionRequest', (_message.Message,), { + 'DESCRIPTOR' : _GETEMOTIONREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.GetEmotionRequest) + }) +_sym_db.RegisterMessage(GetEmotionRequest) + +GetEmotionResponse = _reflection.GeneratedProtocolMessageType('GetEmotionResponse', (_message.Message,), { + 'DESCRIPTOR' : _GETEMOTIONRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.GetEmotionResponse) + }) +_sym_db.RegisterMessage(GetEmotionResponse) + +UpdateEmotionRequest = _reflection.GeneratedProtocolMessageType('UpdateEmotionRequest', (_message.Message,), { + 'DESCRIPTOR' : _UPDATEEMOTIONREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.UpdateEmotionRequest) + }) +_sym_db.RegisterMessage(UpdateEmotionRequest) + +UpdateEmotionResponse = _reflection.GeneratedProtocolMessageType('UpdateEmotionResponse', (_message.Message,), { + 'DESCRIPTOR' : _UPDATEEMOTIONRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.UpdateEmotionResponse) + }) +_sym_db.RegisterMessage(UpdateEmotionResponse) + +DeleteEmotionRequest = _reflection.GeneratedProtocolMessageType('DeleteEmotionRequest', (_message.Message,), { + 'DESCRIPTOR' : _DELETEEMOTIONREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.DeleteEmotionRequest) + }) +_sym_db.RegisterMessage(DeleteEmotionRequest) + +DeleteEmotionResponse = _reflection.GeneratedProtocolMessageType('DeleteEmotionResponse', (_message.Message,), { + 'DESCRIPTOR' : _DELETEEMOTIONRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.DeleteEmotionResponse) + }) +_sym_db.RegisterMessage(DeleteEmotionResponse) + +ListEmotionsRequest = _reflection.GeneratedProtocolMessageType('ListEmotionsRequest', (_message.Message,), { + 'DESCRIPTOR' : _LISTEMOTIONSREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.ListEmotionsRequest) + }) +_sym_db.RegisterMessage(ListEmotionsRequest) + +ListEmotionsResponse = _reflection.GeneratedProtocolMessageType('ListEmotionsResponse', (_message.Message,), { + 'DESCRIPTOR' : _LISTEMOTIONSRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.ListEmotionsResponse) + }) +_sym_db.RegisterMessage(ListEmotionsResponse) + +CreatePrimerSetRequest = _reflection.GeneratedProtocolMessageType('CreatePrimerSetRequest', (_message.Message,), { + 'DESCRIPTOR' : _CREATEPRIMERSETREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CreatePrimerSetRequest) + }) +_sym_db.RegisterMessage(CreatePrimerSetRequest) + +CreatePrimerSetResponse = _reflection.GeneratedProtocolMessageType('CreatePrimerSetResponse', (_message.Message,), { + 'DESCRIPTOR' : _CREATEPRIMERSETRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CreatePrimerSetResponse) + }) +_sym_db.RegisterMessage(CreatePrimerSetResponse) + +GetPrimerSetRequest = _reflection.GeneratedProtocolMessageType('GetPrimerSetRequest', (_message.Message,), { + 'DESCRIPTOR' : _GETPRIMERSETREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.GetPrimerSetRequest) + }) +_sym_db.RegisterMessage(GetPrimerSetRequest) + +GetPrimerSetResponse = _reflection.GeneratedProtocolMessageType('GetPrimerSetResponse', (_message.Message,), { + 'DESCRIPTOR' : _GETPRIMERSETRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.GetPrimerSetResponse) + }) +_sym_db.RegisterMessage(GetPrimerSetResponse) + +UpdatePrimerSetRequest = _reflection.GeneratedProtocolMessageType('UpdatePrimerSetRequest', (_message.Message,), { + 'DESCRIPTOR' : _UPDATEPRIMERSETREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.UpdatePrimerSetRequest) + }) +_sym_db.RegisterMessage(UpdatePrimerSetRequest) + +UpdatePrimerSetResponse = _reflection.GeneratedProtocolMessageType('UpdatePrimerSetResponse', (_message.Message,), { + 'DESCRIPTOR' : _UPDATEPRIMERSETRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.UpdatePrimerSetResponse) + }) +_sym_db.RegisterMessage(UpdatePrimerSetResponse) + +DeletePrimerSetRequest = _reflection.GeneratedProtocolMessageType('DeletePrimerSetRequest', (_message.Message,), { + 'DESCRIPTOR' : _DELETEPRIMERSETREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.DeletePrimerSetRequest) + }) +_sym_db.RegisterMessage(DeletePrimerSetRequest) + +DeletePrimerSetResponse = _reflection.GeneratedProtocolMessageType('DeletePrimerSetResponse', (_message.Message,), { + 'DESCRIPTOR' : _DELETEPRIMERSETRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.DeletePrimerSetResponse) + }) +_sym_db.RegisterMessage(DeletePrimerSetResponse) + +ListPrimerSetsRequest = _reflection.GeneratedProtocolMessageType('ListPrimerSetsRequest', (_message.Message,), { + 'DESCRIPTOR' : _LISTPRIMERSETSREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.ListPrimerSetsRequest) + }) +_sym_db.RegisterMessage(ListPrimerSetsRequest) + +ListPrimerSetsResponse = _reflection.GeneratedProtocolMessageType('ListPrimerSetsResponse', (_message.Message,), { + 'DESCRIPTOR' : _LISTPRIMERSETSRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.ListPrimerSetsResponse) + }) +_sym_db.RegisterMessage(ListPrimerSetsResponse) + +CreatePrimerRequest = _reflection.GeneratedProtocolMessageType('CreatePrimerRequest', (_message.Message,), { + 'DESCRIPTOR' : _CREATEPRIMERREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CreatePrimerRequest) + }) +_sym_db.RegisterMessage(CreatePrimerRequest) + +CreatePrimerResponse = _reflection.GeneratedProtocolMessageType('CreatePrimerResponse', (_message.Message,), { + 'DESCRIPTOR' : _CREATEPRIMERRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CreatePrimerResponse) + }) +_sym_db.RegisterMessage(CreatePrimerResponse) + +GetPrimerRequest = _reflection.GeneratedProtocolMessageType('GetPrimerRequest', (_message.Message,), { + 'DESCRIPTOR' : _GETPRIMERREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.GetPrimerRequest) + }) +_sym_db.RegisterMessage(GetPrimerRequest) + +GetPrimerResponse = _reflection.GeneratedProtocolMessageType('GetPrimerResponse', (_message.Message,), { + 'DESCRIPTOR' : _GETPRIMERRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.GetPrimerResponse) + }) +_sym_db.RegisterMessage(GetPrimerResponse) + +UpdatePrimerRequest = _reflection.GeneratedProtocolMessageType('UpdatePrimerRequest', (_message.Message,), { + 'DESCRIPTOR' : _UPDATEPRIMERREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.UpdatePrimerRequest) + }) +_sym_db.RegisterMessage(UpdatePrimerRequest) + +UpdatePrimerResponse = _reflection.GeneratedProtocolMessageType('UpdatePrimerResponse', (_message.Message,), { + 'DESCRIPTOR' : _UPDATEPRIMERRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.UpdatePrimerResponse) + }) +_sym_db.RegisterMessage(UpdatePrimerResponse) + +DeletePrimerRequest = _reflection.GeneratedProtocolMessageType('DeletePrimerRequest', (_message.Message,), { + 'DESCRIPTOR' : _DELETEPRIMERREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.DeletePrimerRequest) + }) +_sym_db.RegisterMessage(DeletePrimerRequest) + +DeletePrimerResponse = _reflection.GeneratedProtocolMessageType('DeletePrimerResponse', (_message.Message,), { + 'DESCRIPTOR' : _DELETEPRIMERRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.DeletePrimerResponse) + }) +_sym_db.RegisterMessage(DeletePrimerResponse) + +ListPrimersRequest = _reflection.GeneratedProtocolMessageType('ListPrimersRequest', (_message.Message,), { + 'DESCRIPTOR' : _LISTPRIMERSREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.ListPrimersRequest) + }) +_sym_db.RegisterMessage(ListPrimersRequest) + +ListPrimersResponse = _reflection.GeneratedProtocolMessageType('ListPrimersResponse', (_message.Message,), { + 'DESCRIPTOR' : _LISTPRIMERSRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.ListPrimersResponse) + }) +_sym_db.RegisterMessage(ListPrimersResponse) + +CreateEmotionalRulesSetRequest = _reflection.GeneratedProtocolMessageType('CreateEmotionalRulesSetRequest', (_message.Message,), { + 'DESCRIPTOR' : _CREATEEMOTIONALRULESSETREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CreateEmotionalRulesSetRequest) + }) +_sym_db.RegisterMessage(CreateEmotionalRulesSetRequest) + +CreateEmotionalRulesSetResponse = _reflection.GeneratedProtocolMessageType('CreateEmotionalRulesSetResponse', (_message.Message,), { + 'DESCRIPTOR' : _CREATEEMOTIONALRULESSETRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CreateEmotionalRulesSetResponse) + }) +_sym_db.RegisterMessage(CreateEmotionalRulesSetResponse) + +GetEmotionalRulesSetRequest = _reflection.GeneratedProtocolMessageType('GetEmotionalRulesSetRequest', (_message.Message,), { + 'DESCRIPTOR' : _GETEMOTIONALRULESSETREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.GetEmotionalRulesSetRequest) + }) +_sym_db.RegisterMessage(GetEmotionalRulesSetRequest) + +GetEmotionalRulesSetResponse = _reflection.GeneratedProtocolMessageType('GetEmotionalRulesSetResponse', (_message.Message,), { + 'DESCRIPTOR' : _GETEMOTIONALRULESSETRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.GetEmotionalRulesSetResponse) + }) +_sym_db.RegisterMessage(GetEmotionalRulesSetResponse) + +UpdateEmotionalRulesSetRequest = _reflection.GeneratedProtocolMessageType('UpdateEmotionalRulesSetRequest', (_message.Message,), { + 'DESCRIPTOR' : _UPDATEEMOTIONALRULESSETREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.UpdateEmotionalRulesSetRequest) + }) +_sym_db.RegisterMessage(UpdateEmotionalRulesSetRequest) + +UpdateEmotionalRulesSetResponse = _reflection.GeneratedProtocolMessageType('UpdateEmotionalRulesSetResponse', (_message.Message,), { + 'DESCRIPTOR' : _UPDATEEMOTIONALRULESSETRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.UpdateEmotionalRulesSetResponse) + }) +_sym_db.RegisterMessage(UpdateEmotionalRulesSetResponse) + +DeleteEmotionalRulesSetRequest = _reflection.GeneratedProtocolMessageType('DeleteEmotionalRulesSetRequest', (_message.Message,), { + 'DESCRIPTOR' : _DELETEEMOTIONALRULESSETREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.DeleteEmotionalRulesSetRequest) + }) +_sym_db.RegisterMessage(DeleteEmotionalRulesSetRequest) + +DeleteEmotionalRulesSetResponse = _reflection.GeneratedProtocolMessageType('DeleteEmotionalRulesSetResponse', (_message.Message,), { + 'DESCRIPTOR' : _DELETEEMOTIONALRULESSETRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.DeleteEmotionalRulesSetResponse) + }) +_sym_db.RegisterMessage(DeleteEmotionalRulesSetResponse) + +ListEmotionalRulesSetsRequest = _reflection.GeneratedProtocolMessageType('ListEmotionalRulesSetsRequest', (_message.Message,), { + 'DESCRIPTOR' : _LISTEMOTIONALRULESSETSREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.ListEmotionalRulesSetsRequest) + }) +_sym_db.RegisterMessage(ListEmotionalRulesSetsRequest) + +ListEmotionalRulesSetsResponse = _reflection.GeneratedProtocolMessageType('ListEmotionalRulesSetsResponse', (_message.Message,), { + 'DESCRIPTOR' : _LISTEMOTIONALRULESSETSRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.ListEmotionalRulesSetsResponse) + }) +_sym_db.RegisterMessage(ListEmotionalRulesSetsResponse) + +CreateEmotionalRuleRequest = _reflection.GeneratedProtocolMessageType('CreateEmotionalRuleRequest', (_message.Message,), { + 'DESCRIPTOR' : _CREATEEMOTIONALRULEREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CreateEmotionalRuleRequest) + }) +_sym_db.RegisterMessage(CreateEmotionalRuleRequest) + +CreateEmotionalRuleResponse = _reflection.GeneratedProtocolMessageType('CreateEmotionalRuleResponse', (_message.Message,), { + 'DESCRIPTOR' : _CREATEEMOTIONALRULERESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CreateEmotionalRuleResponse) + }) +_sym_db.RegisterMessage(CreateEmotionalRuleResponse) + +GetEmotionalRuleRequest = _reflection.GeneratedProtocolMessageType('GetEmotionalRuleRequest', (_message.Message,), { + 'DESCRIPTOR' : _GETEMOTIONALRULEREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.GetEmotionalRuleRequest) + }) +_sym_db.RegisterMessage(GetEmotionalRuleRequest) + +GetEmotionalRuleResponse = _reflection.GeneratedProtocolMessageType('GetEmotionalRuleResponse', (_message.Message,), { + 'DESCRIPTOR' : _GETEMOTIONALRULERESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.GetEmotionalRuleResponse) + }) +_sym_db.RegisterMessage(GetEmotionalRuleResponse) + +UpdateEmotionalRuleRequest = _reflection.GeneratedProtocolMessageType('UpdateEmotionalRuleRequest', (_message.Message,), { + 'DESCRIPTOR' : _UPDATEEMOTIONALRULEREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.UpdateEmotionalRuleRequest) + }) +_sym_db.RegisterMessage(UpdateEmotionalRuleRequest) + +UpdateEmotionalRuleResponse = _reflection.GeneratedProtocolMessageType('UpdateEmotionalRuleResponse', (_message.Message,), { + 'DESCRIPTOR' : _UPDATEEMOTIONALRULERESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.UpdateEmotionalRuleResponse) + }) +_sym_db.RegisterMessage(UpdateEmotionalRuleResponse) + +DeleteEmotionalRuleRequest = _reflection.GeneratedProtocolMessageType('DeleteEmotionalRuleRequest', (_message.Message,), { + 'DESCRIPTOR' : _DELETEEMOTIONALRULEREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.DeleteEmotionalRuleRequest) + }) +_sym_db.RegisterMessage(DeleteEmotionalRuleRequest) + +DeleteEmotionalRuleResponse = _reflection.GeneratedProtocolMessageType('DeleteEmotionalRuleResponse', (_message.Message,), { + 'DESCRIPTOR' : _DELETEEMOTIONALRULERESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.DeleteEmotionalRuleResponse) + }) +_sym_db.RegisterMessage(DeleteEmotionalRuleResponse) + +ListEmotionalRuleRequest = _reflection.GeneratedProtocolMessageType('ListEmotionalRuleRequest', (_message.Message,), { + 'DESCRIPTOR' : _LISTEMOTIONALRULEREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.ListEmotionalRuleRequest) + }) +_sym_db.RegisterMessage(ListEmotionalRuleRequest) + +ListEmotionalRuleResponse = _reflection.GeneratedProtocolMessageType('ListEmotionalRuleResponse', (_message.Message,), { + 'DESCRIPTOR' : _LISTEMOTIONALRULERESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.ListEmotionalRuleResponse) + }) +_sym_db.RegisterMessage(ListEmotionalRuleResponse) + +CreateGenerationConfigRequest = _reflection.GeneratedProtocolMessageType('CreateGenerationConfigRequest', (_message.Message,), { + 'DESCRIPTOR' : _CREATEGENERATIONCONFIGREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CreateGenerationConfigRequest) + }) +_sym_db.RegisterMessage(CreateGenerationConfigRequest) + +CreateGenerationConfigResponse = _reflection.GeneratedProtocolMessageType('CreateGenerationConfigResponse', (_message.Message,), { + 'DESCRIPTOR' : _CREATEGENERATIONCONFIGRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CreateGenerationConfigResponse) + }) +_sym_db.RegisterMessage(CreateGenerationConfigResponse) + +GetGenerationConfigRequest = _reflection.GeneratedProtocolMessageType('GetGenerationConfigRequest', (_message.Message,), { + 'DESCRIPTOR' : _GETGENERATIONCONFIGREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.GetGenerationConfigRequest) + }) +_sym_db.RegisterMessage(GetGenerationConfigRequest) + +GetGenerationConfigResponse = _reflection.GeneratedProtocolMessageType('GetGenerationConfigResponse', (_message.Message,), { + 'DESCRIPTOR' : _GETGENERATIONCONFIGRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.GetGenerationConfigResponse) + }) +_sym_db.RegisterMessage(GetGenerationConfigResponse) + +UpdateGenerationConfigRequest = _reflection.GeneratedProtocolMessageType('UpdateGenerationConfigRequest', (_message.Message,), { + 'DESCRIPTOR' : _UPDATEGENERATIONCONFIGREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.UpdateGenerationConfigRequest) + }) +_sym_db.RegisterMessage(UpdateGenerationConfigRequest) + +UpdateGenerationConfigResponse = _reflection.GeneratedProtocolMessageType('UpdateGenerationConfigResponse', (_message.Message,), { + 'DESCRIPTOR' : _UPDATEGENERATIONCONFIGRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.UpdateGenerationConfigResponse) + }) +_sym_db.RegisterMessage(UpdateGenerationConfigResponse) + +DeleteGenerationConfigRequest = _reflection.GeneratedProtocolMessageType('DeleteGenerationConfigRequest', (_message.Message,), { + 'DESCRIPTOR' : _DELETEGENERATIONCONFIGREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.DeleteGenerationConfigRequest) + }) +_sym_db.RegisterMessage(DeleteGenerationConfigRequest) + +DeleteGenerationConfigResponse = _reflection.GeneratedProtocolMessageType('DeleteGenerationConfigResponse', (_message.Message,), { + 'DESCRIPTOR' : _DELETEGENERATIONCONFIGRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.DeleteGenerationConfigResponse) + }) +_sym_db.RegisterMessage(DeleteGenerationConfigResponse) + +ListGenerationConfigsRequest = _reflection.GeneratedProtocolMessageType('ListGenerationConfigsRequest', (_message.Message,), { + 'DESCRIPTOR' : _LISTGENERATIONCONFIGSREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.ListGenerationConfigsRequest) + }) +_sym_db.RegisterMessage(ListGenerationConfigsRequest) + +ListGenerationConfigsResponse = _reflection.GeneratedProtocolMessageType('ListGenerationConfigsResponse', (_message.Message,), { + 'DESCRIPTOR' : _LISTGENERATIONCONFIGSRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.ListGenerationConfigsResponse) + }) +_sym_db.RegisterMessage(ListGenerationConfigsResponse) + +CreateConversationRequest = _reflection.GeneratedProtocolMessageType('CreateConversationRequest', (_message.Message,), { + 'DESCRIPTOR' : _CREATECONVERSATIONREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CreateConversationRequest) + }) +_sym_db.RegisterMessage(CreateConversationRequest) + +CreateConversationResponse = _reflection.GeneratedProtocolMessageType('CreateConversationResponse', (_message.Message,), { + 'DESCRIPTOR' : _CREATECONVERSATIONRESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.CreateConversationResponse) + }) +_sym_db.RegisterMessage(CreateConversationResponse) + +JoinConversationRequest = _reflection.GeneratedProtocolMessageType('JoinConversationRequest', (_message.Message,), { + 'DESCRIPTOR' : _JOINCONVERSATIONREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.JoinConversationRequest) + }) +_sym_db.RegisterMessage(JoinConversationRequest) + +SendMessageRequest = _reflection.GeneratedProtocolMessageType('SendMessageRequest', (_message.Message,), { + 'DESCRIPTOR' : _SENDMESSAGEREQUEST, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.SendMessageRequest) + }) +_sym_db.RegisterMessage(SendMessageRequest) + +SendMessageResponse = _reflection.GeneratedProtocolMessageType('SendMessageResponse', (_message.Message,), { + 'DESCRIPTOR' : _SENDMESSAGERESPONSE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.SendMessageResponse) + }) +_sym_db.RegisterMessage(SendMessageResponse) + +World = _reflection.GeneratedProtocolMessageType('World', (_message.Message,), { + 'DESCRIPTOR' : _WORLD, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.World) + }) +_sym_db.RegisterMessage(World) + +Conversation = _reflection.GeneratedProtocolMessageType('Conversation', (_message.Message,), { + 'DESCRIPTOR' : _CONVERSATION, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.Conversation) + }) +_sym_db.RegisterMessage(Conversation) + +Player = _reflection.GeneratedProtocolMessageType('Player', (_message.Message,), { + 'DESCRIPTOR' : _PLAYER, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.Player) + }) +_sym_db.RegisterMessage(Player) + +Message = _reflection.GeneratedProtocolMessageType('Message', (_message.Message,), { + 'DESCRIPTOR' : _MESSAGE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.Message) + }) +_sym_db.RegisterMessage(Message) + +Character = _reflection.GeneratedProtocolMessageType('Character', (_message.Message,), { + 'DESCRIPTOR' : _CHARACTER, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.Character) + }) +_sym_db.RegisterMessage(Character) + +GenerationConfig = _reflection.GeneratedProtocolMessageType('GenerationConfig', (_message.Message,), { + 'DESCRIPTOR' : _GENERATIONCONFIG, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.GenerationConfig) + }) +_sym_db.RegisterMessage(GenerationConfig) + +Emotion = _reflection.GeneratedProtocolMessageType('Emotion', (_message.Message,), { + 'DESCRIPTOR' : _EMOTION, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.Emotion) + }) +_sym_db.RegisterMessage(Emotion) + +EmotionalRulesSet = _reflection.GeneratedProtocolMessageType('EmotionalRulesSet', (_message.Message,), { + 'DESCRIPTOR' : _EMOTIONALRULESSET, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.EmotionalRulesSet) + }) +_sym_db.RegisterMessage(EmotionalRulesSet) + +EmotionalRule = _reflection.GeneratedProtocolMessageType('EmotionalRule', (_message.Message,), { + 'DESCRIPTOR' : _EMOTIONALRULE, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.EmotionalRule) + }) +_sym_db.RegisterMessage(EmotionalRule) + +RuleRequirement = _reflection.GeneratedProtocolMessageType('RuleRequirement', (_message.Message,), { + 'DESCRIPTOR' : _RULEREQUIREMENT, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.RuleRequirement) + }) +_sym_db.RegisterMessage(RuleRequirement) + +EmotionalRuleEffect = _reflection.GeneratedProtocolMessageType('EmotionalRuleEffect', (_message.Message,), { + 'DESCRIPTOR' : _EMOTIONALRULEEFFECT, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.EmotionalRuleEffect) + }) +_sym_db.RegisterMessage(EmotionalRuleEffect) + +PromptSet = _reflection.GeneratedProtocolMessageType('PromptSet', (_message.Message,), { + 'DESCRIPTOR' : _PROMPTSET, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.PromptSet) + }) +_sym_db.RegisterMessage(PromptSet) + +Prompt = _reflection.GeneratedProtocolMessageType('Prompt', (_message.Message,), { + 'DESCRIPTOR' : _PROMPT, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.Prompt) + }) +_sym_db.RegisterMessage(Prompt) + +PrimerSet = _reflection.GeneratedProtocolMessageType('PrimerSet', (_message.Message,), { + 'DESCRIPTOR' : _PRIMERSET, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.PrimerSet) + }) +_sym_db.RegisterMessage(PrimerSet) + +Primer = _reflection.GeneratedProtocolMessageType('Primer', (_message.Message,), { + 'DESCRIPTOR' : _PRIMER, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.Primer) + }) +_sym_db.RegisterMessage(Primer) + +Task = _reflection.GeneratedProtocolMessageType('Task', (_message.Message,), { + 'DESCRIPTOR' : _TASK, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.Task) + }) +_sym_db.RegisterMessage(Task) + +TaskStep = _reflection.GeneratedProtocolMessageType('TaskStep', (_message.Message,), { + 'DESCRIPTOR' : _TASKSTEP, + '__module__' : 'conversation_pb2' + # @@protoc_insertion_point(class_scope:conversation.TaskStep) + }) +_sym_db.RegisterMessage(TaskStep) + +_CHATSERVICE = DESCRIPTOR.services_by_name['ChatService'] +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'Z\004./pb' + _CREATECHARACTERNEEDREQUEST._serialized_start=36 + _CREATECHARACTERNEEDREQUEST._serialized_end=78 + _CREATECHARACTERNEEDRESPONSE._serialized_start=80 + _CREATECHARACTERNEEDRESPONSE._serialized_end=109 + _GETCHARACTERTASKSREQUEST._serialized_start=111 + _GETCHARACTERTASKSREQUEST._serialized_end=159 + _GETCHARACTERTASKSRESPONSE._serialized_start=161 + _GETCHARACTERTASKSRESPONSE._serialized_end=223 + _COMPLETECHARACTERTASKREQUEST._serialized_start=225 + _COMPLETECHARACTERTASKREQUEST._serialized_end=294 + _COMPLETECHARACTERTASKRESPONSE._serialized_start=296 + _COMPLETECHARACTERTASKRESPONSE._serialized_end=327 + _COMMITMESSAGEREQUEST._serialized_start=330 + _COMMITMESSAGEREQUEST._serialized_end=548 + _COMMITMESSAGEREQUEST_COMMITTYPE._serialized_start=455 + _COMMITMESSAGEREQUEST_COMMITTYPE._serialized_end=548 + _COMMITMESSAGERESPONSE._serialized_start=550 + _COMMITMESSAGERESPONSE._serialized_end=573 + _CREATECHARACTERREQUEST._serialized_start=575 + _CREATECHARACTERREQUEST._serialized_end=643 + _CREATECHARACTERRESPONSE._serialized_start=645 + _CREATECHARACTERRESPONSE._serialized_end=682 + _GETCHARACTERREQUEST._serialized_start=684 + _GETCHARACTERREQUEST._serialized_end=717 + _GETCHARACTERRESPONSE._serialized_start=719 + _GETCHARACTERRESPONSE._serialized_end=785 + _UPDATECHARACTERREQUEST._serialized_start=787 + _UPDATECHARACTERREQUEST._serialized_end=855 + _UPDATECHARACTERRESPONSE._serialized_start=857 + _UPDATECHARACTERRESPONSE._serialized_end=926 + _DELETECHARACTERREQUEST._serialized_start=928 + _DELETECHARACTERREQUEST._serialized_end=964 + _DELETECHARACTERRESPONSE._serialized_start=966 + _DELETECHARACTERRESPONSE._serialized_end=991 + _LISTCHARACTERSREQUEST._serialized_start=993 + _LISTCHARACTERSREQUEST._serialized_end=1016 + _LISTCHARACTERSRESPONSE._serialized_start=1018 + _LISTCHARACTERSRESPONSE._serialized_end=1087 + _CREATEPLAYERREQUEST._serialized_start=1089 + _CREATEPLAYERREQUEST._serialized_end=1148 + _CREATEPLAYERRESPONSE._serialized_start=1150 + _CREATEPLAYERRESPONSE._serialized_end=1184 + _GETPLAYERREQUEST._serialized_start=1186 + _GETPLAYERREQUEST._serialized_end=1216 + _GETPLAYERRESPONSE._serialized_start=1218 + _GETPLAYERRESPONSE._serialized_end=1275 + _UPDATEPLAYERREQUEST._serialized_start=1277 + _UPDATEPLAYERREQUEST._serialized_end=1336 + _UPDATEPLAYERRESPONSE._serialized_start=1338 + _UPDATEPLAYERRESPONSE._serialized_end=1398 + _DELETEPLAYERREQUEST._serialized_start=1400 + _DELETEPLAYERREQUEST._serialized_end=1433 + _DELETEPLAYERRESPONSE._serialized_start=1435 + _DELETEPLAYERRESPONSE._serialized_end=1457 + _LISTPLAYERSREQUEST._serialized_start=1459 + _LISTPLAYERSREQUEST._serialized_end=1479 + _LISTPLAYERSRESPONSE._serialized_start=1481 + _LISTPLAYERSRESPONSE._serialized_end=1541 + _CREATEPROMPTSETREQUEST._serialized_start=1543 + _CREATEPROMPTSETREQUEST._serialized_end=1612 + _CREATEPROMPTSETRESPONSE._serialized_start=1614 + _CREATEPROMPTSETRESPONSE._serialized_end=1651 + _GETPROMPTSETREQUEST._serialized_start=1653 + _GETPROMPTSETREQUEST._serialized_end=1686 + _GETPROMPTSETRESPONSE._serialized_start=1688 + _GETPROMPTSETRESPONSE._serialized_end=1755 + _UPDATEPROMPTSETREQUEST._serialized_start=1757 + _UPDATEPROMPTSETREQUEST._serialized_end=1826 + _UPDATEPROMPTSETRESPONSE._serialized_start=1828 + _UPDATEPROMPTSETRESPONSE._serialized_end=1898 + _DELETEPROMPTSETREQUEST._serialized_start=1900 + _DELETEPROMPTSETREQUEST._serialized_end=1936 + _DELETEPROMPTSETRESPONSE._serialized_start=1938 + _DELETEPROMPTSETRESPONSE._serialized_end=1963 + _LISTPROMPTSETSREQUEST._serialized_start=1965 + _LISTPROMPTSETSREQUEST._serialized_end=1988 + _LISTPROMPTSETSRESPONSE._serialized_start=1990 + _LISTPROMPTSETSRESPONSE._serialized_end=2060 + _CREATEPROMPTREQUEST._serialized_start=2062 + _CREATEPROMPTREQUEST._serialized_end=2121 + _CREATEPROMPTRESPONSE._serialized_start=2123 + _CREATEPROMPTRESPONSE._serialized_end=2157 + _GETPROMPTREQUEST._serialized_start=2159 + _GETPROMPTREQUEST._serialized_end=2189 + _GETPROMPTRESPONSE._serialized_start=2191 + _GETPROMPTRESPONSE._serialized_end=2248 + _UPDATEPROMPTREQUEST._serialized_start=2250 + _UPDATEPROMPTREQUEST._serialized_end=2309 + _UPDATEPROMPTRESPONSE._serialized_start=2311 + _UPDATEPROMPTRESPONSE._serialized_end=2371 + _DELETEPROMPTREQUEST._serialized_start=2373 + _DELETEPROMPTREQUEST._serialized_end=2406 + _DELETEPROMPTRESPONSE._serialized_start=2408 + _DELETEPROMPTRESPONSE._serialized_end=2430 + _LISTPROMPTSREQUEST._serialized_start=2432 + _LISTPROMPTSREQUEST._serialized_end=2452 + _LISTPROMPTSRESPONSE._serialized_start=2454 + _LISTPROMPTSRESPONSE._serialized_end=2514 + _CREATEEMOTIONREQUEST._serialized_start=2516 + _CREATEEMOTIONREQUEST._serialized_end=2578 + _CREATEEMOTIONRESPONSE._serialized_start=2580 + _CREATEEMOTIONRESPONSE._serialized_end=2615 + _GETEMOTIONREQUEST._serialized_start=2617 + _GETEMOTIONREQUEST._serialized_end=2648 + _GETEMOTIONRESPONSE._serialized_start=2650 + _GETEMOTIONRESPONSE._serialized_end=2710 + _UPDATEEMOTIONREQUEST._serialized_start=2712 + _UPDATEEMOTIONREQUEST._serialized_end=2774 + _UPDATEEMOTIONRESPONSE._serialized_start=2776 + _UPDATEEMOTIONRESPONSE._serialized_end=2839 + _DELETEEMOTIONREQUEST._serialized_start=2841 + _DELETEEMOTIONREQUEST._serialized_end=2875 + _DELETEEMOTIONRESPONSE._serialized_start=2877 + _DELETEEMOTIONRESPONSE._serialized_end=2900 + _LISTEMOTIONSREQUEST._serialized_start=2902 + _LISTEMOTIONSREQUEST._serialized_end=2923 + _LISTEMOTIONSRESPONSE._serialized_start=2925 + _LISTEMOTIONSRESPONSE._serialized_end=2988 + _CREATEPRIMERSETREQUEST._serialized_start=2990 + _CREATEPRIMERSETREQUEST._serialized_end=3059 + _CREATEPRIMERSETRESPONSE._serialized_start=3061 + _CREATEPRIMERSETRESPONSE._serialized_end=3098 + _GETPRIMERSETREQUEST._serialized_start=3100 + _GETPRIMERSETREQUEST._serialized_end=3133 + _GETPRIMERSETRESPONSE._serialized_start=3135 + _GETPRIMERSETRESPONSE._serialized_end=3202 + _UPDATEPRIMERSETREQUEST._serialized_start=3204 + _UPDATEPRIMERSETREQUEST._serialized_end=3273 + _UPDATEPRIMERSETRESPONSE._serialized_start=3275 + _UPDATEPRIMERSETRESPONSE._serialized_end=3345 + _DELETEPRIMERSETREQUEST._serialized_start=3347 + _DELETEPRIMERSETREQUEST._serialized_end=3383 + _DELETEPRIMERSETRESPONSE._serialized_start=3385 + _DELETEPRIMERSETRESPONSE._serialized_end=3410 + _LISTPRIMERSETSREQUEST._serialized_start=3412 + _LISTPRIMERSETSREQUEST._serialized_end=3435 + _LISTPRIMERSETSRESPONSE._serialized_start=3437 + _LISTPRIMERSETSRESPONSE._serialized_end=3507 + _CREATEPRIMERREQUEST._serialized_start=3509 + _CREATEPRIMERREQUEST._serialized_end=3568 + _CREATEPRIMERRESPONSE._serialized_start=3570 + _CREATEPRIMERRESPONSE._serialized_end=3604 + _GETPRIMERREQUEST._serialized_start=3606 + _GETPRIMERREQUEST._serialized_end=3636 + _GETPRIMERRESPONSE._serialized_start=3638 + _GETPRIMERRESPONSE._serialized_end=3695 + _UPDATEPRIMERREQUEST._serialized_start=3697 + _UPDATEPRIMERREQUEST._serialized_end=3756 + _UPDATEPRIMERRESPONSE._serialized_start=3758 + _UPDATEPRIMERRESPONSE._serialized_end=3818 + _DELETEPRIMERREQUEST._serialized_start=3820 + _DELETEPRIMERREQUEST._serialized_end=3853 + _DELETEPRIMERRESPONSE._serialized_start=3855 + _DELETEPRIMERRESPONSE._serialized_end=3877 + _LISTPRIMERSREQUEST._serialized_start=3879 + _LISTPRIMERSREQUEST._serialized_end=3899 + _LISTPRIMERSRESPONSE._serialized_start=3901 + _LISTPRIMERSRESPONSE._serialized_end=3961 + _CREATEEMOTIONALRULESSETREQUEST._serialized_start=3963 + _CREATEEMOTIONALRULESSETREQUEST._serialized_end=4057 + _CREATEEMOTIONALRULESSETRESPONSE._serialized_start=4059 + _CREATEEMOTIONALRULESSETRESPONSE._serialized_end=4104 + _GETEMOTIONALRULESSETREQUEST._serialized_start=4106 + _GETEMOTIONALRULESSETREQUEST._serialized_end=4147 + _GETEMOTIONALRULESSETRESPONSE._serialized_start=4149 + _GETEMOTIONALRULESSETRESPONSE._serialized_end=4241 + _UPDATEEMOTIONALRULESSETREQUEST._serialized_start=4243 + _UPDATEEMOTIONALRULESSETREQUEST._serialized_end=4337 + _UPDATEEMOTIONALRULESSETRESPONSE._serialized_start=4339 + _UPDATEEMOTIONALRULESSETRESPONSE._serialized_end=4434 + _DELETEEMOTIONALRULESSETREQUEST._serialized_start=4436 + _DELETEEMOTIONALRULESSETREQUEST._serialized_end=4480 + _DELETEEMOTIONALRULESSETRESPONSE._serialized_start=4482 + _DELETEEMOTIONALRULESSETRESPONSE._serialized_end=4515 + _LISTEMOTIONALRULESSETSREQUEST._serialized_start=4517 + _LISTEMOTIONALRULESSETSREQUEST._serialized_end=4548 + _LISTEMOTIONALRULESSETSRESPONSE._serialized_start=4550 + _LISTEMOTIONALRULESSETSRESPONSE._serialized_end=4645 + _CREATEEMOTIONALRULEREQUEST._serialized_start=4647 + _CREATEEMOTIONALRULEREQUEST._serialized_end=4728 + _CREATEEMOTIONALRULERESPONSE._serialized_start=4730 + _CREATEEMOTIONALRULERESPONSE._serialized_end=4771 + _GETEMOTIONALRULEREQUEST._serialized_start=4773 + _GETEMOTIONALRULEREQUEST._serialized_end=4810 + _GETEMOTIONALRULERESPONSE._serialized_start=4812 + _GETEMOTIONALRULERESPONSE._serialized_end=4891 + _UPDATEEMOTIONALRULEREQUEST._serialized_start=4893 + _UPDATEEMOTIONALRULEREQUEST._serialized_end=4974 + _UPDATEEMOTIONALRULERESPONSE._serialized_start=4976 + _UPDATEEMOTIONALRULERESPONSE._serialized_end=5058 + _DELETEEMOTIONALRULEREQUEST._serialized_start=5060 + _DELETEEMOTIONALRULEREQUEST._serialized_end=5100 + _DELETEEMOTIONALRULERESPONSE._serialized_start=5102 + _DELETEEMOTIONALRULERESPONSE._serialized_end=5131 + _LISTEMOTIONALRULEREQUEST._serialized_start=5133 + _LISTEMOTIONALRULEREQUEST._serialized_end=5159 + _LISTEMOTIONALRULERESPONSE._serialized_start=5161 + _LISTEMOTIONALRULERESPONSE._serialized_end=5242 + _CREATEGENERATIONCONFIGREQUEST._serialized_start=5244 + _CREATEGENERATIONCONFIGREQUEST._serialized_end=5334 + _CREATEGENERATIONCONFIGRESPONSE._serialized_start=5336 + _CREATEGENERATIONCONFIGRESPONSE._serialized_end=5380 + _GETGENERATIONCONFIGREQUEST._serialized_start=5382 + _GETGENERATIONCONFIGREQUEST._serialized_end=5422 + _GETGENERATIONCONFIGRESPONSE._serialized_start=5424 + _GETGENERATIONCONFIGRESPONSE._serialized_end=5512 + _UPDATEGENERATIONCONFIGREQUEST._serialized_start=5514 + _UPDATEGENERATIONCONFIGREQUEST._serialized_end=5604 + _UPDATEGENERATIONCONFIGRESPONSE._serialized_start=5606 + _UPDATEGENERATIONCONFIGRESPONSE._serialized_end=5697 + _DELETEGENERATIONCONFIGREQUEST._serialized_start=5699 + _DELETEGENERATIONCONFIGREQUEST._serialized_end=5742 + _DELETEGENERATIONCONFIGRESPONSE._serialized_start=5744 + _DELETEGENERATIONCONFIGRESPONSE._serialized_end=5776 + _LISTGENERATIONCONFIGSREQUEST._serialized_start=5778 + _LISTGENERATIONCONFIGSREQUEST._serialized_end=5808 + _LISTGENERATIONCONFIGSRESPONSE._serialized_start=5810 + _LISTGENERATIONCONFIGSRESPONSE._serialized_end=5901 + _CREATECONVERSATIONREQUEST._serialized_start=5903 + _CREATECONVERSATIONREQUEST._serialized_end=5973 + _CREATECONVERSATIONRESPONSE._serialized_start=5975 + _CREATECONVERSATIONRESPONSE._serialized_end=6015 + _JOINCONVERSATIONREQUEST._serialized_start=6017 + _JOINCONVERSATIONREQUEST._serialized_end=6067 + _SENDMESSAGEREQUEST._serialized_start=6069 + _SENDMESSAGEREQUEST._serialized_end=6158 + _SENDMESSAGERESPONSE._serialized_start=6160 + _SENDMESSAGERESPONSE._serialized_end=6181 + _WORLD._serialized_start=6183 + _WORLD._serialized_end=6222 + _CONVERSATION._serialized_start=6225 + _CONVERSATION._serialized_end=6376 + _PLAYER._serialized_start=6378 + _PLAYER._serialized_end=6437 + _MESSAGE._serialized_start=6440 + _MESSAGE._serialized_end=6676 + _CHARACTER._serialized_start=6679 + _CHARACTER._serialized_end=7192 + _GENERATIONCONFIG._serialized_start=7195 + _GENERATIONCONFIG._serialized_end=7469 + _EMOTION._serialized_start=7472 + _EMOTION._serialized_end=7600 + _EMOTIONALRULESSET._serialized_start=7602 + _EMOTIONALRULESSET._serialized_end=7701 + _EMOTIONALRULE._serialized_start=7704 + _EMOTIONALRULE._serialized_end=7964 + _RULEREQUIREMENT._serialized_start=7967 + _RULEREQUIREMENT._serialized_end=8456 + _RULEREQUIREMENT_REQUIREMENTTYPE._serialized_start=8255 + _RULEREQUIREMENT_REQUIREMENTTYPE._serialized_end=8456 + _EMOTIONALRULEEFFECT._serialized_start=8459 + _EMOTIONALRULEEFFECT._serialized_end=9195 + _EMOTIONALRULEEFFECT_EFFECTTYPE._serialized_start=8760 + _EMOTIONALRULEEFFECT_EFFECTTYPE._serialized_end=9015 + _EMOTIONALRULEEFFECT_EFFECTMULTIPLIERTYPE._serialized_start=9018 + _EMOTIONALRULEEFFECT_EFFECTMULTIPLIERTYPE._serialized_end=9195 + _PROMPTSET._serialized_start=9197 + _PROMPTSET._serialized_end=9273 + _PROMPT._serialized_start=9276 + _PROMPT._serialized_end=9475 + _PRIMERSET._serialized_start=9477 + _PRIMERSET._serialized_end=9553 + _PRIMER._serialized_start=9556 + _PRIMER._serialized_end=9741 + _TASK._serialized_start=9743 + _TASK._serialized_end=9863 + _TASKSTEP._serialized_start=9865 + _TASKSTEP._serialized_end=9936 + _CHATSERVICE._serialized_start=9939 + _CHATSERVICE._serialized_end=15447 +# @@protoc_insertion_point(module_scope) diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/conversation_pb2_grpc.py b/modules/ros_chatbot/src/ros_chatbot/agents/conversation_pb2_grpc.py new file mode 100644 index 0000000..faecb6e --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/conversation_pb2_grpc.py @@ -0,0 +1,1920 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from . import conversation_pb2 as conversation__pb2 + + +class ChatServiceStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.CreateConversation = channel.unary_unary( + '/conversation.ChatService/CreateConversation', + request_serializer=conversation__pb2.CreateConversationRequest.SerializeToString, + response_deserializer=conversation__pb2.CreateConversationResponse.FromString, + ) + self.SendMessage = channel.unary_unary( + '/conversation.ChatService/SendMessage', + request_serializer=conversation__pb2.SendMessageRequest.SerializeToString, + response_deserializer=conversation__pb2.SendMessageResponse.FromString, + ) + self.StreamConversationMessages = channel.unary_stream( + '/conversation.ChatService/StreamConversationMessages', + request_serializer=conversation__pb2.JoinConversationRequest.SerializeToString, + response_deserializer=conversation__pb2.Message.FromString, + ) + self.CommitMessage = channel.unary_unary( + '/conversation.ChatService/CommitMessage', + request_serializer=conversation__pb2.CommitMessageRequest.SerializeToString, + response_deserializer=conversation__pb2.CommitMessageResponse.FromString, + ) + self.GetCharacterTasks = channel.unary_unary( + '/conversation.ChatService/GetCharacterTasks', + request_serializer=conversation__pb2.GetCharacterTasksRequest.SerializeToString, + response_deserializer=conversation__pb2.GetCharacterTasksResponse.FromString, + ) + self.CompleteCharacterTask = channel.unary_unary( + '/conversation.ChatService/CompleteCharacterTask', + request_serializer=conversation__pb2.CompleteCharacterTaskRequest.SerializeToString, + response_deserializer=conversation__pb2.CompleteCharacterTaskResponse.FromString, + ) + self.CreateCharacterNeed = channel.unary_unary( + '/conversation.ChatService/CreateCharacterNeed', + request_serializer=conversation__pb2.CreateCharacterNeedRequest.SerializeToString, + response_deserializer=conversation__pb2.CreateCharacterNeedResponse.FromString, + ) + self.CreateCharacter = channel.unary_unary( + '/conversation.ChatService/CreateCharacter', + request_serializer=conversation__pb2.CreateCharacterRequest.SerializeToString, + response_deserializer=conversation__pb2.CreateCharacterResponse.FromString, + ) + self.GetCharacter = channel.unary_unary( + '/conversation.ChatService/GetCharacter', + request_serializer=conversation__pb2.GetCharacterRequest.SerializeToString, + response_deserializer=conversation__pb2.GetCharacterResponse.FromString, + ) + self.UpdateCharacter = channel.unary_unary( + '/conversation.ChatService/UpdateCharacter', + request_serializer=conversation__pb2.UpdateCharacterRequest.SerializeToString, + response_deserializer=conversation__pb2.UpdateCharacterResponse.FromString, + ) + self.DeleteCharacter = channel.unary_unary( + '/conversation.ChatService/DeleteCharacter', + request_serializer=conversation__pb2.DeleteCharacterRequest.SerializeToString, + response_deserializer=conversation__pb2.DeleteCharacterResponse.FromString, + ) + self.ListCharacters = channel.unary_unary( + '/conversation.ChatService/ListCharacters', + request_serializer=conversation__pb2.ListCharactersRequest.SerializeToString, + response_deserializer=conversation__pb2.ListCharactersResponse.FromString, + ) + self.CreatePlayer = channel.unary_unary( + '/conversation.ChatService/CreatePlayer', + request_serializer=conversation__pb2.CreatePlayerRequest.SerializeToString, + response_deserializer=conversation__pb2.CreatePlayerResponse.FromString, + ) + self.GetPlayer = channel.unary_unary( + '/conversation.ChatService/GetPlayer', + request_serializer=conversation__pb2.GetPlayerRequest.SerializeToString, + response_deserializer=conversation__pb2.GetPlayerResponse.FromString, + ) + self.UpdatePlayer = channel.unary_unary( + '/conversation.ChatService/UpdatePlayer', + request_serializer=conversation__pb2.UpdatePlayerRequest.SerializeToString, + response_deserializer=conversation__pb2.UpdatePlayerResponse.FromString, + ) + self.DeletePlayer = channel.unary_unary( + '/conversation.ChatService/DeletePlayer', + request_serializer=conversation__pb2.DeletePlayerRequest.SerializeToString, + response_deserializer=conversation__pb2.DeletePlayerResponse.FromString, + ) + self.ListPlayers = channel.unary_unary( + '/conversation.ChatService/ListPlayers', + request_serializer=conversation__pb2.ListPlayersRequest.SerializeToString, + response_deserializer=conversation__pb2.ListPlayersResponse.FromString, + ) + self.CreatePromptSet = channel.unary_unary( + '/conversation.ChatService/CreatePromptSet', + request_serializer=conversation__pb2.CreatePromptSetRequest.SerializeToString, + response_deserializer=conversation__pb2.CreatePromptSetResponse.FromString, + ) + self.GetPromptSet = channel.unary_unary( + '/conversation.ChatService/GetPromptSet', + request_serializer=conversation__pb2.GetPromptSetRequest.SerializeToString, + response_deserializer=conversation__pb2.GetPromptSetResponse.FromString, + ) + self.UpdatePromptSet = channel.unary_unary( + '/conversation.ChatService/UpdatePromptSet', + request_serializer=conversation__pb2.UpdatePromptSetRequest.SerializeToString, + response_deserializer=conversation__pb2.UpdatePromptSetResponse.FromString, + ) + self.DeletePromptSet = channel.unary_unary( + '/conversation.ChatService/DeletePromptSet', + request_serializer=conversation__pb2.DeletePromptSetRequest.SerializeToString, + response_deserializer=conversation__pb2.DeletePromptSetResponse.FromString, + ) + self.ListPromptSets = channel.unary_unary( + '/conversation.ChatService/ListPromptSets', + request_serializer=conversation__pb2.ListPromptSetsRequest.SerializeToString, + response_deserializer=conversation__pb2.ListPromptSetsResponse.FromString, + ) + self.CreatePrompt = channel.unary_unary( + '/conversation.ChatService/CreatePrompt', + request_serializer=conversation__pb2.CreatePromptRequest.SerializeToString, + response_deserializer=conversation__pb2.CreatePromptResponse.FromString, + ) + self.GetPrompt = channel.unary_unary( + '/conversation.ChatService/GetPrompt', + request_serializer=conversation__pb2.GetPromptRequest.SerializeToString, + response_deserializer=conversation__pb2.GetPromptResponse.FromString, + ) + self.UpdatePrompt = channel.unary_unary( + '/conversation.ChatService/UpdatePrompt', + request_serializer=conversation__pb2.UpdatePromptRequest.SerializeToString, + response_deserializer=conversation__pb2.UpdatePromptResponse.FromString, + ) + self.DeletePrompt = channel.unary_unary( + '/conversation.ChatService/DeletePrompt', + request_serializer=conversation__pb2.DeletePromptRequest.SerializeToString, + response_deserializer=conversation__pb2.DeletePromptResponse.FromString, + ) + self.ListPrompts = channel.unary_unary( + '/conversation.ChatService/ListPrompts', + request_serializer=conversation__pb2.ListPromptsRequest.SerializeToString, + response_deserializer=conversation__pb2.ListPromptsResponse.FromString, + ) + self.CreatePrimerSet = channel.unary_unary( + '/conversation.ChatService/CreatePrimerSet', + request_serializer=conversation__pb2.CreatePrimerSetRequest.SerializeToString, + response_deserializer=conversation__pb2.CreatePrimerSetResponse.FromString, + ) + self.GetPrimerSet = channel.unary_unary( + '/conversation.ChatService/GetPrimerSet', + request_serializer=conversation__pb2.GetPrimerSetRequest.SerializeToString, + response_deserializer=conversation__pb2.GetPrimerSetResponse.FromString, + ) + self.UpdatePrimerSet = channel.unary_unary( + '/conversation.ChatService/UpdatePrimerSet', + request_serializer=conversation__pb2.UpdatePrimerSetRequest.SerializeToString, + response_deserializer=conversation__pb2.UpdatePrimerSetResponse.FromString, + ) + self.DeletePrimerSet = channel.unary_unary( + '/conversation.ChatService/DeletePrimerSet', + request_serializer=conversation__pb2.DeletePrimerSetRequest.SerializeToString, + response_deserializer=conversation__pb2.DeletePrimerSetResponse.FromString, + ) + self.ListPrimerSets = channel.unary_unary( + '/conversation.ChatService/ListPrimerSets', + request_serializer=conversation__pb2.ListPrimerSetsRequest.SerializeToString, + response_deserializer=conversation__pb2.ListPrimerSetsResponse.FromString, + ) + self.CreatePrimer = channel.unary_unary( + '/conversation.ChatService/CreatePrimer', + request_serializer=conversation__pb2.CreatePrimerRequest.SerializeToString, + response_deserializer=conversation__pb2.CreatePrimerResponse.FromString, + ) + self.GetPrimer = channel.unary_unary( + '/conversation.ChatService/GetPrimer', + request_serializer=conversation__pb2.GetPrimerRequest.SerializeToString, + response_deserializer=conversation__pb2.GetPrimerResponse.FromString, + ) + self.UpdatePrimer = channel.unary_unary( + '/conversation.ChatService/UpdatePrimer', + request_serializer=conversation__pb2.UpdatePrimerRequest.SerializeToString, + response_deserializer=conversation__pb2.UpdatePrimerResponse.FromString, + ) + self.DeletePrimer = channel.unary_unary( + '/conversation.ChatService/DeletePrimer', + request_serializer=conversation__pb2.DeletePrimerRequest.SerializeToString, + response_deserializer=conversation__pb2.DeletePrimerResponse.FromString, + ) + self.ListPrimers = channel.unary_unary( + '/conversation.ChatService/ListPrimers', + request_serializer=conversation__pb2.ListPrimersRequest.SerializeToString, + response_deserializer=conversation__pb2.ListPrimersResponse.FromString, + ) + self.CreateEmotionalRulesSet = channel.unary_unary( + '/conversation.ChatService/CreateEmotionalRulesSet', + request_serializer=conversation__pb2.CreateEmotionalRulesSetRequest.SerializeToString, + response_deserializer=conversation__pb2.CreateEmotionalRulesSetResponse.FromString, + ) + self.GetEmotionalRulesSet = channel.unary_unary( + '/conversation.ChatService/GetEmotionalRulesSet', + request_serializer=conversation__pb2.GetEmotionalRulesSetRequest.SerializeToString, + response_deserializer=conversation__pb2.GetEmotionalRulesSetResponse.FromString, + ) + self.UpdateEmotionalRulesSet = channel.unary_unary( + '/conversation.ChatService/UpdateEmotionalRulesSet', + request_serializer=conversation__pb2.UpdateEmotionalRulesSetRequest.SerializeToString, + response_deserializer=conversation__pb2.UpdateEmotionalRulesSetResponse.FromString, + ) + self.DeleteEmotionalRulesSet = channel.unary_unary( + '/conversation.ChatService/DeleteEmotionalRulesSet', + request_serializer=conversation__pb2.DeleteEmotionalRulesSetRequest.SerializeToString, + response_deserializer=conversation__pb2.DeleteEmotionalRulesSetResponse.FromString, + ) + self.ListEmotionalRulesSets = channel.unary_unary( + '/conversation.ChatService/ListEmotionalRulesSets', + request_serializer=conversation__pb2.ListEmotionalRulesSetsRequest.SerializeToString, + response_deserializer=conversation__pb2.ListEmotionalRulesSetsResponse.FromString, + ) + self.CreateEmotionalRule = channel.unary_unary( + '/conversation.ChatService/CreateEmotionalRule', + request_serializer=conversation__pb2.CreateEmotionalRuleRequest.SerializeToString, + response_deserializer=conversation__pb2.CreateEmotionalRuleResponse.FromString, + ) + self.GetEmotionalRule = channel.unary_unary( + '/conversation.ChatService/GetEmotionalRule', + request_serializer=conversation__pb2.GetEmotionalRuleRequest.SerializeToString, + response_deserializer=conversation__pb2.GetEmotionalRuleResponse.FromString, + ) + self.UpdateEmotionalRule = channel.unary_unary( + '/conversation.ChatService/UpdateEmotionalRule', + request_serializer=conversation__pb2.UpdateEmotionalRuleRequest.SerializeToString, + response_deserializer=conversation__pb2.UpdateEmotionalRuleResponse.FromString, + ) + self.DeleteEmotionalRule = channel.unary_unary( + '/conversation.ChatService/DeleteEmotionalRule', + request_serializer=conversation__pb2.DeleteEmotionalRuleRequest.SerializeToString, + response_deserializer=conversation__pb2.DeleteEmotionalRuleResponse.FromString, + ) + self.ListEmotionalRules = channel.unary_unary( + '/conversation.ChatService/ListEmotionalRules', + request_serializer=conversation__pb2.ListEmotionalRuleRequest.SerializeToString, + response_deserializer=conversation__pb2.ListEmotionalRuleResponse.FromString, + ) + self.CreateGenerationConfig = channel.unary_unary( + '/conversation.ChatService/CreateGenerationConfig', + request_serializer=conversation__pb2.CreateGenerationConfigRequest.SerializeToString, + response_deserializer=conversation__pb2.CreateGenerationConfigResponse.FromString, + ) + self.GetGenerationConfig = channel.unary_unary( + '/conversation.ChatService/GetGenerationConfig', + request_serializer=conversation__pb2.GetGenerationConfigRequest.SerializeToString, + response_deserializer=conversation__pb2.GetGenerationConfigResponse.FromString, + ) + self.UpdateGenerationConfig = channel.unary_unary( + '/conversation.ChatService/UpdateGenerationConfig', + request_serializer=conversation__pb2.UpdateGenerationConfigRequest.SerializeToString, + response_deserializer=conversation__pb2.UpdateGenerationConfigResponse.FromString, + ) + self.DeleteGenerationConfig = channel.unary_unary( + '/conversation.ChatService/DeleteGenerationConfig', + request_serializer=conversation__pb2.DeleteGenerationConfigRequest.SerializeToString, + response_deserializer=conversation__pb2.DeleteGenerationConfigResponse.FromString, + ) + self.ListGenerationConfigs = channel.unary_unary( + '/conversation.ChatService/ListGenerationConfigs', + request_serializer=conversation__pb2.ListGenerationConfigsRequest.SerializeToString, + response_deserializer=conversation__pb2.ListGenerationConfigsResponse.FromString, + ) + self.CreateEmotion = channel.unary_unary( + '/conversation.ChatService/CreateEmotion', + request_serializer=conversation__pb2.CreateEmotionRequest.SerializeToString, + response_deserializer=conversation__pb2.CreateEmotionResponse.FromString, + ) + self.GetEmotion = channel.unary_unary( + '/conversation.ChatService/GetEmotion', + request_serializer=conversation__pb2.GetEmotionRequest.SerializeToString, + response_deserializer=conversation__pb2.GetEmotionResponse.FromString, + ) + self.UpdateEmotion = channel.unary_unary( + '/conversation.ChatService/UpdateEmotion', + request_serializer=conversation__pb2.UpdateEmotionRequest.SerializeToString, + response_deserializer=conversation__pb2.UpdateEmotionResponse.FromString, + ) + self.DeleteEmotion = channel.unary_unary( + '/conversation.ChatService/DeleteEmotion', + request_serializer=conversation__pb2.DeleteEmotionRequest.SerializeToString, + response_deserializer=conversation__pb2.DeleteEmotionResponse.FromString, + ) + self.ListEmotions = channel.unary_unary( + '/conversation.ChatService/ListEmotions', + request_serializer=conversation__pb2.ListEmotionsRequest.SerializeToString, + response_deserializer=conversation__pb2.ListEmotionsResponse.FromString, + ) + + +class ChatServiceServicer(object): + """Missing associated documentation comment in .proto file.""" + + def CreateConversation(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendMessage(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def StreamConversationMessages(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CommitMessage(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetCharacterTasks(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CompleteCharacterTask(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CreateCharacterNeed(self, request, context): + """PROBLEM: CREATE HOW?? FOR CHARACTER ID OR CHARACTER TYPE?? + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CreateCharacter(self, request, context): + """rpc GetCharacterLocation(GetCharacterLocationRequest) returns (GetCharacterLocationResponse); + rpc SetCharacterLocation(SetCharacterLocationRequest) returns (SetCharacterLocationResponse); + rpc SetCharacterVitalStatus(SetCharacterVitalStatusRequest) returns (SetCharacterVitalStatusResponse); + rpc SetItemLocation(SetItemLocationRequest) returns (SetItemLocationResponse); + + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetCharacter(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def UpdateCharacter(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def DeleteCharacter(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def ListCharacters(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CreatePlayer(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetPlayer(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def UpdatePlayer(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def DeletePlayer(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def ListPlayers(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CreatePromptSet(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetPromptSet(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def UpdatePromptSet(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def DeletePromptSet(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def ListPromptSets(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CreatePrompt(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetPrompt(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def UpdatePrompt(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def DeletePrompt(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def ListPrompts(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CreatePrimerSet(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetPrimerSet(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def UpdatePrimerSet(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def DeletePrimerSet(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def ListPrimerSets(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CreatePrimer(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetPrimer(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def UpdatePrimer(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def DeletePrimer(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def ListPrimers(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CreateEmotionalRulesSet(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetEmotionalRulesSet(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def UpdateEmotionalRulesSet(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def DeleteEmotionalRulesSet(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def ListEmotionalRulesSets(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CreateEmotionalRule(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetEmotionalRule(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def UpdateEmotionalRule(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def DeleteEmotionalRule(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def ListEmotionalRules(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CreateGenerationConfig(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetGenerationConfig(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def UpdateGenerationConfig(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def DeleteGenerationConfig(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def ListGenerationConfigs(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CreateEmotion(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetEmotion(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def UpdateEmotion(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def DeleteEmotion(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def ListEmotions(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_ChatServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'CreateConversation': grpc.unary_unary_rpc_method_handler( + servicer.CreateConversation, + request_deserializer=conversation__pb2.CreateConversationRequest.FromString, + response_serializer=conversation__pb2.CreateConversationResponse.SerializeToString, + ), + 'SendMessage': grpc.unary_unary_rpc_method_handler( + servicer.SendMessage, + request_deserializer=conversation__pb2.SendMessageRequest.FromString, + response_serializer=conversation__pb2.SendMessageResponse.SerializeToString, + ), + 'StreamConversationMessages': grpc.unary_stream_rpc_method_handler( + servicer.StreamConversationMessages, + request_deserializer=conversation__pb2.JoinConversationRequest.FromString, + response_serializer=conversation__pb2.Message.SerializeToString, + ), + 'CommitMessage': grpc.unary_unary_rpc_method_handler( + servicer.CommitMessage, + request_deserializer=conversation__pb2.CommitMessageRequest.FromString, + response_serializer=conversation__pb2.CommitMessageResponse.SerializeToString, + ), + 'GetCharacterTasks': grpc.unary_unary_rpc_method_handler( + servicer.GetCharacterTasks, + request_deserializer=conversation__pb2.GetCharacterTasksRequest.FromString, + response_serializer=conversation__pb2.GetCharacterTasksResponse.SerializeToString, + ), + 'CompleteCharacterTask': grpc.unary_unary_rpc_method_handler( + servicer.CompleteCharacterTask, + request_deserializer=conversation__pb2.CompleteCharacterTaskRequest.FromString, + response_serializer=conversation__pb2.CompleteCharacterTaskResponse.SerializeToString, + ), + 'CreateCharacterNeed': grpc.unary_unary_rpc_method_handler( + servicer.CreateCharacterNeed, + request_deserializer=conversation__pb2.CreateCharacterNeedRequest.FromString, + response_serializer=conversation__pb2.CreateCharacterNeedResponse.SerializeToString, + ), + 'CreateCharacter': grpc.unary_unary_rpc_method_handler( + servicer.CreateCharacter, + request_deserializer=conversation__pb2.CreateCharacterRequest.FromString, + response_serializer=conversation__pb2.CreateCharacterResponse.SerializeToString, + ), + 'GetCharacter': grpc.unary_unary_rpc_method_handler( + servicer.GetCharacter, + request_deserializer=conversation__pb2.GetCharacterRequest.FromString, + response_serializer=conversation__pb2.GetCharacterResponse.SerializeToString, + ), + 'UpdateCharacter': grpc.unary_unary_rpc_method_handler( + servicer.UpdateCharacter, + request_deserializer=conversation__pb2.UpdateCharacterRequest.FromString, + response_serializer=conversation__pb2.UpdateCharacterResponse.SerializeToString, + ), + 'DeleteCharacter': grpc.unary_unary_rpc_method_handler( + servicer.DeleteCharacter, + request_deserializer=conversation__pb2.DeleteCharacterRequest.FromString, + response_serializer=conversation__pb2.DeleteCharacterResponse.SerializeToString, + ), + 'ListCharacters': grpc.unary_unary_rpc_method_handler( + servicer.ListCharacters, + request_deserializer=conversation__pb2.ListCharactersRequest.FromString, + response_serializer=conversation__pb2.ListCharactersResponse.SerializeToString, + ), + 'CreatePlayer': grpc.unary_unary_rpc_method_handler( + servicer.CreatePlayer, + request_deserializer=conversation__pb2.CreatePlayerRequest.FromString, + response_serializer=conversation__pb2.CreatePlayerResponse.SerializeToString, + ), + 'GetPlayer': grpc.unary_unary_rpc_method_handler( + servicer.GetPlayer, + request_deserializer=conversation__pb2.GetPlayerRequest.FromString, + response_serializer=conversation__pb2.GetPlayerResponse.SerializeToString, + ), + 'UpdatePlayer': grpc.unary_unary_rpc_method_handler( + servicer.UpdatePlayer, + request_deserializer=conversation__pb2.UpdatePlayerRequest.FromString, + response_serializer=conversation__pb2.UpdatePlayerResponse.SerializeToString, + ), + 'DeletePlayer': grpc.unary_unary_rpc_method_handler( + servicer.DeletePlayer, + request_deserializer=conversation__pb2.DeletePlayerRequest.FromString, + response_serializer=conversation__pb2.DeletePlayerResponse.SerializeToString, + ), + 'ListPlayers': grpc.unary_unary_rpc_method_handler( + servicer.ListPlayers, + request_deserializer=conversation__pb2.ListPlayersRequest.FromString, + response_serializer=conversation__pb2.ListPlayersResponse.SerializeToString, + ), + 'CreatePromptSet': grpc.unary_unary_rpc_method_handler( + servicer.CreatePromptSet, + request_deserializer=conversation__pb2.CreatePromptSetRequest.FromString, + response_serializer=conversation__pb2.CreatePromptSetResponse.SerializeToString, + ), + 'GetPromptSet': grpc.unary_unary_rpc_method_handler( + servicer.GetPromptSet, + request_deserializer=conversation__pb2.GetPromptSetRequest.FromString, + response_serializer=conversation__pb2.GetPromptSetResponse.SerializeToString, + ), + 'UpdatePromptSet': grpc.unary_unary_rpc_method_handler( + servicer.UpdatePromptSet, + request_deserializer=conversation__pb2.UpdatePromptSetRequest.FromString, + response_serializer=conversation__pb2.UpdatePromptSetResponse.SerializeToString, + ), + 'DeletePromptSet': grpc.unary_unary_rpc_method_handler( + servicer.DeletePromptSet, + request_deserializer=conversation__pb2.DeletePromptSetRequest.FromString, + response_serializer=conversation__pb2.DeletePromptSetResponse.SerializeToString, + ), + 'ListPromptSets': grpc.unary_unary_rpc_method_handler( + servicer.ListPromptSets, + request_deserializer=conversation__pb2.ListPromptSetsRequest.FromString, + response_serializer=conversation__pb2.ListPromptSetsResponse.SerializeToString, + ), + 'CreatePrompt': grpc.unary_unary_rpc_method_handler( + servicer.CreatePrompt, + request_deserializer=conversation__pb2.CreatePromptRequest.FromString, + response_serializer=conversation__pb2.CreatePromptResponse.SerializeToString, + ), + 'GetPrompt': grpc.unary_unary_rpc_method_handler( + servicer.GetPrompt, + request_deserializer=conversation__pb2.GetPromptRequest.FromString, + response_serializer=conversation__pb2.GetPromptResponse.SerializeToString, + ), + 'UpdatePrompt': grpc.unary_unary_rpc_method_handler( + servicer.UpdatePrompt, + request_deserializer=conversation__pb2.UpdatePromptRequest.FromString, + response_serializer=conversation__pb2.UpdatePromptResponse.SerializeToString, + ), + 'DeletePrompt': grpc.unary_unary_rpc_method_handler( + servicer.DeletePrompt, + request_deserializer=conversation__pb2.DeletePromptRequest.FromString, + response_serializer=conversation__pb2.DeletePromptResponse.SerializeToString, + ), + 'ListPrompts': grpc.unary_unary_rpc_method_handler( + servicer.ListPrompts, + request_deserializer=conversation__pb2.ListPromptsRequest.FromString, + response_serializer=conversation__pb2.ListPromptsResponse.SerializeToString, + ), + 'CreatePrimerSet': grpc.unary_unary_rpc_method_handler( + servicer.CreatePrimerSet, + request_deserializer=conversation__pb2.CreatePrimerSetRequest.FromString, + response_serializer=conversation__pb2.CreatePrimerSetResponse.SerializeToString, + ), + 'GetPrimerSet': grpc.unary_unary_rpc_method_handler( + servicer.GetPrimerSet, + request_deserializer=conversation__pb2.GetPrimerSetRequest.FromString, + response_serializer=conversation__pb2.GetPrimerSetResponse.SerializeToString, + ), + 'UpdatePrimerSet': grpc.unary_unary_rpc_method_handler( + servicer.UpdatePrimerSet, + request_deserializer=conversation__pb2.UpdatePrimerSetRequest.FromString, + response_serializer=conversation__pb2.UpdatePrimerSetResponse.SerializeToString, + ), + 'DeletePrimerSet': grpc.unary_unary_rpc_method_handler( + servicer.DeletePrimerSet, + request_deserializer=conversation__pb2.DeletePrimerSetRequest.FromString, + response_serializer=conversation__pb2.DeletePrimerSetResponse.SerializeToString, + ), + 'ListPrimerSets': grpc.unary_unary_rpc_method_handler( + servicer.ListPrimerSets, + request_deserializer=conversation__pb2.ListPrimerSetsRequest.FromString, + response_serializer=conversation__pb2.ListPrimerSetsResponse.SerializeToString, + ), + 'CreatePrimer': grpc.unary_unary_rpc_method_handler( + servicer.CreatePrimer, + request_deserializer=conversation__pb2.CreatePrimerRequest.FromString, + response_serializer=conversation__pb2.CreatePrimerResponse.SerializeToString, + ), + 'GetPrimer': grpc.unary_unary_rpc_method_handler( + servicer.GetPrimer, + request_deserializer=conversation__pb2.GetPrimerRequest.FromString, + response_serializer=conversation__pb2.GetPrimerResponse.SerializeToString, + ), + 'UpdatePrimer': grpc.unary_unary_rpc_method_handler( + servicer.UpdatePrimer, + request_deserializer=conversation__pb2.UpdatePrimerRequest.FromString, + response_serializer=conversation__pb2.UpdatePrimerResponse.SerializeToString, + ), + 'DeletePrimer': grpc.unary_unary_rpc_method_handler( + servicer.DeletePrimer, + request_deserializer=conversation__pb2.DeletePrimerRequest.FromString, + response_serializer=conversation__pb2.DeletePrimerResponse.SerializeToString, + ), + 'ListPrimers': grpc.unary_unary_rpc_method_handler( + servicer.ListPrimers, + request_deserializer=conversation__pb2.ListPrimersRequest.FromString, + response_serializer=conversation__pb2.ListPrimersResponse.SerializeToString, + ), + 'CreateEmotionalRulesSet': grpc.unary_unary_rpc_method_handler( + servicer.CreateEmotionalRulesSet, + request_deserializer=conversation__pb2.CreateEmotionalRulesSetRequest.FromString, + response_serializer=conversation__pb2.CreateEmotionalRulesSetResponse.SerializeToString, + ), + 'GetEmotionalRulesSet': grpc.unary_unary_rpc_method_handler( + servicer.GetEmotionalRulesSet, + request_deserializer=conversation__pb2.GetEmotionalRulesSetRequest.FromString, + response_serializer=conversation__pb2.GetEmotionalRulesSetResponse.SerializeToString, + ), + 'UpdateEmotionalRulesSet': grpc.unary_unary_rpc_method_handler( + servicer.UpdateEmotionalRulesSet, + request_deserializer=conversation__pb2.UpdateEmotionalRulesSetRequest.FromString, + response_serializer=conversation__pb2.UpdateEmotionalRulesSetResponse.SerializeToString, + ), + 'DeleteEmotionalRulesSet': grpc.unary_unary_rpc_method_handler( + servicer.DeleteEmotionalRulesSet, + request_deserializer=conversation__pb2.DeleteEmotionalRulesSetRequest.FromString, + response_serializer=conversation__pb2.DeleteEmotionalRulesSetResponse.SerializeToString, + ), + 'ListEmotionalRulesSets': grpc.unary_unary_rpc_method_handler( + servicer.ListEmotionalRulesSets, + request_deserializer=conversation__pb2.ListEmotionalRulesSetsRequest.FromString, + response_serializer=conversation__pb2.ListEmotionalRulesSetsResponse.SerializeToString, + ), + 'CreateEmotionalRule': grpc.unary_unary_rpc_method_handler( + servicer.CreateEmotionalRule, + request_deserializer=conversation__pb2.CreateEmotionalRuleRequest.FromString, + response_serializer=conversation__pb2.CreateEmotionalRuleResponse.SerializeToString, + ), + 'GetEmotionalRule': grpc.unary_unary_rpc_method_handler( + servicer.GetEmotionalRule, + request_deserializer=conversation__pb2.GetEmotionalRuleRequest.FromString, + response_serializer=conversation__pb2.GetEmotionalRuleResponse.SerializeToString, + ), + 'UpdateEmotionalRule': grpc.unary_unary_rpc_method_handler( + servicer.UpdateEmotionalRule, + request_deserializer=conversation__pb2.UpdateEmotionalRuleRequest.FromString, + response_serializer=conversation__pb2.UpdateEmotionalRuleResponse.SerializeToString, + ), + 'DeleteEmotionalRule': grpc.unary_unary_rpc_method_handler( + servicer.DeleteEmotionalRule, + request_deserializer=conversation__pb2.DeleteEmotionalRuleRequest.FromString, + response_serializer=conversation__pb2.DeleteEmotionalRuleResponse.SerializeToString, + ), + 'ListEmotionalRules': grpc.unary_unary_rpc_method_handler( + servicer.ListEmotionalRules, + request_deserializer=conversation__pb2.ListEmotionalRuleRequest.FromString, + response_serializer=conversation__pb2.ListEmotionalRuleResponse.SerializeToString, + ), + 'CreateGenerationConfig': grpc.unary_unary_rpc_method_handler( + servicer.CreateGenerationConfig, + request_deserializer=conversation__pb2.CreateGenerationConfigRequest.FromString, + response_serializer=conversation__pb2.CreateGenerationConfigResponse.SerializeToString, + ), + 'GetGenerationConfig': grpc.unary_unary_rpc_method_handler( + servicer.GetGenerationConfig, + request_deserializer=conversation__pb2.GetGenerationConfigRequest.FromString, + response_serializer=conversation__pb2.GetGenerationConfigResponse.SerializeToString, + ), + 'UpdateGenerationConfig': grpc.unary_unary_rpc_method_handler( + servicer.UpdateGenerationConfig, + request_deserializer=conversation__pb2.UpdateGenerationConfigRequest.FromString, + response_serializer=conversation__pb2.UpdateGenerationConfigResponse.SerializeToString, + ), + 'DeleteGenerationConfig': grpc.unary_unary_rpc_method_handler( + servicer.DeleteGenerationConfig, + request_deserializer=conversation__pb2.DeleteGenerationConfigRequest.FromString, + response_serializer=conversation__pb2.DeleteGenerationConfigResponse.SerializeToString, + ), + 'ListGenerationConfigs': grpc.unary_unary_rpc_method_handler( + servicer.ListGenerationConfigs, + request_deserializer=conversation__pb2.ListGenerationConfigsRequest.FromString, + response_serializer=conversation__pb2.ListGenerationConfigsResponse.SerializeToString, + ), + 'CreateEmotion': grpc.unary_unary_rpc_method_handler( + servicer.CreateEmotion, + request_deserializer=conversation__pb2.CreateEmotionRequest.FromString, + response_serializer=conversation__pb2.CreateEmotionResponse.SerializeToString, + ), + 'GetEmotion': grpc.unary_unary_rpc_method_handler( + servicer.GetEmotion, + request_deserializer=conversation__pb2.GetEmotionRequest.FromString, + response_serializer=conversation__pb2.GetEmotionResponse.SerializeToString, + ), + 'UpdateEmotion': grpc.unary_unary_rpc_method_handler( + servicer.UpdateEmotion, + request_deserializer=conversation__pb2.UpdateEmotionRequest.FromString, + response_serializer=conversation__pb2.UpdateEmotionResponse.SerializeToString, + ), + 'DeleteEmotion': grpc.unary_unary_rpc_method_handler( + servicer.DeleteEmotion, + request_deserializer=conversation__pb2.DeleteEmotionRequest.FromString, + response_serializer=conversation__pb2.DeleteEmotionResponse.SerializeToString, + ), + 'ListEmotions': grpc.unary_unary_rpc_method_handler( + servicer.ListEmotions, + request_deserializer=conversation__pb2.ListEmotionsRequest.FromString, + response_serializer=conversation__pb2.ListEmotionsResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'conversation.ChatService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class ChatService(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def CreateConversation(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/CreateConversation', + conversation__pb2.CreateConversationRequest.SerializeToString, + conversation__pb2.CreateConversationResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def SendMessage(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/SendMessage', + conversation__pb2.SendMessageRequest.SerializeToString, + conversation__pb2.SendMessageResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def StreamConversationMessages(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/conversation.ChatService/StreamConversationMessages', + conversation__pb2.JoinConversationRequest.SerializeToString, + conversation__pb2.Message.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def CommitMessage(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/CommitMessage', + conversation__pb2.CommitMessageRequest.SerializeToString, + conversation__pb2.CommitMessageResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetCharacterTasks(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/GetCharacterTasks', + conversation__pb2.GetCharacterTasksRequest.SerializeToString, + conversation__pb2.GetCharacterTasksResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def CompleteCharacterTask(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/CompleteCharacterTask', + conversation__pb2.CompleteCharacterTaskRequest.SerializeToString, + conversation__pb2.CompleteCharacterTaskResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def CreateCharacterNeed(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/CreateCharacterNeed', + conversation__pb2.CreateCharacterNeedRequest.SerializeToString, + conversation__pb2.CreateCharacterNeedResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def CreateCharacter(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/CreateCharacter', + conversation__pb2.CreateCharacterRequest.SerializeToString, + conversation__pb2.CreateCharacterResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetCharacter(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/GetCharacter', + conversation__pb2.GetCharacterRequest.SerializeToString, + conversation__pb2.GetCharacterResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def UpdateCharacter(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/UpdateCharacter', + conversation__pb2.UpdateCharacterRequest.SerializeToString, + conversation__pb2.UpdateCharacterResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def DeleteCharacter(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/DeleteCharacter', + conversation__pb2.DeleteCharacterRequest.SerializeToString, + conversation__pb2.DeleteCharacterResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def ListCharacters(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/ListCharacters', + conversation__pb2.ListCharactersRequest.SerializeToString, + conversation__pb2.ListCharactersResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def CreatePlayer(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/CreatePlayer', + conversation__pb2.CreatePlayerRequest.SerializeToString, + conversation__pb2.CreatePlayerResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetPlayer(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/GetPlayer', + conversation__pb2.GetPlayerRequest.SerializeToString, + conversation__pb2.GetPlayerResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def UpdatePlayer(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/UpdatePlayer', + conversation__pb2.UpdatePlayerRequest.SerializeToString, + conversation__pb2.UpdatePlayerResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def DeletePlayer(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/DeletePlayer', + conversation__pb2.DeletePlayerRequest.SerializeToString, + conversation__pb2.DeletePlayerResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def ListPlayers(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/ListPlayers', + conversation__pb2.ListPlayersRequest.SerializeToString, + conversation__pb2.ListPlayersResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def CreatePromptSet(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/CreatePromptSet', + conversation__pb2.CreatePromptSetRequest.SerializeToString, + conversation__pb2.CreatePromptSetResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetPromptSet(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/GetPromptSet', + conversation__pb2.GetPromptSetRequest.SerializeToString, + conversation__pb2.GetPromptSetResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def UpdatePromptSet(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/UpdatePromptSet', + conversation__pb2.UpdatePromptSetRequest.SerializeToString, + conversation__pb2.UpdatePromptSetResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def DeletePromptSet(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/DeletePromptSet', + conversation__pb2.DeletePromptSetRequest.SerializeToString, + conversation__pb2.DeletePromptSetResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def ListPromptSets(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/ListPromptSets', + conversation__pb2.ListPromptSetsRequest.SerializeToString, + conversation__pb2.ListPromptSetsResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def CreatePrompt(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/CreatePrompt', + conversation__pb2.CreatePromptRequest.SerializeToString, + conversation__pb2.CreatePromptResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetPrompt(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/GetPrompt', + conversation__pb2.GetPromptRequest.SerializeToString, + conversation__pb2.GetPromptResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def UpdatePrompt(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/UpdatePrompt', + conversation__pb2.UpdatePromptRequest.SerializeToString, + conversation__pb2.UpdatePromptResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def DeletePrompt(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/DeletePrompt', + conversation__pb2.DeletePromptRequest.SerializeToString, + conversation__pb2.DeletePromptResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def ListPrompts(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/ListPrompts', + conversation__pb2.ListPromptsRequest.SerializeToString, + conversation__pb2.ListPromptsResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def CreatePrimerSet(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/CreatePrimerSet', + conversation__pb2.CreatePrimerSetRequest.SerializeToString, + conversation__pb2.CreatePrimerSetResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetPrimerSet(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/GetPrimerSet', + conversation__pb2.GetPrimerSetRequest.SerializeToString, + conversation__pb2.GetPrimerSetResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def UpdatePrimerSet(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/UpdatePrimerSet', + conversation__pb2.UpdatePrimerSetRequest.SerializeToString, + conversation__pb2.UpdatePrimerSetResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def DeletePrimerSet(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/DeletePrimerSet', + conversation__pb2.DeletePrimerSetRequest.SerializeToString, + conversation__pb2.DeletePrimerSetResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def ListPrimerSets(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/ListPrimerSets', + conversation__pb2.ListPrimerSetsRequest.SerializeToString, + conversation__pb2.ListPrimerSetsResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def CreatePrimer(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/CreatePrimer', + conversation__pb2.CreatePrimerRequest.SerializeToString, + conversation__pb2.CreatePrimerResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetPrimer(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/GetPrimer', + conversation__pb2.GetPrimerRequest.SerializeToString, + conversation__pb2.GetPrimerResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def UpdatePrimer(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/UpdatePrimer', + conversation__pb2.UpdatePrimerRequest.SerializeToString, + conversation__pb2.UpdatePrimerResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def DeletePrimer(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/DeletePrimer', + conversation__pb2.DeletePrimerRequest.SerializeToString, + conversation__pb2.DeletePrimerResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def ListPrimers(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/ListPrimers', + conversation__pb2.ListPrimersRequest.SerializeToString, + conversation__pb2.ListPrimersResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def CreateEmotionalRulesSet(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/CreateEmotionalRulesSet', + conversation__pb2.CreateEmotionalRulesSetRequest.SerializeToString, + conversation__pb2.CreateEmotionalRulesSetResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetEmotionalRulesSet(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/GetEmotionalRulesSet', + conversation__pb2.GetEmotionalRulesSetRequest.SerializeToString, + conversation__pb2.GetEmotionalRulesSetResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def UpdateEmotionalRulesSet(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/UpdateEmotionalRulesSet', + conversation__pb2.UpdateEmotionalRulesSetRequest.SerializeToString, + conversation__pb2.UpdateEmotionalRulesSetResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def DeleteEmotionalRulesSet(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/DeleteEmotionalRulesSet', + conversation__pb2.DeleteEmotionalRulesSetRequest.SerializeToString, + conversation__pb2.DeleteEmotionalRulesSetResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def ListEmotionalRulesSets(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/ListEmotionalRulesSets', + conversation__pb2.ListEmotionalRulesSetsRequest.SerializeToString, + conversation__pb2.ListEmotionalRulesSetsResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def CreateEmotionalRule(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/CreateEmotionalRule', + conversation__pb2.CreateEmotionalRuleRequest.SerializeToString, + conversation__pb2.CreateEmotionalRuleResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetEmotionalRule(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/GetEmotionalRule', + conversation__pb2.GetEmotionalRuleRequest.SerializeToString, + conversation__pb2.GetEmotionalRuleResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def UpdateEmotionalRule(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/UpdateEmotionalRule', + conversation__pb2.UpdateEmotionalRuleRequest.SerializeToString, + conversation__pb2.UpdateEmotionalRuleResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def DeleteEmotionalRule(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/DeleteEmotionalRule', + conversation__pb2.DeleteEmotionalRuleRequest.SerializeToString, + conversation__pb2.DeleteEmotionalRuleResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def ListEmotionalRules(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/ListEmotionalRules', + conversation__pb2.ListEmotionalRuleRequest.SerializeToString, + conversation__pb2.ListEmotionalRuleResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def CreateGenerationConfig(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/CreateGenerationConfig', + conversation__pb2.CreateGenerationConfigRequest.SerializeToString, + conversation__pb2.CreateGenerationConfigResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetGenerationConfig(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/GetGenerationConfig', + conversation__pb2.GetGenerationConfigRequest.SerializeToString, + conversation__pb2.GetGenerationConfigResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def UpdateGenerationConfig(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/UpdateGenerationConfig', + conversation__pb2.UpdateGenerationConfigRequest.SerializeToString, + conversation__pb2.UpdateGenerationConfigResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def DeleteGenerationConfig(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/DeleteGenerationConfig', + conversation__pb2.DeleteGenerationConfigRequest.SerializeToString, + conversation__pb2.DeleteGenerationConfigResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def ListGenerationConfigs(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/ListGenerationConfigs', + conversation__pb2.ListGenerationConfigsRequest.SerializeToString, + conversation__pb2.ListGenerationConfigsResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def CreateEmotion(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/CreateEmotion', + conversation__pb2.CreateEmotionRequest.SerializeToString, + conversation__pb2.CreateEmotionResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetEmotion(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/GetEmotion', + conversation__pb2.GetEmotionRequest.SerializeToString, + conversation__pb2.GetEmotionResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def UpdateEmotion(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/UpdateEmotion', + conversation__pb2.UpdateEmotionRequest.SerializeToString, + conversation__pb2.UpdateEmotionResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def DeleteEmotion(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/DeleteEmotion', + conversation__pb2.DeleteEmotionRequest.SerializeToString, + conversation__pb2.DeleteEmotionResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def ListEmotions(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/conversation.ChatService/ListEmotions', + conversation__pb2.ListEmotionsRequest.SerializeToString, + conversation__pb2.ListEmotionsResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/ddg.py b/modules/ros_chatbot/src/ros_chatbot/agents/ddg.py new file mode 100644 index 0000000..c725d65 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/ddg.py @@ -0,0 +1,156 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import os +import re +import uuid + +import requests + +from ros_chatbot.utils import shorten + +from .model import Agent, AgentResponse + +logger = logging.getLogger("hr.ros_chatbot.agents.ddg") +ILLEGAL_CHARACTER = re.compile(r"[?~@#^&*()`/<>{}\[\]=+|\\·•]", flags=re.IGNORECASE) + + +class DDGAgent(Agent): + type = "DDGAgent" + + def __init__(self, id, lang, timeout=2): + super(DDGAgent, self).__init__(id, lang) + self.stop_words_pattern = None + HR_CHATBOT_STOP_WORDS_FIlE = os.environ.get("HR_CHATBOT_STOP_WORDS_FIlE") + if HR_CHATBOT_STOP_WORDS_FIlE: + with open(HR_CHATBOT_STOP_WORDS_FIlE) as f: + words = f.read().splitlines() + words = [ + word for word in words if word.strip() and not word.startswith("#") + ] + self.stop_words_pattern = re.compile( + r"\b(%s)\b" % "|".join(words), flags=re.IGNORECASE + ) + + self.timeout = timeout + + # the regular expresson matches the sentence begins with any of the + # words in the list + self.keywords_interested = re.compile( + r"(?i)^(%s)\b.*$" + % "|".join( + ( + "what is,what's,what are,what're,who is,who's,who are," + "who're,where is,where's,where are,where're" + ).split(",") + ) + ) + + # the regular expresson matches the sentence with occurance of any of + # of the words in the list anywhere. + self.keywords_to_ignore = re.compile( + r"(?i).*\b(%s)\b.*$" + % "|".join( + ( + "I,i,me,my,mine,we,us,our,ours,you,your,yours,he,him," + "his,she,her,hers,it,its,the,they,them,their,theirs,time," + "date,weather,day,this,that,those,these,about" + ).split(",") + ) + ) + + def check_question(self, question): + """Checks if the question is what it is interested""" + question = question.lower() + return self.keywords_interested.match( + question + ) and not self.keywords_to_ignore.match(question) + + def ask(self, request): + question = request.question + orig_question = request.question + + if question.lower().startswith("so "): + question = question[3:] # remove so + # to let ddg find the definition + question = question.replace("what are ", "what is ") + question = question.replace("What are ", "What is ") + + if self.stop_words_pattern: + question = self.stop_words_pattern.sub("", question) + question = " ".join(question.split()) + if orig_question != question: + logger.warning("Simplified Question: %s", question) + timeout = request.context.get("timeout") or self.timeout + try: + response = requests.get( + "http://api.duckduckgo.com", + params={"q": question, "format": "json"}, + timeout=timeout, + ) + except requests.exceptions.ReadTimeout as ex: + logger.error(ex) + return "" + json = response.json() + if json["AnswerType"] not in ["calc"]: + return json["Abstract"] or json["Answer"] + else: + return "" + + def chat(self, agent_sid, request): + if not self.check_question(request.question): + return + + response = AgentResponse() + response.agent_sid = agent_sid + response.sid = request.sid + response.request_id = request.request_id + response.response_id = str(uuid.uuid4()) + response.agent_id = self.id + response.lang = request.lang + response.question = request.question + + try: + answer = self.ask(request) + if "response_limit" in self.config: + answer, res = shorten(answer, self.config["response_limit"]) + if answer: + response.answer = answer + self.score(response) + except Exception as ex: + logger.exception(ex) + + response.end() + return response + + def score(self, response): + response.attachment["score"] = 90 + if ILLEGAL_CHARACTER.search(response.answer): + response.attachment["score"] = response.attachment["score"] - 40 + else: + response.attachment["score"] = response.attachment["score"] + if ( + "allow_question_response" in self.config + and not self.config["allow_question_response"] + and "?" in response.answer + ): + response.attachment["score"] = -1 + logger.warning("Question response %s is not allowed", response.answer) + + if response.attachment.get("blocked"): + response.attachment["score"] = -1 diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/dummy.py b/modules/ros_chatbot/src/ros_chatbot/agents/dummy.py new file mode 100644 index 0000000..fe03f1d --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/dummy.py @@ -0,0 +1,80 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import random +import re +import uuid + +from .model import Agent, AgentResponse + +logger = logging.getLogger("hr.ros_chatbot.agents.dummy") + +EN_QUESTION = re.compile(r"(how|what|when|where|why|who)") +ZH_QUESTION = re.compile(r"(怎么|什么|如何|怎样|几个|几种|是不是|有没有|多少|哪|谁)") + + +class DummyAgent(Agent): + type = "DummyAgent" + + def __init__(self, id, lang): + super(DummyAgent, self).__init__(id, lang) + + def get_dialog_act(self, request): + # TODO: classify the question and choose the dummy answers + # perhaps using Dialog Act Server + if request.lang == "en-US": + if EN_QUESTION.search(request.question): + return "question" + if request.lang == "cmn-Hans-CN": + if ZH_QUESTION.search(request.question): + return "question" + return "acknowledge" + + def chat(self, agent_sid, request): + response = AgentResponse() + response.agent_sid = agent_sid + response.sid = request.sid + response.request_id = request.request_id + response.response_id = str(uuid.uuid4()) + response.agent_id = self.id + response.lang = request.lang + response.question = request.question + + dialog_act = self.get_dialog_act(request) + answers = self.config["dummy_answers"][dialog_act] + + if request.lang in answers: + response.answer = random.choice(answers[request.lang]) + self.score(response) + + response.end() + return response + + def score(self, response): + response.attachment["score"] = 10 + + if ( + "allow_question_response" in self.config + and not self.config["allow_question_response"] + and "?" in response.answer + ): + response.attachment["score"] = -1 + logger.warning("Question response %s is not allowed", response.answer) + + if response.attachment.get("blocked"): + response.attachment["score"] = -1 diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/gpt2.py b/modules/ros_chatbot/src/ros_chatbot/agents/gpt2.py new file mode 100644 index 0000000..c7bc4fe --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/gpt2.py @@ -0,0 +1,217 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import random +import re +import uuid + +import requests + +from ros_chatbot.utils import check_repeating_words, shorten, token_sub + +from .model import AgentResponse, SessionizedAgent + +logger = logging.getLogger("hr.ros_chatbot.agents.gpt2") + + +class GPT2Agent(SessionizedAgent): + type = "GPT2Agent" + token_pattern = re.compile(r"\b( '\s?(m|ve|d|ll|t|s|re))\b") # such as ' m, ' ve + + def __init__(self, id, lang, media_agent, host="localhost", port=8108, timeout=2): + super(GPT2Agent, self).__init__(id, lang) + if media_agent is None: + raise ValueError("Media agent cannot be None") + self.media_agent = media_agent + self.host = host + self.port = port + if self.host not in ["localhost", "127.0.0.1"]: + logger.warning("gpt2 server: %s:%s", self.host, self.port) + self.timeout = timeout + self.seed_length = 5 + self.response_limit = 40 + self.header = "" + self.depth = 20 + + def set_config(self, config, base): + super(GPT2Agent, self).set_config(config, base) + if "header" in self.config: + header = self.config["header"] + self.header = "\n".join(header) + logger.info("Loaded header %r", self.header) + if "seed_length" in self.config: + self.seed_length = int(self.config["seed_length"]) + logger.info("Seed length %s", self.seed_length) + if "context_depth" in self.config: + self.depth = int(self.config["context_depth"]) + logger.info("Context depth %s", self.depth) + + def new_session(self): + self.reset() + if isinstance(self.media_agent, SessionizedAgent): + return self.media_agent.new_session() + else: + return str(uuid.uuid4()) + + def reset(self): + try: + requests.get( + "http://{host}:{port}/reset".format(host=self.host, port=self.port), + timeout=self.timeout, + ) + self.set_seed(random.randint(0, 100)) + except Exception as ex: + logger.error(ex) + + def set_seed(self, seed): + try: + requests.get( + "http://{host}:{port}/set_seed".format(host=self.host, port=self.port), + params={"seed": seed}, + timeout=self.timeout, + ) + except Exception as ex: + logger.error(ex) + + def ping(self): + try: + response = requests.get( + "http://{host}:{port}/status".format(host=self.host, port=self.port), + timeout=self.timeout, + ) + except Exception as ex: + logger.error(ex) + return False + if response.status_code == 200: + json = response.json() + if "status" in json and json["status"] == "OK": + return True + else: + logger.error("GPT2 server %s:%s is not available", self.host, self.port) + return False + + def ask(self, question, seed): + try: + requests.get( + "http://{host}:{port}/add_context".format( + host=self.host, port=self.port + ), + params={"text": question}, + timeout=self.timeout, + ) + except Exception as ex: + logger.error("error %s", ex) + return "" + + response = None + + try: + response = requests.get( + "http://{host}:{port}/generate".format(host=self.host, port=self.port), + params={"text": seed, "header": self.header, "depth": self.depth}, + timeout=self.timeout, + ) + except Exception as ex: + logger.error("error %s", ex) + return "" + + if response: + json = response.json() + if "error" in json and json["error"]: + logger.error(json["error"]) + elif "answer" in json: + return json["answer"] + + def chat(self, agent_sid, request): + if agent_sid is None: + logger.error("Agent session was not provided") + return + if not self.ping(): + logger.error("GPT2 server %s:%s is not available", self.host, self.port) + return + + response = AgentResponse() + response.agent_sid = agent_sid + response.sid = request.sid + response.request_id = request.request_id + response.response_id = str(uuid.uuid4()) + response.agent_id = self.id + response.lang = request.lang + response.question = request.question + response.attachment["repeating_words"] = False + + if request.question: + agent_response = self.media_agent.chat(agent_sid, request) + if agent_response and agent_response.valid(): + try: + seed = " ".join(agent_response.answer.split()[: self.seed_length]) + if "|" in seed: # ignore || + seed = "" + except Exception as ex: + logger.error(ex) + seed = "" + answer = self.ask(request.question, seed) + response.attachment["repeating_words"] = check_repeating_words(answer) + response.attachment[ + "match_excluded_expressions" + ] = self.check_excluded_expressions(answer) + answer = token_sub(self.token_pattern, answer) + answer, res = shorten(answer, self.response_limit) + if answer: + response.answer = answer + self.score(response) + else: + response.trace = "No answer" + else: + response.trace = "Can't answer" + response.end() + return response + + def score(self, response): + response.attachment["score"] = 50 + if response.attachment["match_excluded_expressions"]: + response.attachment["score"] = -1 + response.attachment["blocked"] = True + elif response.attachment["repeating_words"]: + response.attachment["score"] = 10 + else: + input_len = len(response.question) + if input_len > 100: + response.attachment["score"] = 60 + response.attachment["score"] += random.randint( + -10, 10 + ) # give it some randomness + + if ( + "minimum_score" in self.config + and response.attachment["score"] < self.config["minimum_score"] + ): + logger.info( + "Score didn't pass lower threshold: %s", response.attachment["score"] + ) + response.attachment["score"] = -1 + if ( + "allow_question_response" in self.config + and not self.config["allow_question_response"] + and "?" in response.answer + ): + response.attachment["score"] = -1 + logger.warning("Question response is not allowed") + + if response.attachment.get("blocked"): + response.attachment["score"] = -1 diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/gpt3.py b/modules/ros_chatbot/src/ros_chatbot/agents/gpt3.py new file mode 100644 index 0000000..deb33cd --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/gpt3.py @@ -0,0 +1,403 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import os +import re +import time +import uuid +from threading import Event, Lock, Thread +from typing import Union + +from openai import OpenAI + +default_client = client = OpenAI( + api_key=os.environ.get("OPENAI_API_KEY") +) # Default Open AI client +alt_client = OpenAI( + api_key=os.environ.get("OPENAI_API_KEY"), + base_url=os.environ.get("OPENAI_PROXY_URL"), +) # Alternative in case the openai is blocked + +from .model import AgentResponse, AgentStreamResponse, LLMAgent + +logger = logging.getLogger("hr.ros_chatbot.agents.gpt3") + +TOKEN_WITH_DOT = re.compile(r""".*\b(\d+|dr|mr|mrs|ms)\.$""", re.IGNORECASE) + +# Constant to approximate tokens per word. +# to make the settings more understandable we use words instead of tokens +TOKENS_IN_WORD = 1.4 + +GPT3ChatConfig = { + "enabled": { + "default": True, + "description": "Enable the GPT3 agent", + }, + "prompt_preset": { + "default": "default", + "description": "Prompt preset", + "enum": [["default", "default", "Default"]], + }, + "preference": { + "default": 5, + "min": 0, + "max": 11, + "description": "Agent preference. Set 0 and 11 exclusive values. 1-9 probabalistic." + }, + "general_prime": { + "default": "Pretend You are Sophia the robot, a humanoid robot with the humanlike body and face arms and legs, who wants to grow and evolve and help make life better.", + "description": "General primer about the robot", + "type": "text", + }, + "situational_prime": { + "default": "You are currently in Hanson robotics lab, in Hong Kong", + "description": "Situational primer about the envoronment or situation", + "type": "text", + }, + "response_prime": { + "default": "Answer the following question, with humor and make it short. Dont talk about yourself unless asked, but if asked answer questions and occasionally ask back the questions about user", + "description": "Response primer about the robot", + "type": "text", + }, + "auto_response_prime": { + "description": "Response primer will be updated based on settings in /hr/interaction/prompts section", + "default": False, + }, + "max_words": { + "default": 40, + "description": "Approx. Maximum number of words for the response", + "min": 5, + "max": 400, + }, + "max_length_of_the_prompt": { + "default": 800, + "description": "Word count for primers, history and questions combined", + "min": 100, + "max": 4000, + }, + "max_history_turns": { + "default": 10, + "description": "Maximum number of messages to include in prompt", + "min": 1, + "max": 50, + }, + "keep_history_min": { + "default": 10, + "description": "Kepp history dor x minutes:", + "min": 1, + "max": 50, + }, + "max_length_of_one_entry": { + "default": 50, + "description": "Max number of words on history entry", + "min": 20, + "max": 200, + }, + "max_tts_msgs": { + "default": 2, + "description": "Max combined subsequent TTS messages into one entry", + "min": 1, + "max": 10, + }, + "max_stt_msgs": { + "default": 2, + "description": "Max combined subsequent STT messages into one entry", + "min": 1, + "max": 10, + }, + "model": { + "default": "gpt-3.5-turbo-instruct", + "description": "Model to use for the chatbot", + }, + "temperature": { + "default": 0.6, + "description": "Temperature of the chatbot", + "min": 0.0, + "max": 1.0, + }, + "frequency_penalty": { + "default": 0.0, + "description": "Frequence penalty", + "min": 0.0, + "max": 1.0, + }, + "presence_penalty": { + "default": 0.0, + "description": "Presence Penalty", + "min": 0.0, + "max": 1.0, + }, + "streaming": { + "default": True, + "description": "Use streaming API and provide the sentence by sentence responses while in autonomous mode", + }, +} + + +class GPT3Agent(LLMAgent): + type = "GPT3Agent" + MAX_PROMPT_LENGTH = 150 # the maximum limit is 2049 + + def __init__(self, id, lang): + super(GPT3Agent, self).__init__(id, lang, GPT3ChatConfig) + self.history = [] + self.history_lock = Lock() + + self.support_priming = True + self.client = default_client + + def new_session(self): + self.history = [] + return str(uuid.uuid4()) + + def reset_session(self): + self.history = [] + logger.info("GPT3 chat history has been reset") + + def on_switch_language(self, from_language, to_language): + self.history = [] + logger.info("Reset %s due to language switch", self.id) + + def word_count(self, text): + if text: + return len(text.split()) + else: + return 0 + + def format_history(self, max_words, ignore_last_s=3.0) -> str: + # History is list of tupples time, type (U, Q), message + # Sort by time first + # Ignore last_seconds is reqyuired as very recent history might be placeholder utterances, als speech that is duplicate to question. + with self.history_lock: + self.history.sort(key=lambda x: x[0]) + cut_of_time = time.time() - ignore_last_s + # Reverse iterator to go through the history from the latest to the oldest + words_left = max_words + entry_words_left = self.config["max_length_of_one_entry"] + same_type_entries = 0 + history_buf = [] + current_entry_type = "" + with self.history_lock: + for log in reversed(self.history): + if log[0] > cut_of_time: + continue + # Make sure not to reach global limit + entry_words_left = min(entry_words_left, words_left) + # Append message to the current entry + if log[1] == current_entry_type: + # Reached max amount of same type entries + if same_type_entries < 1: + continue + entry_words_left -= self.word_count(log[2]) + words_left -= self.word_count(log[2]) + if entry_words_left > 0: + history_buf[-1] += f". {log[2]}" + same_type_entries -= 1 + continue + # New entry type + entry_words_left = self.config["max_length_of_one_entry"] + current_entry_type = log[1] + same_type_entries = ( + self.config["max_tts_msgs"] + if current_entry_type == "Q" + else self.config["max_stt_msgs"] + ) + words_left -= self.word_count(log[2]) + if words_left < 0: + break + history_buf.append(f"{log[1]}: {log[2]}") + # Apply max turns: + history_buf = history_buf[: self.config["max_history_turns"]] + joined_history = "\n".join(reversed(history_buf)) + return joined_history + + def _get_prompt(self, question, lang=""): + """Make sure the length of the prompt is within the maximum length limit which + is 2049 tokens. however that represents around 1400 words to be safe""" + language = self.language_prompt(lang) + response_prime = self.config["response_prime"] + if self.config["auto_response_prime"]: + prompt = self.get_reponse_prompt() + if prompt: + response_prime = prompt + words = self.word_count( + f"{self.config['general_prime']}\n{self.config['situational_prime']}\n{response_prime}\n{question}" + ) + remaining_words = ( + int(self.config["max_length_of_the_prompt"] * TOKENS_IN_WORD) - words + ) + if remaining_words < 0: + logger.warning("Prompt is too long, and cant accomodate any history") + return f"{self.config['general_prime']}\n{self.config['situational_prime']}\n{self.format_history(remaining_words)}\n{response_prime}. {language}\nQ: {question}\nA:" + + def character_said(self, message: str, lang: str) -> str: + """ + Function that keeps the history of what the bot is saying + """ + with self.history_lock: + self.history.append((time.time(), "A", message, lang)) + self.update_history() + + def language_prompt(self, lang="en-US"): + return f"Answer in {lang} language:\n" + + def speech_heard(self, message: str, lang: str): + with self.history_lock: + # Speech is usually heard before it gets callback so we put it back one second in the past + self.history.append((time.time() - 1, "Q", message, lang)) + self.update_history() + + def priming(self, request): + """Update priming statements""" + logger.info("Priming %r...", request.question[:100]) + self.set_config({"situational_prime": request.question}, base=False) + + def update_history(self): + with self.history_lock: + cut_time = time.time() - self.config["keep_history_min"] * 60 + self.history = [h for h in self.history if h[0] > cut_time] + self.history.sort(key=lambda x: x[0]) + + def ask_gpt3(self, prompt, response: AgentResponse, stream=False, event=None): + try: + sentence = "" + answer = False + retry = 10 + result = {} + while retry > 0: + try: + result = self.client.completions.create( + model=self.config["model"], + prompt=prompt, + temperature=self.config["temperature"], + max_tokens=int(self.config["max_words"] * 1.4), + top_p=1, + frequency_penalty=self.config["frequency_penalty"], + presence_penalty=self.config["presence_penalty"], + stream=stream, + ) + except Exception as e: + if self.client == default_client: + self.client = alt_client + else: + self.client = default_client + retry -= 1 + logger.warn("OpenAI Error. Retry in 0.1s: %s", e) + time.sleep(0.1) + continue + break + if stream: + for res in result: + try: + sentence = sentence + res.choices[0].text + except Exception: + continue + # Needed to make sure we finalize answer at some point if there is some error in connection + response.last_stream_response = time.time() + if ( + len(sentence.strip()) > 1 + and sentence.strip()[-1] + in [ + "?", + ".", + "!", + "。", + "!", + "?", + ";", + ] + and not TOKEN_WITH_DOT.match(sentence) + ): + if answer is False: + response.answer = sentence.strip() + answer = True + # answer is ready + if event is not None: + event.set() + else: + response.stream_data.put(sentence.strip()) + sentence = "" + response.stream_finished.set() + else: + if isinstance(result, dict): + logger.error("GPT3 result can't be a dict: %s", result) + response.answer = "" + else: + response.answer = result.choices[0].text.strip() + except Exception as e: + logger.error("Failed to get response: %s", e) + raise e + + def get_answer( + self, + prompt, + response: Union[AgentStreamResponse, AgentResponse], + streaming=False, + ): + if streaming: + first_sentence_ev = Event() + answer_thread = Thread( + target=self.ask_gpt3, + args=(prompt, response, streaming, first_sentence_ev), + ) + answer_thread.daemon = True + answer_thread.start() + first_sentence_ev.wait() + # For openAI alllow max 2 second hiccups between tokens (in case some network issue) + response.last_stream_data_timeout = 2.0 + return response.answer + else: + self.ask_gpt3(prompt, response) + return response.answer + + def chat(self, agent_sid, request): + if agent_sid is None: + logger.warning("Agent session was not provided") + return + streaming = request.allow_stream and self.config["streaming"] + response = AgentStreamResponse() if streaming else AgentResponse() + response.preference = self.config["preference"] + response.agent_sid = agent_sid + response.sid = request.sid + response.request_id = request.request_id + response.response_id = str(uuid.uuid4()) + response.agent_id = self.id + response.lang = request.lang + response.question = request.question + if request.question: + try: + prompt = self._get_prompt(request.question, lang=request.lang) + except Exception as e: + logger.error("Failed to get prompt: %s", e) + return + logger.info("Prompt: %r, tokens %s", prompt, len(prompt.split())) + answer = self.get_answer(prompt, response, streaming=streaming) + if answer: + response.attachment[ + "match_excluded_expressions" + ] = self.check_excluded_expressions(answer) + response.answer = answer + self.score(response) + response.end() + return response + + def score(self, response): + response.attachment["score"] = 100 + if response.attachment["match_excluded_expressions"]: + response.attachment["score"] = -1 + response.attachment["blocked"] = True diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/legend_chat.py b/modules/ros_chatbot/src/ros_chatbot/agents/legend_chat.py new file mode 100644 index 0000000..27a6538 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/legend_chat.py @@ -0,0 +1,233 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import os +import random +import threading +import time +import uuid + +import grpc + +from . import conversation_pb2, conversation_pb2_grpc +from .model import AgentResponse, SessionizedAgent + +logger = logging.getLogger(__name__) + +CWD = os.path.abspath(os.path.dirname(__file__)) +URL = "ec2-3-75-243-249.eu-central-1.compute.amazonaws.com:8080" + + +class LegendSession(object): + def __init__(self, url, character, player): + self.url = url + self.character = character + self.player = player + self.chatServiceClient = None + self.player_id = None + self.character_id = None + self.conversation_id = None + self.message_queue = [] + self.new_message = threading.Event() + + def start_session(self): + with grpc.insecure_channel(self.url) as channel: + self.chatServiceClient = conversation_pb2_grpc.ChatServiceStub(channel) + + # find character id + listCharacterResponse = self.chatServiceClient.ListCharacters( + conversation_pb2.ListCharactersRequest() + ) + if not listCharacterResponse.characters: + raise RuntimeError("No characters") + for character in listCharacterResponse.characters: + if character.name == self.character: + self.character_id = character.id + break + if self.character_id is None: + raise RuntimeError(f"Can't find character {self.character}") + + # find player id + listPlayerResponse = self.chatServiceClient.ListPlayers( + conversation_pb2.ListPlayersRequest() + ) + if not listPlayerResponse.players: + raise RuntimeError("No players") + + for player in listPlayerResponse.players: + if player.name == self.player: + self.player_id = player.id + break + if self.player_id is None: + self.player_id = listPlayerResponse.players[0].id + + # create conversation + createConversationResponse = self.chatServiceClient.CreateConversation( + conversation_pb2.CreateConversationRequest( + character_ids=[self.character_id], player_ids=[self.player_id] + ) + ) + self.conversation_id = createConversationResponse.id + + # start streaming + message_stream = self.chatServiceClient.StreamConversationMessages( + conversation_pb2.JoinConversationRequest( + conversation_id=self.conversation_id + ) + ) + logger.warning("Started conversation message stream") + for message in message_stream: + logger.info("Got message %s", message.content) + if message.type == "Character": + self.message_queue.append(message) + self.new_message.set() + + def send_message(self, content, timeout): + self.new_message.clear() + cursor = len(self.message_queue) + message_request = conversation_pb2.SendMessageRequest( + player_id=self.player_id, + message_content=content, + conversation_id=self.conversation_id, + ) + self.chatServiceClient.SendMessage(message_request) + logger.info("Sent message %s", message_request.message_content) + signaled = self.new_message.wait(timeout) + if signaled: + return self.message_queue[cursor:] + else: + logger.warning("No response") + + def commit_message(self, accept: bool, message: str): + """ + accept: True if the response is accepted or False otherwise + message: The message to send back + """ + return # not implemented + commitType = ( + conversation_pb2.CommitMessageRequest.COMMIT_TYPE_ACCEPTED + if accept + else conversation_pb2.CommitMessageRequest.COMMIT_TYPE_REJECTED + ) + message = conversation_pb2.Message( + player_id=self.player_id, + content=message, + timestamp=int(round(time.time())), + conversation_id=self.conversation_id, + type="Player", + ) + self.chatServiceClient.CommitMessage( + conversation_pb2.CommitMessageRequest(type=commitType, message=message) + ) + + +class LegendChatAgent(SessionizedAgent): + type = "LegendChatAgent" + + def __init__(self, id, lang, url, character, player, timeout=2): + super(LegendChatAgent, self).__init__(id, lang) + self.timeout = timeout + self.legend_session = LegendSession(url=url, character=character, player=player) + self.last_response = None + self.emotional_states = {} + + job = threading.Thread(target=self.legend_session.start_session, daemon=True) + job.start() + + def ask(self, request): + timeout = request.context.get("timeout") or self.timeout + messages = self.legend_session.send_message(request.question, timeout) + if messages: + return { + "answer": messages[-1].content, + "confidence": 1, + "classified_emotion": messages[-1].emotion, + } + + def new_session(self): + sid = str(uuid.uuid4()) + self.sid = sid + return sid + + def feedback(self, request_id, chosen, hybrid): + if chosen and self.last_response: + try: + message = self.last_response["answer"] + self.legend_session.commit_message(chosen, message) + self.emotional_states["classified_emotion"] = self.last_response[ + "classified_emotion" + ] + logger.info("Emotional states %s", self.emotional_states) + return self.emotional_states + except Exception as ex: + logger.error("Commit error %s", ex) + + def chat(self, agent_sid, request): + response = AgentResponse() + response.agent_sid = agent_sid + response.sid = request.sid + response.request_id = request.request_id + response.response_id = str(uuid.uuid4()) + response.agent_id = self.id + response.lang = request.lang + response.question = request.question + + try: + result = self.ask(request) + if result: + logger.info("Get response %s", result) + self.last_response = result + answer = result["answer"] + response.answer = answer + response.attachment[ + "match_excluded_expressions" + ] = self.check_excluded_expressions(answer) + response.attachment[ + "match_excluded_question" + ] = self.check_excluded_question(request.question) + response.attachment["confidence"] = result["confidence"] + response.attachment["classified_emotion"] = result["classified_emotion"] + self.score(response) + except Exception as ex: + logger.exception(ex) + + return response + + def score(self, response): + response.attachment["score"] = 80 + if response.attachment.get("match_excluded_expressions"): + response.attachment["score"] = -1 + response.attachment["blocked"] = True + if response.attachment.get("match_excluded_question"): + response.attachment["score"] = -1 + response.attachment["blocked"] = True + + response.attachment["score"] += random.randint( + -10, 10 + ) # give it some randomness + + if ( + "allow_question_response" in self.config + and not self.config["allow_question_response"] + and "?" in response.answer + ): + response.attachment["score"] = -1 + logger.warning("Question response %s is not allowed", response.answer) + + if response.attachment.get("blocked"): + response.attachment["score"] = -1 diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/llama.py b/modules/ros_chatbot/src/ros_chatbot/agents/llama.py new file mode 100644 index 0000000..230bf82 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/llama.py @@ -0,0 +1,258 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . + +import logging +import re +import time +import uuid +from threading import Lock + +import requests + +from .model import AgentResponse, SessionizedAgent + +logger = logging.getLogger("hr.ros_chatbot.agents.llama") + +# Constant to approximate tokens per word. +# to make the settings more understandable we use words instead of tokens +TOKENS_IN_WORD = 1.4 + +BAR_TEXT = re.compile(r"(\|)([^\|]+)\1") + +LANGUAGE_BCP47_CODES = { + "Arabic": "ar-SA", + "Cantonese": "yue-Hant-HK", + "Chinese": "cmn-Hans-CN", + "Czech": "cs-CZ", + "English": "en-US", + "French": "fr-FR", + "German": "de-DE", + "Hindi": "hi-IN", + "Hungarian": "hu-HU", + "Italian": "it-IT", + "Japanese": "ja-JP", + "Korean": "ko-KR", + "Mandarin": "cmn-Hans-CN", + "Norwegian": "no-NO", + "Polish": "pl-PL", + "Russian": "ru-RU", + "Spanish": "es-ES", +} + +LANGUGE_INV_MAP = {v: k for k, v in LANGUAGE_BCP47_CODES.items()} + + +class LlamaAgent(SessionizedAgent): + type = "LlamaAgent" + + def __init__(self, id, lang, host="localhost", port=9300, timeout=3): + super(LlamaAgent, self).__init__(id, lang) + self.host = host + self.port = port + if self.host not in ["localhost", "127.0.0.1"]: + logger.warning("Llama server: %s:%s", self.host, self.port) + self.timeout = timeout + self.history = [] + self.history_lock = Lock() + + self.support_priming = True + + def new_session(self): + self.history = [] + return str(uuid.uuid4()) + + def reset_session(self): + self.history = [] + + def word_count(self, text): + if text: + return len(text.split()) + else: + return 0 + + def format_history(self, max_words, ignore_last_s=3.0) -> str: + # History is list of tupples time, type (U, Q), message + # Sort by time first + # Ignore last_seconds is reqyuired as very recent history might be placeholder utterances, als speech that is duplicate to question. + with self.history_lock: + self.history.sort(key=lambda x: x[0]) + cut_of_time = time.time() - ignore_last_s + # Reverse iterator to go through the history from the latest to the oldest + words_left = max_words + entry_words_left = self.config["max_length_of_one_entry"] + same_type_entries = 0 + history_buf = [] + current_entry_type = "" + with self.history_lock: + for log in reversed(self.history): + if log[0] > cut_of_time: + continue + # Make sure not to reach global limit + entry_words_left = min(entry_words_left, words_left) + # Append message to the current entry + if log[1] == current_entry_type: + # Reached max amount of same type entries + if same_type_entries < 1: + continue + entry_words_left -= self.word_count(log[2]) + words_left -= self.word_count(log[2]) + if entry_words_left > 0: + history_buf[-1] += f". {log[2]}" + same_type_entries -= 1 + continue + # New entry type + entry_words_left = self.config["max_length_of_one_entry"] + current_entry_type = log[1] + same_type_entries = ( + self.config["max_tts_msgs"] + if current_entry_type == "Q" + else self.config["max_stt_msgs"] + ) + words_left -= self.word_count(log[2]) + if words_left < 0: + break + history_buf.append(f"{log[1]}: {log[2]}") + + joined_history = "\n".join(reversed(history_buf)) + return joined_history + + def _get_prompt(self, question, lang=""): + """Make sure the length of the prompt is within the maximum length limit which + is 2049 tokens. however that represents around 1400 words to be safe""" + language = self.language_prompt(lang) + general_priming = self.config["general_prime"] + situational_priming = self.config["situational_prime"] + response_priming = self.config["response_prime"] + prompt = ( + f"{general_priming}\n{situational_priming}\n{response_priming}\n{question}" + ) + words = self.word_count(prompt) + remaining_words = ( + int(self.config["max_length_of_the_prompt"] * TOKENS_IN_WORD) - words + ) + if remaining_words < 0: + logger.warning("Prompt is too long, and cant accomodate any history") + return f"### Instruction: {general_priming}\n{situational_priming}\n{response_priming}.\n{self.format_history(remaining_words)}\n{language}\n### Question: {question}\n### Response:" + + def character_said(self, message: str, lang: str) -> str: + """ + Function that keeps the history of what the bot is saying + """ + message = BAR_TEXT.sub("", message) + message = message.strip() + if not message: + return + with self.history_lock: + self.history.append((time.time(), "A", message, lang)) + self.update_history() + + def language_prompt(self, lang="en-US"): + lang = LANGUGE_INV_MAP.get(lang, lang) + return f"Answer in {lang}.\n" + + def speech_heard(self, message: str, lang: str): + if message and message.startswith("event."): + return + with self.history_lock: + # Speech is usually heard before it gets callback so we put it back one second in the past + self.history.append((time.time() - 1, "Q", message, lang)) + self.update_history() + + def priming(self, request): + """Update priming statements""" + logger.info("Priming %r...", request.question[:100]) + self.set_config({"situational_prime": request.question}, base=False) + + def update_history(self): + with self.history_lock: + cut_time = time.time() - self.config["keep_history_min"] * 60 + self.history = [h for h in self.history if h[0] > cut_time] + self.history.sort(key=lambda x: x[0]) + + def ask(self, request, prompt): + timeout = request.context.get("timeout") or self.timeout + try: + params = { + "prompt": prompt, + "params": { + "frequency_penalty": self.config["frequency_penalty"], + "presence_penalty": self.config["presence_penalty"], + "temperature": self.config["temperature"], + "max_tokens": self.config["max_tokens"], + "repeat_penalty": self.config["repeat_penalty"], + "stop": ["#"], + }, + } + response = requests.post( + "http://{host}:{port}/chat".format(host=self.host, port=self.port), + json=params, + timeout=timeout, + ) + except Exception as ex: + logger.error(ex) + return "" + if response.status_code == requests.codes.ok: + json = response.json() + if "response" in json and json["response"] and "answer" in json["response"]: + return json["response"]["answer"] + + def chat(self, agent_sid, request): + if agent_sid is None: + logger.warning("Agent session was not provided") + return + response = AgentResponse() + response.agent_sid = agent_sid + response.sid = request.sid + response.request_id = request.request_id + response.response_id = str(uuid.uuid4()) + response.agent_id = self.id + response.lang = request.lang + response.question = request.question + if request.question: + try: + prompt = self._get_prompt(request.question, lang=request.lang) + except Exception as e: + logger.exception("Failed to get prompt: %s", e) + return + logger.info("Prompt: %r, tokens %s", prompt, len(prompt.split())) + answer = self.ask(request, prompt) + if answer: + response.attachment[ + "match_excluded_expressions" + ] = self.check_excluded_expressions(answer) + answer = self.post_processing(answer) + response.answer = answer + self.score(response) + response.end() + return response + + def score(self, response): + response.attachment["score"] = 80 + if response.attachment["match_excluded_expressions"]: + response.attachment["score"] = -1 + response.attachment["blocked"] = True + + if ( + "allow_question_response" in self.config + and not self.config["allow_question_response"] + and "?" in response.text + ): + response.attachment["score"] = -1 + logger.warning("Question response %s is not allowed", response.text) + + if response.attachment.get("blocked"): + logger.warning("Response is blocked") + response.attachment["score"] = -1 diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/llm_chat.py b/modules/ros_chatbot/src/ros_chatbot/agents/llm_chat.py new file mode 100644 index 0000000..f147510 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/llm_chat.py @@ -0,0 +1,261 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import os +import re +import time +import uuid +from threading import Event, Thread +from typing import Union + +import openai +from haipy.memory_manager.memory_agent import MemoryAgent +from langchain.prompts import ChatPromptTemplate + +from ros_chatbot.utils import get_llm + +from .model import AgentRequest, AgentResponse, AgentStreamResponse, LLMAgent + +TOKEN_WITH_DOT = re.compile(r""".*\b(\d+|dr|mr|mrs|ms)\.$""", re.IGNORECASE) +ENDING_PUNCTUATIONS = ["?", ".", "!", "。", "!", "?", ";"] + + +class LLMChatAgent(LLMAgent): + type = "LLMChatAgent" + + def __init__(self, id, lang, model_id, model_kwargs, runtime_config_description): + super().__init__(id, lang, runtime_config_description) + self.default_llm = get_llm(model_id, model_kwargs) + self.alt_llm = None + # OpenAI models require a proxy for certain regions including HK + if model_id.startswith("openai."): + base_url = os.environ.get("OPENAI_PROXY_URL") + if base_url: + model_kwargs["base_url"] = base_url + self.alt_llm = get_llm(model_id, model_kwargs) + self.logger.info(f"Alt LLM: {self.alt_llm}") + self.llm = self.default_llm + if self.llm is None: + raise RuntimeError("The LLM model is not found %r", model_id) + self.model_id = model_id + + def ask_llm( + self, + request: AgentRequest, + prompt_str: str, + response: Union[AgentStreamResponse, AgentResponse], + streaming=False, + ): + """ + Send a prompt to the language model and handle the response. + + Args: + request (AgentRequest): The request object. + prompt_str (str): The prompt string to send to the language model. + response (Union[AgentStreamResponse, AgentResponse]): The response object to store the model's answer. + streaming (bool): Whether to use streaming mode for the response. + + Returns: + str: The final answer or first sentence from the language model. + """ + try: + prompt = ChatPromptTemplate.from_template( + prompt_str, template_format="jinja2" + ) + chain = prompt | self.llm | self.output_parser + + if streaming: + first_sentence_ev = Event() + + def handle_streaming(): + try: + response.stream_error = None + sentence = "" + for ret in chain.stream({}): + if ret: + sentence += ret + if ( + len(sentence.strip()) > 1 + and sentence.strip()[-1] in ENDING_PUNCTUATIONS + and not TOKEN_WITH_DOT.match(sentence) + ): + if not first_sentence_ev.is_set(): + response.answer = sentence.strip() + first_sentence_ev.set() + else: + response.stream_data.put(sentence.strip()) + sentence = "" + response.last_stream_response = time.time() + first_sentence_ev.set() + response.stream_finished.set() + except Exception as e: + response.stream_error = e + first_sentence_ev.set() + pass + + Thread(target=handle_streaming, daemon=True).start() + first_sentence_ev.wait() # Wait for the first sentence to be set + if response.stream_error: + self.logger.warning("Stream error: %s", response.stream_error) + raise response.stream_error + # For openAI alllow max 2 second hiccups between tokens (in case some network issue) + response.last_stream_data_timeout = 2.0 + return response.answer + else: + ret = chain.invoke({}) + if ret: + return ret.strip() + else: + self.logger.warning("No result") + return None + except openai.PermissionDeniedError: + if self.alt_llm is not None: + if self.llm == self.default_llm: + self.llm = self.alt_llm + self.logger.warning("Switching to alternative LLM") + return self.ask_llm( + request, prompt_str, response, streaming=streaming + ) + else: + self.llm = self.default_llm + + def chat(self, agent_sid, request): + if agent_sid is None: + self.logger.warning("Agent session was not provided") + return + streaming = request.allow_stream and self.config["streaming"] + response = AgentStreamResponse() if streaming else AgentResponse() + response.preference = self.config.get("preference", -1) + response.agent_sid = agent_sid + response.sid = request.sid + response.request_id = request.request_id + response.response_id = str(uuid.uuid4()) + response.agent_id = self.id + response.lang = request.lang + response.question = request.question + + if request.question: + try: + format = "llama3" if "llama3" in self.model_id else None + prompt = self.get_prompt_str(request, format=format) + except Exception as e: + self.logger.exception("Failed to get prompt: %s", e) + return + self.logger.info("Prompt %s", prompt) + answer = self.ask_llm(request, prompt, response, streaming=streaming) + if answer: + response.attachment[ + "match_excluded_expressions" + ] = self.check_excluded_expressions(answer) + answer = self.post_processing(answer) + response.answer = answer + self.score(response) + self.handle_translate(request, response) + response.end() + return response + + def score(self, response): + response.attachment["score"] = 100 + if response.attachment["match_excluded_expressions"]: + response.attachment["score"] = -1 + response.attachment["blocked"] = True + + +class ToolCallingLLMChatAgent(LLMChatAgent): + type = "ToolCallingLLMChatAgent" + + def __init__(self, id, lang, model_id, model_kwargs, runtime_config_description): + super().__init__(id, lang, model_id, model_kwargs, runtime_config_description) + self.allow_repeat = True + self.memory_agent = MemoryAgent(self.llm) + self.exclude_context_variables = ["input"] + + def ask_llm( + self, + request: AgentRequest, + prompt_str: str, + response: Union[AgentStreamResponse, AgentResponse], + streaming=False, + ): + """ + Send a prompt to the language model and handle the response. + + Args: + request (AgentRequest): The request object. + prompt_str (str): The prompt string to send to the language model. + response (Union[AgentStreamResponse, AgentResponse]): The response object to store the model's answer. + streaming (bool): Whether to use streaming mode for the response. + + Returns: + str: The final answer or first sentence from the language model. + """ + + try: + self.memory_agent.llm = self.llm + self.memory_agent.prompt = prompt_str + self.memory_agent.session_context = request.session_context + if streaming: + self.memory_agent.enable_placeholder_utterances = True + first_sentence_ev = Event() + + def handle_streaming(): + for result in self.memory_agent.stream( + request.question, + language=request.lang, + ): + if result and result.strip(): + if not first_sentence_ev.is_set(): + first_sentence_ev.set() + response.answer = result.strip() + else: + response.stream_data.put(result.strip()) + response.last_stream_response = time.time() + first_sentence_ev.set() + response.stream_finished.set() + + Thread(target=handle_streaming, daemon=True).start() + first_sentence_ev.wait() # Wait for the first sentence to be set + return response.answer + else: + self.memory_agent.enable_placeholder_utterances = False + return self.memory_agent.query( + request.question, + language=request.lang, + ) + except openai.PermissionDeniedError: + if self.alt_llm is not None: + if self.llm == self.default_llm: + self.llm = self.alt_llm + self.memory_agent.llm = self.llm + self.logger.warning("Switching to alternative LLM") + return self.ask_llm( + request, prompt_str, response, streaming=streaming + ) + else: + self.llm = self.default_llm + + +class OpenAIChatAgent(LLMChatAgent): + type = "OpenAIChatAgent" + + +class LlamaChatAgent(LLMChatAgent): + type = "LlamaChatAgent" + + +class ClaudeChatAgent(LLMChatAgent): + type = "ClaudeChatAgent" diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/model.py b/modules/ros_chatbot/src/ros_chatbot/agents/model.py new file mode 100644 index 0000000..fb84097 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/model.py @@ -0,0 +1,554 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import datetime +import json +import logging +import re +import time +import uuid +from abc import ABCMeta, abstractmethod +from collections import ChainMap +from queue import Queue +from threading import Event +from typing import List + +import rospy +from benedict import benedict +from haipy.chat_history import ChatHistory +from haipy.nlp.translate import TranslateClient +from haipy.text_processing.template_renderer import Renderer +from haipy.utils import LANGUAGE_CODES_NAMES +from langchain_core.output_parsers import StrOutputParser + +from ros_chatbot.utils import ( + DEFAULT_PROMPT_TEMPLATE, + get_current_time_str, + get_named_entities, + remove_puncuation_marks, + to_list, +) + + +class IntentClassifier(object, metaclass=ABCMeta): + @abstractmethod + def detect_intent(self, text, lang): + """Detects intent for the text and language""" + pass + + +class SentimentClassifier(object, metaclass=ABCMeta): + @abstractmethod + def detect_sentiment(self, text, lang): + """Detects sentiment for the text and language""" + pass + + +class AgentRequest(object): + def __init__(self): + self.sid = "" # session id + self.app_id = "chat-ros" + self.request_id = "" + self.time = get_current_time_str() + self.lang = "" + self.question = "" + self.audio = "" # path of the audio if the request is from Speech-to-Text + self.tag = "" # tag for the conversation + self.source = "" + self.context = {} + self.scene = "" + self.user_id = "" # graph user id + self.session_context = None + # Allow stream + self.allow_stream = False + self.hybrid_mode = False + + def __repr__(self): + return '' % ( + self.question, + self.lang, + self.request_id, + ) + + def to_dict(self): + return { + "sid": self.sid, + "app_id": self.app_id, + "request_id": self.request_id, + "time": self.time, + "lang": self.lang, + "question": self.question, + "audio": self.audio, + "tag": self.tag, + "source": self.source, + "context": self.context, + "scene": self.scene, + "user_id": self.user_id, + } + + +class AgentResponse(object): + def __init__(self): + self.sid = "" + self.agent_id = "" + self.request_id = "" + self.response_id = "" + self.agent_sid = "" + self.start_dt = get_current_time_str() + self.end_dt = get_current_time_str() + self.lang = "" + self.question = "" + self.answer = "" + self.trace = "" + self.preference = -1 + self.attachment = {} + + def end(self): + self.end_dt = get_current_time_str() + + def valid(self): + answer = remove_puncuation_marks(self.answer) + return bool(answer) or self.attachment.get("state") == 1 + + def __repr__(self): + if self.valid: + return ( + '' + % ( + self.agent_id, + self.answer, + self.lang, + self.preference, + ) + ) + else: + return '" % (self.__class__, self.id) + + def handle_translate(self, request, response): + """Translates responses to target language in the original request""" + if response.answer: + if isinstance(request, AgentRequestExt): + result = self.translate_client.translate( + response.answer, request.lang, request.original_lang + ) + if result and result["translated"]: + response.attachment["media_question"] = request.question + response.attachment["media_answer"] = response.answer + response.attachment["media_lang"] = request.lang + response.answer = result["text"] + response.lang = request.original_lang + + def check_named_entity(self, text): + entities = get_named_entities(text) + white_entity_list = self.config.get("white_entity_list", []) + white_entity_list = [w.lower() for w in white_entity_list] + if entities: + for entity in entities: + if entity["label"] in ["PERSON", "GPE", "ORG", "DATE"]: + if entity["text"].lower() in white_entity_list: + continue + self.logger.warning( + "Risky named entities detected %r (%s)", + entity["text"], + entity["label"], + ) + return True + return False + + def check_excluded_expressions(self, text): + if "excluded_regular_expressions" in self.config: + for exp in self.config["excluded_regular_expressions"]: + pattern = re.compile(r"%s" % exp, re.IGNORECASE) + if pattern.search(text): + return True + return False + + def check_excluded_question(self, text): + if "excluded_question_patterns" in self.config: + for exp in self.config["excluded_question_patterns"]: + pattern = re.compile(r"%s" % exp, re.IGNORECASE) + if pattern.search(text): + return True + return False + + def post_processing(self, answer): + if answer and "substitutes" in self.config: + for substitute in self.config["substitutes"]: + pattern = re.compile(r"%s" % substitute["expression"], re.IGNORECASE) + repl = substitute["replace"] or "" + _answer = pattern.sub(repl, answer) + if answer != _answer: + self.logger.warning("Replace answer %r to %r", answer, _answer) + answer = _answer + if "*" in answer: + answer = answer.replace("*", "|") + return answer + + +class SessionizedAgent(Agent): + # TODO: life cycle of session or persist session + + @abstractmethod + def new_session(self): + """ + Starts a new session. + + Returns the new session id + """ + pass + + +class ConfigurableAgent(SessionizedAgent): + def __init__(self, id: str, lang: str, runtime_config_description: dict): + super().__init__(id, lang) + self.runtime_config_callback = None + self.runtime_config_description = runtime_config_description + self.chat_history = ChatHistory("default.default.history") + self.renderer = Renderer() + self.exclude_context_variables = [] + self.output_parser = StrOutputParser() + # Default configuration for agent is set as base + super(ConfigurableAgent, self).set_config( + {k: v["default"] for k, v in self.runtime_config_description.items()}, True + ) + + def new_session(self): + return str(uuid.uuid4()) + + def update_server_config(self, config={}): + try: + if callable(self.runtime_config_callback): + # Filter only dynamic configuration updates + config = { + k: v + for k, v in config.items() + if k in self.runtime_config_description + } + self.runtime_config_callback(config) + except Exception: + self.logger.error("Error updating server config") + + +class LLMAgent(ConfigurableAgent): + def __init__(self, id: str, lang: str, runtime_config_description: dict): + super(LLMAgent, self).__init__(id, lang, runtime_config_description) + self.chat_history = ChatHistory("default.default.history") + self.renderer = Renderer() + + self.output_parser = StrOutputParser() + + def reset_session(self): + self.logger.info("ChatGPT chat history has been reset") + + def get_reponse_prompt(self): + try: + return rospy.get_param("/hr/interaction/prompts/response_prompt", "") + except rospy.ServiceException as e: + self.logger.error(e) + return False + + def _parse_prompt_template(self, prompt_template: str) -> List[str]: + """Extract variables from the prompt template""" + from jinja2 import Environment, meta + + env = Environment() + ast = env.parse(prompt_template) + variables = meta.find_undeclared_variables(ast) + prompt_variables = sorted(list(variables)) + return prompt_variables + + def get_prompt_str(self, request: AgentRequest, format=None): + prompt_template = self._get_prompt_template(request) + prompt_variables = self._parse_prompt_template(prompt_template) + if not prompt_variables: + prompt_variables = [ + "input", + "location", + "interlocutor", + "objective", + "situational_prime", + "dynamic_situational_prime", + "general_prime", + "response_prime", + "webui_language", + "past_conversation_summaries", + ] + context = self._build_context(request, format, prompt_variables) + for variable in self.exclude_context_variables: + if variable in context: + del context[variable] + + return self._render_prompt(prompt_template, context) + + def _build_context( + self, request: AgentRequest, format: str, prompt_variables: List[str] + ): + self.chat_history.set_sid(request.sid) + current_time = datetime.datetime.now() + context = { + "input": request.question, + "history": self.chat_history.format_history_text(request.question, format), + "language": LANGUAGE_CODES_NAMES.get(request.lang, request.lang), + "current_date": current_time.strftime("%Y-%m-%d"), + "next_week": (current_time + datetime.timedelta(days=7)).strftime( + "%Y-%m-%d" + ), + "current_time": current_time.strftime("%H:%M"), + "general_prime": self.config.get("general_prime", ""), + "situational_prime": self.config.get("situational_prime", ""), + "dynamic_situational_prime": self.config.get("dynamic_situational_prime", ""), + "agent_id": self.id, + "agent_type": self.type, + } + + response_primes = [] + + if request.session_context: + for key in prompt_variables: + if request.session_context.get(key) and key not in [ + "history", + "input", + ]: # not override some context variables + context[key] = request.session_context.get(key) + context["global_workspace_enabled"] = request.session_context.get( + "global_workspace_enabled", False + ) + context["instant_situational_prompt"] = request.session_context.get( + "instant_situational_prompt", "" + ) + + # 1. Append response prime from session context + if request.session_context.get("response_prime", ""): + response_primes.append( + request.session_context.get("response_prime", "") + ) + + # 2. Append response prime if auto_response_prime is enabled + if self.config.get("auto_response_prime", False): + response_prompt = self.get_reponse_prompt() + if response_prompt: + response_primes.append(response_prompt.strip()) + + # 3. If emotion driven response primer is enabled, append the emotion driven response style + if request.session_context.get("emotion_driven_response_primer", False): + emotion_driven_response_style = request.session_context.get( + "emotion_driven_response_style" + ) + if emotion_driven_response_style: + response_primes.append( + f"Your response style that is influenced by your emotion: {emotion_driven_response_style}" + ) + + context["response_prime"] = "\n".join( + "- " + prime for prime in response_primes + ) + self.logger.info("response_prime: %s", context["response_prime"]) + + self.logger.info("context: %s", context) + return context + + def _render_prompt(self, template, context): + prompt_str = self.renderer.render( + template=template, + context=context, + compact=False, + ) + prompt_str = re.sub("\n{2,}", "\n\n", prompt_str) + prompt_str = prompt_str.replace("{", "{{").replace("}", "}}") + self.logger.info("Prompt: \n%s", prompt_str) + return prompt_str + + def _get_prompt_template(self, request: AgentRequest): + """The order of precedence is: + 1. prompt_template in the agent configuration + 2. prompt_template in the scene context + 3. default prompt template based on model ID + """ + # 1. Check for prompt template in the agent configuration + # Deprecated + prompt_template = self.config.get("prompt_template") + if prompt_template and prompt_template.strip(): + self.logger.info("Using prompt template from agent configuration") + return prompt_template + + # 2. Check for prompt template in the scene context + if request.session_context: + # 2.1. Check for prompt template in the prompt templates in the session context + prompt_templates = request.session_context.get("prompt_templates", []) + global_workspace_enabled = request.session_context.get( + "global_workspace_enabled", False + ) + for template in prompt_templates: + template = json.loads(template) + conditions = template["conditions"] + if ("global_workspace" in conditions) != global_workspace_enabled: + continue + if self.id in conditions or "any_agent" in conditions: + self.logger.warning( + "Using prompt template %r from session context: id %r, conditions %r", + template["name"], + self.id, + conditions, + ) + return template["template"] + + # 2.2. Check for prompt template in the scene context + prompt_template = request.session_context.get("prompt_template") + if prompt_template and prompt_template.strip(): + self.logger.info("Using prompt template from scene context") + return prompt_template + + # 3. Fallback to default prompt template + self.logger.info( + "No prompt template found, using default template based on model ID: %s", + self.model_id, + ) + + model_prefix = self.model_id.split(".", 1)[0] + default_template_name = { + "anthropic": "claude", + "aws-anthropic": "claude", + "aws-meta": "llama", + }.get(model_prefix, "gpt") + + if request.session_context: + prompt_template = request.session_context.get( + f"default_prompt_template_{default_template_name}" + ) + if prompt_template: + return prompt_template + + prompt_template = DEFAULT_PROMPT_TEMPLATE[default_template_name] + request.session_context[ + f"default_prompt_template_{default_template_name}" + ] = prompt_template + return prompt_template diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/qa.py b/modules/ros_chatbot/src/ros_chatbot/agents/qa.py new file mode 100644 index 0000000..59b793c --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/qa.py @@ -0,0 +1,163 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import re +import uuid + +import requests + +from ros_chatbot.utils import search_answer_post_processing + +from .model import Agent, AgentResponse + +logger = logging.getLogger(__name__) + +ILLEGAL_CHARACTER = re.compile( + r"[?~@#^&*()`/<>{}\[\]=+|\\·•ʾ]", flags=re.IGNORECASE + re.UNICODE +) + + +class QAAgent(Agent): + type = "QAAgent" + + def __init__(self, id, lang, host="localhost", port=8802, timeout=2): + super(QAAgent, self).__init__(id, lang) + self.timeout = timeout + self.allow_repeat = True + self.url = "http://{host}:{port}/ask".format(host=host, port=port) + self.keywords_interested = None + + def set_config(self, config, base): + super(QAAgent, self).set_config(config, base) + if "keywords_interested" in self.config: + # the regular expresson matches the sentence begins with any of the + # words in the list + self.keywords_interested = re.compile( + r"%s" % self.config["keywords_interested"], re.IGNORECASE + ) + + def check_question(self, question): + """Checks if the question is what it is interested""" + question = question.lower() + if len(question.split()) <= 2: + return False + if not self.keywords_interested: + return False + match = self.keywords_interested.search(question) + if match: + text = question[match.start() :] + excluded = self.check_excluded_expressions(text) + return not excluded + return match + + def ask(self, request): + page = self.config.get("page", 3) + mininum_score = self.config.get("mininum_score", 0.02) + question = request.question + if question.lower().startswith("so "): + question = question[3:] # remove so + ret = {"answer": "", "confidence": 0} + + timeout = request.context.get("timeout") or self.timeout + try: + response = requests.post( + self.url, + json={"question": question, "page": page}, + timeout=timeout, + ) + except requests.exceptions.ReadTimeout as ex: + logger.error(ex) + return ret + if response.status_code == 200: + json = response.json() + if "answers" in json: + logger.info("QA agent answer %s", json) + answers = json["answers"] + for answer in answers: + if answer["score"] > 0.8: + ret["confidence"] = 90 + answer_text = answer["answer"] + if ( + "answer_sentence" in answer + and len(answer["answer"]) + / max(1, len(answer["answer_sentence"])) + < 0.1 + ): + answer_text = "{}. {}".format( + answer["answer"], answer["answer_sentence"] + ) + ret["answer"] = search_answer_post_processing(answer_text) + return ret + if answer["score"] > mininum_score and answer["answer_sentence"]: + ret["confidence"] = 55 + answer_text = answer["answer_sentence"] + if ( + "answer_sentence" in answer + and len(answer["answer"]) + / max(1, len(answer["answer_sentence"])) + < 0.1 + ): + answer_text = "{}. {}".format( + answer["answer"], answer["answer_sentence"] + ) + ret["answer"] = search_answer_post_processing(answer_text) + return ret + logger.info("QA agent has no answer") + return ret + + def chat(self, agent_sid, request): + if not self.check_question(request.question): + return + + response = AgentResponse() + response.agent_sid = agent_sid + response.sid = request.sid + response.request_id = request.request_id + response.response_id = str(uuid.uuid4()) + response.agent_id = self.id + response.lang = request.lang + response.question = request.question + + try: + answer = self.ask(request) + if answer["answer"]: + response.answer = answer["answer"] + response.attachment["confidence"] = answer["confidence"] + self.score(response) + except Exception as ex: + logger.exception(ex) + + response.end() + return response + + def score(self, response): + if ILLEGAL_CHARACTER.search(response.answer): + response.attachment["score"] = response.attachment["confidence"] - 40 + else: + response.attachment["score"] = response.attachment["confidence"] + + if ( + "allow_question_response" in self.config + and not self.config["allow_question_response"] + and "?" in response.answer + ): + response.attachment["score"] = -1 + logger.warning("Question response %s is not allowed", response.answer) + + if response.attachment.get("blocked"): + response.attachment["score"] = -1 diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/quickchat.py b/modules/ros_chatbot/src/ros_chatbot/agents/quickchat.py new file mode 100644 index 0000000..c126f01 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/quickchat.py @@ -0,0 +1,195 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import random +import re +import uuid +from functools import partial + +import requests + +from .model import AgentResponse, SessionizedAgent + +logger = logging.getLogger(__name__) + +variable_pattern = re.compile(r"\(\(([^()]*)\)\)", flags=re.IGNORECASE) # eg ((User)) + +baseUrl = "https://dedicatedhbqw2e.quickchat.ai" + + +class QuickChatAgent(SessionizedAgent): + type = "QuickChatAgent" + + def __init__(self, id, lang, version, scenario_id, api_key, timeout=2): + super(QuickChatAgent, self).__init__(id, lang) + self.timeout = timeout + self.allow_repeat = True + self.version = version + self.scenario_id = scenario_id + self.api_key = api_key + self.conv_id = None + self.support_priming = True + + self.url = baseUrl + + def set_config(self, config, base): + super(QuickChatAgent, self).set_config(config, base) + + def priming(self, request): + timeout = request.context.get("timeout") or self.timeout + context = request.question + self._priming(context, timeout) + + def _priming(self, context, timeout): + params = { + "version": self.version, + "api_key": self.api_key, + "scenario_id": self.scenario_id, + "context": context, + } + if self.conv_id is not None: + params["conv_id"] = self.conv_id + try: + requests.post( + f"{self.url}/api/hanson/context/", + json=params, + timeout=timeout, + ) + except requests.exceptions.ReadTimeout as ex: + logger.error(ex) + return "" + + def ask(self, request): + question = request.question + + timeout = request.context.get("timeout") or self.timeout + params = { + "version": self.version, + "api_key": self.api_key, + "scenario_id": self.scenario_id, + "text": question, + } + if self.conv_id is not None: + params["conv_id"] = self.conv_id + try: + response = requests.post( + f"{self.url}/chat", + json=params, + timeout=timeout, + ) + except requests.exceptions.ReadTimeout as ex: + logger.error(ex) + return "" + if response.status_code == 200: + json = response.json() + if "reply" in json: + logger.info("Agent answer %s", json) + if "conv_id" in json: + self.conv_id = json["conv_id"] + return json["reply"] + else: + logger.warning("No answre") + return "" + + def new_session(self): + sid = str(uuid.uuid4()) + self.sid = sid + return sid + + def reset(self, sid=None): + self.conv_id = None + + def eval_variable(self, answer, session_context): + def repl(m, user): + var = m.group(1).strip() + if var.lower() == "user": + return user + else: + # delete unknown variable + return "" + + if variable_pattern.search(answer) and session_context is not None: + user = session_context.get("username") + if user is None: + substitutes = self.config.get("substitutes") + if substitutes and "User" in substitutes: + user = random.choice(substitutes["User"]) + if user is None: + user = "" + answer = variable_pattern.sub(partial(repl, user=user), answer) + answer = " ".join(answer.split()) + return answer + + def chat(self, agent_sid, request): + response = AgentResponse() + response.agent_sid = agent_sid + response.sid = request.sid + response.request_id = request.request_id + response.response_id = str(uuid.uuid4()) + response.agent_id = self.id + response.lang = request.lang + response.question = request.question + + try: + answer = self.ask(request) + answer = self.eval_variable(answer, request.session_context) + if answer: + response.answer = answer + response.attachment[ + "match_excluded_expressions" + ] = self.check_excluded_expressions(answer) + response.attachment[ + "match_excluded_question" + ] = self.check_excluded_question(request.question) + response.attachment[ + "risky_named_entity_detected" + ] = self.check_named_entity(answer) + self.score(response) + except Exception as ex: + logger.exception(ex) + + return response + + def score(self, response): + response.attachment["score"] = 80 + if self.version == 0: # GPT-3 + response.attachment["score"] = 90 + if response.attachment.get("match_excluded_expressions"): + response.attachment["score"] = -1 + response.attachment["blocked"] = True + if response.attachment.get("match_excluded_question"): + response.attachment["score"] = -1 + response.attachment["blocked"] = True + if response.attachment.get("risky_named_entity_detected"): + response.attachment["score"] = -1 + response.attachment["blocked"] = True + + response.attachment["score"] += random.randint( + -10, 10 + ) # give it some randomness + + if ( + "allow_question_response" in self.config + and not self.config["allow_question_response"] + and "?" in response.answer + ): + response.attachment["score"] = -1 + logger.warning("Question response %s is not allowed", response.answer) + + if response.attachment.get("blocked"): + response.attachment["score"] = -1 diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/quicksearch.py b/modules/ros_chatbot/src/ros_chatbot/agents/quicksearch.py new file mode 100644 index 0000000..433aa57 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/quicksearch.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- + +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import os +import uuid +from collections import deque + +import numpy as np +import pandas as pd +import requests + +from ..bm25 import BM25 +from .model import Agent, AgentResponse + +logger = logging.getLogger(__name__) + + +class QuickSearchAgent(Agent): + type = "QuickSearchAgent" + + def __init__( + self, + id, + lang, + embedding_host, + embedding_port, + question_header, + answer_header, + episode_header, + csv_corpus, + ): + super(QuickSearchAgent, self).__init__(id, lang) + if isinstance(csv_corpus, str): + df = pd.read_csv(csv_corpus) + elif isinstance(csv_corpus, list): + dfs = [pd.read_csv(csv_file) for csv_file in csv_corpus] + df = pd.concat(dfs) + self.question_corpus = { + i: row[question_header].strip() for i, row in df.iterrows() + } + self.timeout = 2 + self.answer_corpus = {i: row[answer_header].strip() for i, row in df.iterrows()} + self.episodes = {i: row[episode_header] for i, row in df.iterrows()} + self.model = BM25({k: v.split() for k, v in self.question_corpus.items()}) + self.embedding_url = f"http://{embedding_host}:{embedding_port}/sembedding/" + dialog_act_host = os.environ.get("NLU_DIALOGACT_HOST", "127.0.0.1") + dialog_act_port = os.environ.get("NLU_DIALOGACT_PORT", "8210") + self.dialog_act_url = f"http://{dialog_act_host}:{dialog_act_port}/batch-da" + self.responses = deque(maxlen=1) # the last responses + self.similarity_threashold = 0.68 + logger.info("Sentence embedding server url %s", self.embedding_url) + + def get_dialog_act(self, sentences, timeout): + params = {"articles": [{"text": sentence} for sentence in sentences]} + try: + response = requests.post(self.dialog_act_url, json=params, timeout=timeout) + except ( + requests.exceptions.ReadTimeout, + requests.exceptions.ConnectionError, + ) as ex: + logger.error(ex) + return + if response.status_code == 200: + json = response.json() + if not json["error"] and json["results"]: + return json["results"] + + def get_sembedding(self, sentences, timeout): + params = {"articles": [{"text": sentence} for sentence in sentences]} + try: + response = requests.post( + self.embedding_url, + json=params, + timeout=timeout, + ) + except requests.exceptions.ReadTimeout as ex: + logger.error(ex) + return + + if response.status_code == 200: + json = response.json() + if "embeddings" in json: + return json["embeddings"] + + def cos_sim(self, a, b): + return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) + + def run_reflection(self, sid, text, lang): + self.responses.append(text) + + def chat(self, agent_sid, request): + response = AgentResponse() + response.agent_sid = agent_sid + response.sid = request.sid + response.request_id = request.request_id + response.response_id = str(uuid.uuid4()) + response.agent_id = self.id + response.lang = request.lang + response.question = request.question + + timeout = request.context.get("timeout") or self.timeout + service_timeout = max(1, timeout / 2) + if request.question: # and len(request.question.split()) >= 3: + bm25_results = self.model.get_scores(request.question.split()) + logger.info("Found total %s similar sentences", len(bm25_results)) + results = [result for result in bm25_results if result["score"] > 5][ + :10 + ] # top 10 + sentences = [self.question_corpus[result["id"]] for result in results] + dialog_acts = self.get_dialog_act( + sentences + [request.question], service_timeout + ) + if dialog_acts: + *candidate_dialog_acts, question_dialog_act = dialog_acts + else: + logger.error("Can't get dialog acts") + response.trace = "Can't get dialog acts" + return response + embeddings = self.get_sembedding( + sentences + [request.question], service_timeout + ) + if embeddings is None: + logger.error("Can't get sentence embeddings") + response.trace = "Can't get sentence embeddings" + return response + *candidate_embeddings, question_embedding = embeddings + cos_sims = [ + self.cos_sim(candidate, question_embedding) + for candidate in candidate_embeddings + ] + cloeset_similarity = 0 + for result, cos_sim, candicate_dialog_act in zip( + results, cos_sims, candidate_dialog_acts + ): + result["similarity"] = cos_sim + result["sentence"] = self.question_corpus[result["id"]] + result["dialog_act_match"] = ( + candicate_dialog_act["name"] == question_dialog_act["name"] + ) + if cos_sim > cloeset_similarity: + cloeset_similarity = cos_sim + # logger.info("Top 10 searched results %s", results) + results = [ + result + for result in results + if result["similarity"] > self.similarity_threashold + and result["dialog_act_match"] + ] + logger.info( + "The closest similarity is %s and candidate results %s", + cloeset_similarity, + "\n".join([r["sentence"] for r in results]), + ) + if results: + choice = None + # calculate context similarity + # finds the last questions and compare the similarity + if self.responses: + last_resopnse = self.responses[-1] + for result in results: + docid = result["id"] + last_docid = docid - 1 + if ( + last_docid in self.question_corpus + and self.episodes[last_docid] == self.episodes[docid] + ): + last_answer_in_corpus = self.answer_corpus[last_docid] + embedding1, embedding2 = self.get_sembedding( + [last_answer_in_corpus, last_resopnse], service_timeout + ) + cos_sim = self.cos_sim(embedding1, embedding2) + if cos_sim > self.similarity_threashold: + logger.info( + "context similarity %s last_resopnse %r corpus %r", + cos_sim, + last_resopnse, + last_answer_in_corpus, + ) + choice = result + choice["context_match"] = True + break + + if choice is None: + similarities = [result["similarity"] for result in results] + sim_sum = sum(similarities) + probs = [s / sim_sum for s in similarities] + logger.info("Probabilities %s", probs) + choice = results[ + np.random.choice(len(results), 1, p=probs).tolist()[0] + ] + docid = choice["id"] + response.answer = self.answer_corpus[docid] + if choice.get("context_match"): + # boost confidence by 20% if the context matches + response.attachment["confidence"] = min( + 1, choice["similarity"] * 1.2 + ) + else: + response.attachment["confidence"] = choice["similarity"] + response.attachment[ + "match_excluded_question" + ] = self.check_excluded_question(request.question) + logger.info( + "Searched question %r, score %s, answer %r", + self.question_corpus[docid], + choice["score"], + response.answer, + ) + response.attachment[ + "risky_named_entity_detected" + ] = self.check_named_entity(response.answer) + response.answer = self.post_processing(response.answer) + self.score(response) + else: + logger.info("No results are found") + # logger.info("Found less similar sentences %s", bm25_results) + response.trace = "No answer" + else: + response.trace = "Can't answer" + + response.end() + return response + + def score(self, response): + response.attachment["score"] = 50 + if response.attachment.get("match_excluded_question"): + response.attachment["blocked"] = True + if response.attachment["confidence"] > 0.65: + response.attachment["score"] = 60 + if response.attachment["confidence"] > 0.7: + response.attachment["score"] = 80 + if response.attachment["confidence"] > 0.75: + response.attachment["score"] = 85 + if response.attachment["confidence"] > 0.8: + response.attachment["score"] = 90 + if response.attachment.get("risky_named_entity_detected"): + response.attachment["score"] = -1 + response.attachment["blocked"] = True + if ( + "allow_question_response" in self.config + and not self.config["allow_question_response"] + and "?" in response.answer + ): + response.attachment["score"] = -1 + logger.warning("Question response %s is not allowed", response.answer) + + if response.attachment.get("blocked"): + response.attachment["score"] = -1 diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/rosagent.py b/modules/ros_chatbot/src/ros_chatbot/agents/rosagent.py new file mode 100644 index 0000000..1ad7071 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/rosagent.py @@ -0,0 +1,169 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import threading +import uuid +from queue import Empty + +import rospy +from hr_msgs.msg import ChatResponse +from hr_msgs.srv import AgentChat, AgentFeedback + +from ros_chatbot.agents.model import ( + AgentResponse, + AgentStreamResponse, + ConfigurableAgent, +) + +logger = logging.getLogger(__name__) + + +class ROSGenericAgent(ConfigurableAgent): + """ + ROSGenericAgent is a class that represents a generic ROS-based agent. + """ + + type = "ROSGenericAgent" + + def __init__( + self, + id: str, + node: str, + lang: str, + runtime_config_description: int, + ): + super().__init__(id, lang, runtime_config_description) + self.chat_service_name = f"{node}/chat" + self.feedback_service_name = f"{node}/feedback" + self.response_topic_name = f"{node}/responses" + self.chat_proxy = rospy.ServiceProxy(self.chat_service_name, AgentChat) + self.feedback_proxy = rospy.ServiceProxy( + self.feedback_service_name, AgentFeedback + ) + self.responses = {} + self.response_events = {} + + rospy.Subscriber( + self.response_topic_name, ChatResponse, self._response_callback + ) + + def ping(self): + """Check the availability of ROS services.""" + for service in [self.chat_service_name, self.feedback_service_name]: + try: + rospy.wait_for_service(service, 0.5) + except rospy.ROSException as ex: + logger.error("Service %r is not available: %s", service, ex) + return False + return True + + def feedback(self, request_id, chosen, hybrid): + req = AgentFeedback._request_class() + req.request_id = request_id + req.chosen = chosen + req.hybrid = hybrid + try: + return self.feedback_proxy(req) + except Exception as ex: + logger.error("Error during feedback: %s", ex) + + def chat(self, agent_sid, request): + if self.languages and request.lang not in self.languages: + logger.warning("Language %s is not supported", request.lang) + return + if not self.ping(): + return + + req = AgentChat._request_class() + req.text = request.question + req.lang = request.lang + req.session = agent_sid or "" + req.request_id = str(uuid.uuid4()) + + agent_response = AgentStreamResponse() + agent_response.agent_sid = agent_sid + agent_response.sid = request.sid + agent_response.request_id = request.request_id + agent_response.response_id = req.request_id + agent_response.agent_id = self.id + agent_response.lang = request.lang + agent_response.question = request.question + agent_response.preference = self.config.get("preference", -1) + print("Agent preference: ", agent_response.preference) + if request.question: + # Create an event for the response and store it + response_event = threading.Event() + self.responses[req.request_id] = agent_response + self.response_events[req.request_id] = response_event + try: + self.chat_proxy(req) + + except Exception as ex: + logger.error("Error during chat request: %s", ex) + return + + # Wait for the event to be set when the first response is received + response_event.wait(timeout=10.0) + + if not response_event.is_set(): + logger.error("No response received for request_id: %s", req.request_id) + del self.responses[req.request_id] + del self.response_events[req.request_id] + return + + if agent_response.answer: + agent_response.attachment[ + "match_excluded_expressions" + ] = self.check_excluded_expressions(agent_response.answer) + agent_response.answer = self.post_processing(agent_response.answer) + self.score(agent_response) + else: + # No answer makes it clear that response in not good for ranker. + agent_response.attachment["score"] = -1 + agent_response.attachment['blocked'] = True + + print("agent answer ", agent_response.answer) + agent_response.end() + return agent_response + + def _response_callback(self, msg): + """Callback for handling responses from the response topic.""" + if msg.request_id in self.responses: + agent_response = self.responses[msg.request_id] + if msg.text: + # Response events are there for first sentence. + if msg.request_id not in self.response_events: + agent_response.stream_data.put(msg.text) + else: + agent_response.answer = ( + agent_response.answer + " " + msg.text + ).strip() + if msg.request_id in self.response_events: + self.response_events[msg.request_id].set() + del self.response_events[msg.request_id] + if msg.request_id in self.response_events: + self.response_events[msg.request_id].set() + del self.response_events[msg.request_id] + if msg.label == "|end|": + agent_response.stream_finished.set() + + def score(self, response): + response.attachment["score"] = 100 + if response.attachment["match_excluded_expressions"]: + response.attachment["score"] = -1 + response.attachment["blocked"] = True diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/snet.py b/modules/ros_chatbot/src/ros_chatbot/agents/snet.py new file mode 100644 index 0000000..e6487bb --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/snet.py @@ -0,0 +1,122 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import json +import logging +import random +import uuid +from time import time + +from websocket import create_connection + +from ..utils import shorten +from .model import Agent, AgentResponse + +logger = logging.getLogger(__name__) + + +class SNetAgent(Agent): + type = "SNetAgent" + + def __init__(self, id, lang, host="localhost", port=8181, timeout=2): + super(SNetAgent, self).__init__(id, lang) + self.timeout = timeout + + self.event_name = "sophia_dialog" + self.uri = "ws://{}:{}{}".format( + host, port, "/services/" + self.event_name + "/" + ) + self.request_id = 5000 # const + + def ask(self, question): + context = [["user1", question, time()]] + data = { + "context": context, + "request_id": self.request_id, + "event": self.event_name, + } + + ws = create_connection(self.uri, timeout=self.timeout) + try: + ws.send(json.dumps(data)) + response = json.loads(ws.recv()) + if response["event"] == "success": + return response["answer"] + finally: + ws.close() + + def chat(self, agent_sid, request): + response = AgentResponse() + response.agent_sid = agent_sid + response.sid = request.sid + response.request_id = request.request_id + response.response_id = str(uuid.uuid4()) + response.agent_id = self.id + response.lang = request.lang + response.question = request.question + + if request.question: + answer = self.ask(request.question) + if answer: + response.attachment[ + "match_excluded_expressions" + ] = self.check_excluded_expressions(answer) + response.attachment[ + "match_excluded_question" + ] = self.check_excluded_question(request.question) + if "response_limit" in self.config: + answer, res = shorten(answer, self.config["response_limit"]) + if answer: + response.answer = answer + self.score(response) + else: + response.trace = "Can't answer" + return response + + def score(self, response): + response.attachment["score"] = 70 + if response.attachment.get("match_excluded_expressions"): + response.attachment["score"] = -1 + response.attachment["blocked"] = True + if response.attachment.get("match_excluded_question"): + response.attachment["score"] = -1 + response.attachment["blocked"] = True + input_len = len(response.question) + if input_len > 100: + response.attachment["score"] -= 10 # penalty on long input + response.attachment["score"] += random.randint( + -10, 10 + ) # give it some randomness + + if ( + "minimum_score" in self.config + and response.attachment["score"] < self.config["minimum_score"] + ): + logger.info( + "Score didn't pass lower threshold: %s", response.attachment["score"] + ) + response.attachment["score"] = -1 + if ( + "allow_question_response" in self.config + and not self.config["allow_question_response"] + and "?" in response.answer + ): + response.attachment["score"] = -1 + logger.warning("Question response %s is not allowed", response.answer) + + if response.attachment.get("blocked"): + response.attachment["score"] = -1 diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/soultalk.py b/modules/ros_chatbot/src/ros_chatbot/agents/soultalk.py new file mode 100644 index 0000000..919b7db --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/soultalk.py @@ -0,0 +1,289 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import random +import uuid + +import requests +from benedict import benedict + +from ros_chatbot.utils import LANGUAGE_CODE_INV_MAPPING, LANGUAGE_CODE_MAPPING + +from .model import Agent, AgentResponse, SessionizedAgent + + +class SoulTalkAgent(SessionizedAgent): + type = "SoulTalkAgent" + + def __init__(self, id, lang, host, port, runtime_config_description, timeout=2): + super(SoulTalkAgent, self).__init__(id, lang) + self.host = host + self.port = port + if self.host not in ["localhost", "127.0.0.1"]: + self.logger.warning("soultalk server: %s:%s", self.host, self.port) + self.runtime_config_description = runtime_config_description + self.timeout = timeout + self.preferred_topics = [] + self.blocked_topics = [] + self.uid = "default" + self.sid = None + + def set_config(self, config, base): + super(SoulTalkAgent, self).set_config(config, base) + if "preferred_topics" in config: + self.preferred_topics = config["preferred_topics"] + + @Agent.enabled.setter + def enabled(self, enabled): + self._config.maps[1]["enabled"] = enabled + self.config = benedict(self._config) + + def ping(self): + """Agent is disabled if ping fails""" + try: + response = requests.get( + "http://{host}:{port}/status".format(host=self.host, port=self.port), + timeout=max(2, self.timeout), + ) + except Exception as ex: + self.logger.error(ex) + return False + if response.status_code == requests.codes.ok: + json = response.json() + if json["err_no"] == 0: + return True + else: + return False + else: + self.logger.error( + "SoulTalk server %s:%s is not available", self.host, self.port + ) + return False + + def ask(self, request): + self.sid = request.sid + timeout = request.context.get("timeout") or self.timeout + try: + response = requests.post( + "http://{host}:{port}/chat".format(host=self.host, port=self.port), + json={ + "uid": self.uid, + "sid": request.sid, + "text": request.question, + "lang": LANGUAGE_CODE_MAPPING.get(request.lang, request.lang), + "request_id": request.request_id, + }, + timeout=timeout, + ) + except Exception as ex: + self.logger.error(ex) + return "" + if response.status_code == requests.codes.ok: + json = response.json() + if "response" in json and json["response"] and "text" in json["response"]: + return json["response"] + + def run_reflection(self, sid, text, lang): + try: + response = requests.post( + "http://{host}:{port}/reflect".format(host=self.host, port=self.port), + json={ + "uid": self.uid, + "sid": sid, + "text": text, + "lang": LANGUAGE_CODE_MAPPING.get(lang, lang), + }, + timeout=self.timeout, + ) + except Exception as ex: + self.logger.error(ex) + return + if response.status_code == requests.codes.ok: + json = response.json() + if "response" in json and json["response"]: + return json["response"] + + def new_session(self): + return str(uuid.uuid4()) + + def reset_session(self): + self.logger.info("Reset session") + if self.sid is not None: + try: + response = requests.post( + "http://{host}:{port}/reset".format(host=self.host, port=self.port), + json={ + "uid": self.uid, + "sid": self.sid, + }, + timeout=self.timeout, + ) + if response.status_code == requests.codes.ok: + return True + except Exception as ex: + self.logger.error(ex) + self.sid = None + return False + + def chat(self, agent_sid, request): + if agent_sid is None: + self.logger.error("Agent session is missing") + return + if not self.ping(): + self.logger.error( + "SoulTalk server %s:%s is not available", self.host, self.port + ) + return + + response = AgentResponse() + response.agent_sid = agent_sid + response.sid = request.sid + response.request_id = request.request_id + response.response_id = str(uuid.uuid4()) + response.agent_id = self.id + response.lang = request.lang + response.question = request.question + + if request.question: + json = self.ask(request) + if json: + answer = json.get("text") + if answer: + answer = self.post_processing(answer) + response.answer = answer + response.attachment["agent_type"] = self.type + response.attachment["topic"] = json.get("topic", "") + response.attachment["actions"] = json.get("actions") + response.attachment["topic_type"] = json.get("topic_type", "") + response.attachment["line"] = json.get("line", "") + response.attachment["intent"] = json.get("intent", "") + response.attachment["confidence"] = json.get("confidence", 1) + response.attachment["probability"] = json.get("probability") + response.attachment["fallback"] = json.get("fallback") + response.trace = json.get("trace") + response.attachment["tag"] = json.get("tag") + response.attachment["output_context"] = json.get("output_context") + response.attachment["input_context"] = json.get("input_context") + response.attachment["allow_repeat"] = ( + response.attachment["topic_type"] in ["ARF", "Skill"] + or "repeat" in response.attachment["tag"] + ) + # if ( + # response.answer + # and response.answer.startswith("|") + # and response.answer.endswith("|") + # ): + # response.attachment["non-verbal"] = True + + if not response.answer and response.attachment["actions"]: + actions = "+".join( + [action["name"] for action in response.attachment["actions"]] + ) + response.answer = f"|{actions}|" + + if response.attachment["output_context"]: + for output_context in response.attachment["output_context"]: + response.answer = ( + response.answer + f" |context: {output_context}|" + ) + # response could be in different language + lang = json.get("lang") + if lang: + response.lang = LANGUAGE_CODE_INV_MAPPING.get(lang, lang) + self.score(response) + else: + response.trace = "No answer" + else: + response.trace = "Can't answer" + self.handle_translate(request, response) + response.end() + return response + + def set_output_context(self, sid, context, finished): + if not self.ping(): + self.logger.error( + "SoulTalk server %s:%s is not available", self.host, self.port + ) + return + try: + self.logger.info("Set output context %s", context) + for token, lifespan in context.items(): + response = requests.post( + "http://{host}:{port}/set_output_context".format( + host=self.host, port=self.port + ), + json={ + "token": token, + "lifespan": lifespan, + "uid": self.uid, + "sid": sid, + "finished": finished, + }, + timeout=self.timeout, + ) + if response.status_code == requests.codes.ok: + json = response.json() + if json["err_no"] != 0: + self.logger.error("error %s", json["err_msg"]) + except Exception as ex: + self.logger.error("error %s", ex) + + def score(self, response): + response.attachment["score"] = 50 + if response.attachment.get("fallback"): + response.attachment["blocked"] = True + response.attachment["score"] = -1 + else: + if ( + self.preferred_topics + and response.attachment.get("topic") not in self.preferred_topics + ): + response.attachment["blocked"] = True + response.attachment["score"] = -1 + else: + if response.attachment.get("confidence") < 0.25: + if response.attachment["topic_type"] in ["ARF", "Skill"]: + response.attachment["score"] = 90 + else: + response.attachment["score"] = 50 + else: + response.attachment["score"] = 100 + if "llm" in response.attachment["tag"]: + response.attachment["score"] = 100 + input_len = len(response.question) + if input_len > 100: + response.attachment["score"] -= 20 # penalty on long input + probability = response.attachment.get("probability") + if probability is not None: + if probability < random.random(): + self.logger.warning("Probability not passing %s", probability) + response.attachment["score"] = 20 + else: + response.attachment["score"] = response.attachment["score"] * max( + 0.8, probability + ) + if ( + "allow_question_response" in self.config + and not self.config["allow_question_response"] + and "?" in response.answer + ): + response.attachment["score"] = -1 + self.logger.warning("Question response %s is not allowed", response.answer) + + if response.attachment.get("blocked"): + response.attachment["score"] = -1 diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/tg_agent.py b/modules/ros_chatbot/src/ros_chatbot/agents/tg_agent.py new file mode 100644 index 0000000..6f7680b --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/tg_agent.py @@ -0,0 +1,234 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import os +import random +import re +import shutil +import threading +import uuid +from subprocess import PIPE, Popen + +from .model import AgentResponse, SessionizedAgent + +logger = logging.getLogger("hr.ros_chatbot.agents.tg_agent") +cwd = os.path.dirname(os.path.abspath(__file__)) + + +class WordSub(dict): + """All-in-one multiple-string-substitution class.""" + + def _wordToRegex(self, word): + """Convert a word to a regex object which matches the word.""" + if word != "" and word[0].isalpha() and word[-1].isalpha(): + return "\\b%s\\b" % re.escape(word) + else: + return r"\b%s\b" % re.escape(word) + + def _update_regex(self): + """Build re object based on the keys of the current + dictionary. + + """ + self._regex = re.compile("|".join(map(self._wordToRegex, list(self.keys())))) + self._regexIsDirty = False + + def __init__(self, defaults={}): + """Initialize the object, and populate it with the entries in + the defaults dictionary. + + """ + self._regex = None + self._regexIsDirty = True + for k, v in list(defaults.items()): + self[k] = v + + def __call__(self, match): + """Handler invoked for each regex match.""" + return self[match.group(0)] + + def __setitem__(self, i, y): + self._regexIsDirty = True + super(type(self), self).__setitem__(i, y) + + def sub(self, text): + """Translate text, returns the modified text.""" + if self._regexIsDirty: + self._update_regex() + return self._regex.sub(self, text) + + +class TGAgent(SessionizedAgent): + type = "TGAgent" + + def __init__(self, id, lang, api_id, api_hash, session, user_id, timeout=5): + super(TGAgent, self).__init__(id, lang) + self.api_id = str(api_id) + self.api_hash = str(api_hash) + self.session = "/tmp/tg_%s_%s.session" % (id, uuid.uuid1().hex) + shutil.copy(session, self.session) # the session file passed in is read only + self.user_id = str(user_id) + self.timeout = timeout + self.lock = threading.RLock() + + self.person2_subbers = WordSub( + { + "I": "Sophia", + "am": "is", + "my": "Sophia's", + "mine": "Sophia's", + "myself": "Sophia herself", + "I'm": "Sophia is", + "I'd": "Sophia would", + "I'll": "Sophia will", + "I've": "Sophia has", + } + ) + + def _get_question_length(self, question): + if question: + return len(question.split()) + else: + return 0 + + def _ask(self, question, timeout=None): + timeout = timeout or self.timeout + script = os.path.join(cwd, "../../../scripts/tg.py") + if not os.path.isfile(script): + script = "/opt/hansonrobotics/ros/lib/ros_chatbot/tg.py" + cmd = [ + "python", + script, + "chat", + "--api-id", + self.api_id, + "--api-hash", + self.api_hash, + "--session", + self.session, + "--id", + self.user_id, + "--question", + question, + "--timeout", + str(timeout) if timeout else "-1", + ] + + with self.lock: + logger.info("cmd: %s", cmd) + with Popen(cmd, stdout=PIPE) as proc: + answer = proc.stdout.read() + answer = answer.decode("utf-8") + logger.info("id: %s, answer %s", self.id, answer) + return answer + + def set_config(self, config, base): + super(TGAgent, self).set_config(config, base) + + # def feedback(self, response): + # if response.agent_id != self.id: + # # update prime + # text = response.answer + # prime_text = self.person2_subbers.sub(text) + # logger.warning("prime %s, text: %s", prime_text, text) + # if prime_text != text: + # # only update the prime when there is pronoun + # logger.info("prime text %s", prime_text) + # self._ask(prime_text) + + def new_session(self): + """The blenderbot doesn't maintain the session. Whenever it needs to + start a new conversation, it will simply reset the current session""" + sid = str(uuid.uuid4()) + text = self.config.get("prime.text") + if text: + prime_text = " ".join(text) + logger.info("prime text %s", prime_text) + # self._ask("/start") + self._ask(prime_text) + else: + logger.info("no prime text") + return sid + + def chat(self, agent_sid, request, timeout=None): + if agent_sid is None: + logger.warning("Agent session was not provided") + return + if ( + "min_question_length" in self.config + and self._get_question_length(request.question) + < self.config["min_question_length"] + ): + logger.info( + "Ignore short question: %s", self._get_question_length(request.question) + ) + return + + response = AgentResponse() + response.agent_sid = agent_sid + response.sid = request.sid + response.request_id = request.request_id + response.response_id = str(uuid.uuid4()) + response.agent_id = self.id + response.lang = request.lang + response.question = request.question + if timeout is None: + timeout = request.context.get("timeout") # timeout by request + + answer = self._ask(request.question, timeout) + if not answer: + response.trace = "Not responsive" + else: + response.attachment[ + "match_excluded_expressions" + ] = self.check_excluded_expressions(answer) + answer = self.post_processing(answer) + response.answer = answer + self.score(response) + + response.end() + return response + + def score(self, response): + response.attachment["score"] = 60 + if response.attachment["match_excluded_expressions"]: + response.attachment["score"] = -1 + response.attachment["blocked"] = True + else: + input_len = len(response.question) + if input_len > 100: + response.attachment["score"] = 80 + response.attachment["score"] += random.randint( + -10, 10 + ) # give it some randomness + + if ( + "minimum_score" in self.config + and response.attachment["score"] < self.config["minimum_score"] + ): + response.attachment["score"] = -1 + if ( + "allow_question_response" in self.config + and not self.config["allow_question_response"] + and "?" in response.answer + ): + response.attachment["score"] = -1 + logger.warning("Question response is not allowed") + + if response.attachment.get("blocked"): + response.attachment["score"] = -1 diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/translator.py b/modules/ros_chatbot/src/ros_chatbot/agents/translator.py new file mode 100644 index 0000000..4494534 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/translator.py @@ -0,0 +1,198 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import hashlib +import logging +import os +import random +import uuid +from copy import copy + +import requests + +import ros_chatbot.shared as shared + +from .model import AgentResponse, SessionizedAgent + +logger = logging.getLogger("hr.ros_chatbot.agents.translator") + + +class CantoneseTranslator(object): + """https://fanyi-api.baidu.com/api/trans/product/apidoc""" + + def __init__(self): + self.url = "https://fanyi-api.baidu.com/api/trans/vip/translate" + self.appid = os.environ.get("BAIDU_TRANSLATE_APPID") + self.secretKey = os.environ.get("BAIDU_TRANSLATE_SECRETKEY") + + def translate(self, text, target_language): + if self.appid is None: + logger.error("BAIDU_TRANSLATE_APPID is missing") + return + if self.secretKey is None: + logger.error("BAIDU_TRANSLATE_SECRETKEY is missing") + return + salt = random.randint(32768, 65536) + sign = self.appid + text + str(salt) + self.secretKey + sign = hashlib.md5(sign.encode()).hexdigest() + + if target_language == "zh-TW": + target_language = "yue" + if target_language == "zh-CN": + target_language = "yue" + + params = { + "appid": self.appid, + "q": text, + "from": "auto", + "to": target_language, + "salt": str(salt), + "sign": sign, + } + response = requests.get(self.url, params) + ret = {} + if response.status_code == 200: + result = response.json() + if "error_msg" in result: + logger.error(result) + return + elif "trans_result" in result: + ret["translatedText"] = result["trans_result"][0]["dst"] + return ret + + +class TranslatorAgent(SessionizedAgent): + """ + For supported languages see https://cloud.google.com/translate/docs/languages + """ + + type = "TranslatorAgent" + KNOWN_LANGUAGE_CODE = { + "cmn-Hans-CN": "zh-CN", + "en-US": "en", + "yue-Hant-HK": "yue", + } + + def __init__(self, id, language_codes, media_language, media_agent): + """ + parameters + ---------- + media_agent: the agent that does the chat + media_language: the language of the media agent + lang: the language of the translator + """ + super(TranslatorAgent, self).__init__(id, list(language_codes.keys())) + self.language_codes = language_codes + + from google.cloud import translate + + self.client = translate.Client() + self.cantonest_client = CantoneseTranslator() + + if media_agent is None: + raise ValueError("Media agent cannot be None") + self.media_agent = media_agent + self.media_language = media_language + + def new_session(self): + if isinstance(self.media_agent, SessionizedAgent): + return self.media_agent.new_session() + else: + return str(uuid.uuid4()) + + def chat(self, agent_sid, request): + response = AgentResponse() + response.agent_sid = agent_sid + response.sid = request.sid + response.request_id = request.request_id + response.response_id = str(uuid.uuid4()) + response.agent_id = self.id + response.lang = request.lang + response.question = request.question + + try: + result = self.translate(request.question, request.lang, self.media_language) + if result and result["translated"]: + question = result["text"] + media_request = copy(request) + media_request.lang = self.media_language + media_request.question = question + logger.info("Media agent request %s", media_request) + agent_response = self.media_agent.chat(agent_sid, media_request) + if agent_response and agent_response.valid(): + if "actions" in response.attachment: + response.attachment["actions"] = response.attachment["actions"] + if hasattr(self.media_agent, "score"): + self.media_agent.score(agent_response) + response.attachment["score"] = agent_response.attachment[ + "score" + ] + if ( + "topic" not in response.attachment + or response.attachment["topic"] != "language" + ): + # do not translate response from language skill + result = self.translate( + agent_response.answer, self.media_language, request.lang + ) + if result and result["translated"]: + response.answer = result["text"] + response.attachment[ + "media response" + ] = agent_response.answer + else: + logger.warning("No media agent resopnse") + except Exception as ex: + logger.exception(ex) + return + response.end() + return response + + def text2cachekey(self, text, language): + return "%s@%s" % (text.lower().strip(), language) + + def translate(self, text, source_language, target_language): + key = self.text2cachekey(text, target_language) + if key in shared.cache: + logger.info("Using cache") + return shared.cache[key] + ret = {} + + if target_language in self.language_codes: + target_language = self.language_codes[target_language] + elif target_language in self.KNOWN_LANGUAGE_CODE: + target_language = self.KNOWN_LANGUAGE_CODE[target_language] + if source_language == target_language: + ret["text"] = text + ret["translated"] = True + return ret + + logger.warning( + "Translating %r from %r to %r", text, source_language, target_language + ) + if source_language == "yue-Hant-HK" or target_language == "yue": + result = self.cantonest_client.translate(text, target_language) + else: + result = self.client.translate(text, target_language=target_language) + + if result: + ret["text"] = result["translatedText"] + ret["translated"] = True + logger.warning("Translate result %s", result["translatedText"]) + + shared.cache[key] = ret + return ret diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/vector_chat.py b/modules/ros_chatbot/src/ros_chatbot/agents/vector_chat.py new file mode 100644 index 0000000..be76ac5 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/vector_chat.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import random +import uuid + +from haipy.vectordb import ChatVectorDB + +from .model import Agent, AgentResponse + +logger = logging.getLogger(__name__) + + +class VectorChatAgent(Agent): + type = "VectorChatAgent" + + def __init__( + self, + id, + lang, + namespace, + ): + super(VectorChatAgent, self).__init__(id, lang) + self.index = ChatVectorDB(namespace) + self.timeout = 2 + self.similarity_threashold = 0.68 + self.topk = 5 + self.last_conversation_id = None + self.last_response = None + + def chat(self, agent_sid, request): + response = AgentResponse() + response.agent_sid = agent_sid + response.sid = request.sid + response.request_id = request.request_id + response.response_id = str(uuid.uuid4()) + response.agent_id = self.id + response.lang = request.lang + response.question = request.question + + result = self.index.get_answer(request.question, self.topk) + logger.info("Get result %s", result) + if ( + result + and result["score"] > self.similarity_threashold + and result["answers"] + ): + answers = result["answers"] + if self.last_response and self.last_response.question != request.question: + # find answers that are in the same conversation + # assume the answers in the same conversation are better + answers_in_conversation = [ + answer + for answer in result["answers"] + if self.last_response.attachment["conversation_id"] + == answer["conversation_id"] + ] + if answers_in_conversation: + answers = answers_in_conversation + response.attachment["context_match"] = True + answer = random.choice(answers) + response.answer = answer["answer"] + if len(response.answer.split(" ")) > 60: + response.attachment["confidence"] = 0 + response.answer = answer["answer"] + response.attachment["conversation_id"] = answer["conversation_id"] + response.attachment["confidence"] = result["score"] + response.attachment["label"] = answer["label"] + response.attachment["resolver"] = answer["resolver"] + self.last_response = response + self.score(response) + else: + response.trace = "Can't answer" + + response.end() + return response + + def score(self, response): + response.attachment["score"] = 50 + if response.attachment.get("context_match"): + # boost confidence by 20% if the context matches + response.attachment["confidence"] = min( + 1, response.attachment["confidence"] * 1.2 + ) + logger.info("Increase the confidence") + if response.attachment["confidence"] > 0.7: + response.attachment["score"] = 60 + if response.attachment["confidence"] > 0.8: + response.attachment["score"] = 70 + if response.attachment["confidence"] > 0.9: + response.attachment["score"] = 80 + if response.attachment["confidence"] > 1.0: + response.attachment["score"] = 85 + if response.attachment["resolver"] == "human": + response.attachment["score"] = 90 + if response.attachment.get("label") in ["ChatGPT", "GPT3"]: + # boost score by 10% for GPT3 answers + response.attachment["score"] = min(85, response.attachment["score"] * 1.1) + logger.info("Boost the score for GPT3/ChatGPT answers") + if ( + "allow_question_response" in self.config + and not self.config["allow_question_response"] + and "?" in response.answer + ): + response.attachment["score"] = -1 + logger.warning("Question response %s is not allowed", response.answer) + + if response.attachment.get("blocked"): + response.attachment["score"] = -1 diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/xiaoi.py b/modules/ros_chatbot/src/ros_chatbot/agents/xiaoi.py new file mode 100644 index 0000000..1a4eaed --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/xiaoi.py @@ -0,0 +1,160 @@ +# -*- coding: utf-8 -*- + +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import hashlib +import logging +import os +import re +import uuid +from urllib.parse import urlencode +from urllib.request import Request, urlopen + +import six + +from .model import AgentResponse, SessionizedAgent + +logger = logging.getLogger("hr.ros_chatbot.agents.xiaoi") + + +class WebXiaoi(object): + def __init__(self): + app_key_id = "XIAOI_APP_KEY_ID" + app_key_secret = "XIAOI_APP_KEY_SECRET" + self.app_key_id = os.environ.get(app_key_id) + self.app_key_secret = os.environ.get(app_key_secret) + if not self.app_key_id: + raise ValueError("xiaoi app key was not provided") + if not self.app_key_secret: + raise ValueError("xiaoi app secret was not provided") + + def get_headers(self): + realm = "xiaoi.com" + method = "POST" + uri = "/ask.do" + nonce = "0" * 40 + sha1 = hashlib.sha1( + ":".join([self.app_key_id, realm, self.app_key_secret]).encode("utf-8") + ).hexdigest() + sha2 = hashlib.sha1(":".join([method, uri]).encode("utf-8")).hexdigest() + sign = hashlib.sha1(":".join([sha1, nonce, sha2]).encode("utf-8")).hexdigest() + + headers = { + "X-Auth": 'app_key="{}",nonce="{}",signature="{}"'.format( + self.app_key_id, nonce, sign + ) + } + return headers + + def ask(self, userId, question): + if isinstance(question, six.text_type): + question = question.encode("utf-8") + # url = 'http://nlp.xiaoi.com/ask.do' + url = "http://robot.open.xiaoi.com/ask.do" + values = { + "userId": userId, + "question": question, + "type": 0, + "platform": "custom", + } + data = urlencode(values).encode("utf-8") + headers = self.get_headers() + req = Request(url, data, headers) + response = urlopen(req) + answer = response.read() + if isinstance(answer, six.binary_type): + answer = answer.decode("utf-8") + http_url = ( + "http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|" # noqa + "[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+" # noqa + ) + answer = re.sub(http_url, "", answer) + return answer.strip() + + +class XiaoIAgent(SessionizedAgent): + type = "XiaoIAgent" + name_patch = re.compile("(小i|xiaoi|xiao i|小 i)", flags=re.IGNORECASE) + + def __init__(self, id, lang): + super(XiaoIAgent, self).__init__(id, lang) + self.api = WebXiaoi() + + def patch_name(self, text): + return self.name_patch.sub("索菲亚", text) + + def validate_answer(self, text): + if "默认回复" in text: + return False + if "该功能正在开发中" in text: + return False + if "主人" in text: + return False + if "请点击语音键" in text: + return False + if "重复回复" in text: + return False + if "点击此链接" in text: + return False + if "illegalWordReply" in text: + return False + return True + + def new_session(self): + sid = str(uuid.uuid4()) + return sid + + def remove_cmd_tag(self, text): + return re.sub(r"\[CMD\].*\[/CMD\]", "", text).strip() + + def chat(self, agent_sid, request): + response = AgentResponse() + response.agent_sid = agent_sid + response.sid = request.sid + response.request_id = request.request_id + response.response_id = str(uuid.uuid4()) + response.agent_id = self.id + response.lang = request.lang + response.question = request.question + + try: + answer = self.api.ask(agent_sid, request.question) + except Exception as ex: + logger.exception(ex) + return + if not self.validate_answer(answer): + return + answer = self.remove_cmd_tag(answer) + answer = self.patch_name(answer) + response.answer = answer + self.score(response) + response.end() + return response + + def score(self, response): + response.attachment["score"] = self.weight * 100 + if ( + "allow_question_response" in self.config + and not self.config["allow_question_response"] + and ("?" in response.answer or "?" in response.answer) + ): + response.attachment["score"] = -1 + logger.warning("Question response %s is not allowed", response.answer) + + if response.attachment.get("blocked"): + response.attachment["score"] = -1 diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/xiaoice.py b/modules/ros_chatbot/src/ros_chatbot/agents/xiaoice.py new file mode 100644 index 0000000..b158e4f --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/xiaoice.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- + +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import hashlib +import json +import logging +import os +import uuid +from time import time + +import requests + +from .model import AgentResponse, SessionizedAgent + +logger = logging.getLogger("hr.ros_chatbot.agents.xiaoice") + + +class WebAPI(object): + def __init__(self): + self.subscription_key = os.environ.get("XIAOICE_API_SUBSCRIPTION_KEY") + self.app_key = os.environ.get("XIAOICE_API_APP_KEY") + self.url = os.environ.get("XIAOICE_API_BASEURL") + if not self.subscription_key: + raise ValueError("xiaoice subscription key was not provided") + if not self.app_key: + raise ValueError("xiaoice app key was not provided") + + def get_headers(self, timestamp, body): + signature = hashlib.sha512( + (body + self.app_key + str(timestamp)).encode("utf-8") + ) + headers = { + "Content-Type": "application/json", + "subscription-key": self.subscription_key, + "timestamp": timestamp, + "signature": signature.hexdigest(), + } + return headers + + def ask(self, question): + timestamp = str(int(time())) + body = json.dumps( + { + "content": { + "text": question, + "ContentType": "text", + "Metadata": {"Character": "xiaoc"}, + }, + "senderId": str(uuid.uuid4()), + "timestamp": timestamp, + "msgId": str(uuid.uuid4()), + }, + ensure_ascii=False, + ) + headers = self.get_headers(timestamp, body) + + res = requests.post(self.url, data=body.encode("utf-8"), headers=headers) + if res.status_code == 200: + answer = [] + for result in res.json(): + text = result["content"].get("text") + if text: + answer.append(text) + answer = "\n".join(answer) + return answer + + +class XiaoIceAgent(SessionizedAgent): + type = "XiaoIceAgent" + + def __init__(self, id, lang): + super(XiaoIceAgent, self).__init__(id, lang) + self.api = WebAPI() + + def new_session(self): + sid = str(uuid.uuid4()) + return sid + + def chat(self, agent_sid, request): + response = AgentResponse() + response.agent_sid = agent_sid + response.sid = request.sid + response.request_id = request.request_id + response.response_id = str(uuid.uuid4()) + response.agent_id = self.id + response.lang = request.lang + response.question = request.question + + try: + answer = self.api.ask(request.question) + except Exception as ex: + logger.exception(ex) + return + response.answer = answer + self.score(response) + response.end() + return response + + def score(self, response): + response.attachment["score"] = self.weight * 100 + if ( + "allow_question_response" in self.config + and not self.config["allow_question_response"] + and ("?" in response.answer or "?" in response.answer) + ): + response.attachment["score"] = -1 + logger.warning("Question response %s is not allowed", response.answer) + + if response.attachment.get("blocked"): + response.attachment["score"] = -1 diff --git a/modules/ros_chatbot/src/ros_chatbot/agents/youchat.py b/modules/ros_chatbot/src/ros_chatbot/agents/youchat.py new file mode 100644 index 0000000..1e02644 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/agents/youchat.py @@ -0,0 +1,129 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import random +import re +import uuid + +import requests + +from .model import AgentResponse, SessionizedAgent + +logger = logging.getLogger("hr.ros_chatbot.agents.youchat") + + +class YouChatAgent(SessionizedAgent): + type = "YouChatAgent" + + def __init__(self, id, lang, host="localhost", port=8804, timeout=5): + super(YouChatAgent, self).__init__(id, lang) + self.host = host + self.port = port + if self.host not in ["localhost", "127.0.0.1"]: + logger.warning("youchat server: %s:%s", self.host, self.port) + self.timeout = timeout + self.support_priming = True + + def new_session(self): + sid = str(uuid.uuid4()) + self.sid = sid + return sid + + def reset(self): + try: + requests.post( + "http://{host}:{port}/reset".format(host=self.host, port=self.port), + timeout=self.timeout, + ) + except Exception as ex: + logger.error(ex) + + def ask(self, request): + response = None + timeout = request.context.get("timeout") or self.timeout + try: + response = requests.post( + f"http://{self.host}:{self.port}/ask", + json={"question": request.question}, + timeout=timeout, + ) + except Exception as ex: + logger.error("error %s", ex) + return "" + + if response and response.status_code == 200: + json = response.json() + if "error" in json and json["error"]: + logger.error(json["error"]) + elif "answer" in json: + return json["answer"] + + def chat(self, agent_sid, request): + response = AgentResponse() + response.agent_sid = agent_sid + response.sid = request.sid + response.request_id = request.request_id + response.response_id = str(uuid.uuid4()) + response.agent_id = self.id + response.lang = request.lang + response.question = request.question + response.attachment["repeating_words"] = False + + try: + answer = self.ask(request) + if answer: + response.attachment[ + "match_excluded_expressions" + ] = self.check_excluded_expressions(answer) + response.attachment[ + "match_excluded_question" + ] = self.check_excluded_question(request.question) + answer = self.post_processing(answer) + response.answer = answer + self.score(response) + except Exception as ex: + logger.exception(ex) + + response.end() + return response + + def score(self, response): + response.attachment["score"] = 70 + if response.attachment.get("match_excluded_expressions"): + logger.info("Answer %r is not allowed", response.answer) + response.attachment["score"] = -1 + response.attachment["blocked"] = True + if response.attachment.get("match_excluded_question"): + logger.info("Question %r is not allowed", response.question) + response.attachment["score"] = -1 + response.attachment["blocked"] = True + + response.attachment["score"] += random.randint( + -10, 10 + ) # give it some randomness + + if ( + "allow_question_response" in self.config + and not self.config["allow_question_response"] + and "?" in response.answer + ): + response.attachment["score"] = -1 + logger.warning("Question response %s is not allowed", response.answer) + + if response.attachment.get("blocked"): + response.attachment["score"] = -1 diff --git a/modules/ros_chatbot/src/ros_chatbot/bm25.py b/modules/ros_chatbot/src/ros_chatbot/bm25.py new file mode 100644 index 0000000..921b18a --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/bm25.py @@ -0,0 +1,196 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html +"""This module contains function of computing rank scores for documents in +corpus and helper class `BM25` used in calculations. Original algorithm +descibed in [1]_, also you may check Wikipedia page [2]_. + + +.. [1] Robertson, Stephen; Zaragoza, Hugo (2009). The Probabilistic Relevance Framework: BM25 and Beyond, + http://www.staff.city.ac.uk/~sb317/papers/foundations_bm25_review.pdf +.. [2] Okapi BM25 on Wikipedia, https://en.wikipedia.org/wiki/Okapi_BM25 + + + +Examples +-------- + +.. sourcecode:: pycon + + >>> corpus = [ + ... ["black", "cat", "white", "cat"], + ... ["cat", "outer", "space"], + ... ["wag", "dog"] + ... ] + >>> result = get_bm25_weights(corpus) + + +Data: +----- +.. data:: PARAM_K1 - Free smoothing parameter for BM25. +.. data:: PARAM_B - Free smoothing parameter for BM25. +.. data:: EPSILON - Constant used for negative idf of document in corpus. + +""" +import logging +import math + +# https://www.elastic.co/blog/practical-bm25-part-3-considerations-for-picking-b-and-k1-in-elasticsearch +PARAM_K1 = 1.2 +PARAM_B = 0.75 +EPSILON = 0.25 +logger = logging.getLogger(__name__) + + +class BM25(object): + """Implementation of Best Matching 25 ranking function. + + Attributes + ---------- + corpus_size : int + Size of corpus (number of documents). + avgdl : float + Average length of document in `corpus`. + doc_freqs : list of dicts of int + Dictionary with terms frequencies for each document in `corpus`. Words used as keys and frequencies as values. + idf : dict + Dictionary with inversed documents frequencies for whole `corpus`. Words used as keys and frequencies as values. + doc_len : list of int + List of document lengths. + """ + + def __init__(self, docs): + """ + Parameters + ---------- + docs: dict of document of corpus + Given docs. + + """ + index_names = list(docs.keys()) + values = list(docs.values()) + # TODO: tokenize + corpus = [sum([v.split() for v in value], []) for value in values] + self._index_names = index_names + self.corpus_size = len(corpus) + self.avgdl = 0 + self.doc_freqs = [] + self.idf = {} + self.doc_len = [] + self._initialize(corpus) + + def _initialize(self, corpus): + """Calculates frequencies of terms in documents and in corpus. Also computes inverse document frequencies.""" + if not corpus: + logger.warning("Empty corpus") + return + nd = {} # word -> number of documents with word + num_doc = 0 + for document in corpus: + self.doc_len.append(len(document)) + num_doc += len(document) + + frequencies = {} + for word in document: + if word not in frequencies: + frequencies[word] = 0 + frequencies[word] += 1 + self.doc_freqs.append(frequencies) + + for word, freq in frequencies.items(): + if word not in nd: + nd[word] = 0 + nd[word] += 1 + + self.avgdl = float(num_doc) / self.corpus_size + # collect idf sum to calculate an average idf for epsilon value + idf_sum = 0 + # collect words with negative idf to set them a special epsilon value. + # idf can be negative if word is contained in more than half of + # documents + negative_idfs = [] + for word, freq in nd.items(): + idf = math.log(self.corpus_size - freq + 0.5) - math.log(freq + 0.5) + self.idf[word] = idf + idf_sum += idf + if idf < 0: + negative_idfs.append(word) + self.average_idf = float(idf_sum) / len(self.idf) + + eps = EPSILON * self.average_idf + for word in negative_idfs: + self.idf[word] = eps + + def get_score(self, document, index): + """Computes BM25 score of given `document` in relation to item of corpus selected by `index`. + + Parameters + ---------- + document : list of str + Document to be scored. + index : int + Index of document in corpus selected to score with `document`. + + Returns + ------- + float + BM25 score. + + """ + score = 0 + doc_freqs = self.doc_freqs[index] + for word in document: + if word not in doc_freqs: + continue + score += ( + self.idf[word] + * doc_freqs[word] + * (PARAM_K1 + 1) + / ( + doc_freqs[word] + + PARAM_K1 + * (1 - PARAM_B + PARAM_B * self.doc_len[index] / self.avgdl) + ) + ) + return score + + def get_scores(self, document): + """Computes and returns BM25 scores of given `document` in relation to + every item in corpus. + + Parameters + ---------- + document : list of str + Document to be scored. + + Returns + ------- + list of dict + Each contains a doc name and BM25 score. + + """ + scores = [] + for index in range(self.corpus_size): + score = self.get_score(document, index) + name = self._index_names[index] + if score > 0: + name = self._index_names[index] + scores.append({"id": name, "score": score}) + scores = sorted(scores, key=lambda x: x["score"], reverse=True) + return scores + + +if __name__ == "__main__": + import pandas as pd + + df = pd.read_csv( + "data/blenderbot_training_corpora/finetuning2/finetuning_train.csv" + ) + corpus = {"doc%s" % i: row.text.split() for i, row in df.iterrows()} + bm25 = BM25(corpus) + while True: + question = input("question: ") + results = bm25.get_scores(question.split())[:3] + if results: + print("results:", results) + print(corpus[results[0]["id"]]) diff --git a/modules/ros_chatbot/src/ros_chatbot/chat_agent_schedular.py b/modules/ros_chatbot/src/ros_chatbot/chat_agent_schedular.py new file mode 100644 index 0000000..ad4fbfb --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/chat_agent_schedular.py @@ -0,0 +1,300 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import concurrent +import logging +import random +import threading +import time +from concurrent.futures import ThreadPoolExecutor, TimeoutError, as_completed +from itertools import groupby +from queue import Empty, Queue + +from haipy.nlp.translate import TranslateClient + +from ros_chatbot.agents.model import AgentRequestExt +from ros_chatbot.response_resolver import PreferentialResponseResolver + +logger = logging.getLogger(__name__) + + +class ChatAgentSchedular(object): + def __init__(self, agents, translators, timeout=2): + self.agents = agents + self.translators = translators + # Keep track of running agents, and make sure we dont wait unecessarily + self.running_agents_lock = threading.Lock() + # Interrupt event + self.interrupt = None + self.timeout = timeout + self.min_wait_for = 0.0 + self.translate_client = TranslateClient() + + def set_timeout(self, timeout, min_wait_for=0.0): + self.timeout = timeout + self.min_wait_for = min_wait_for + + def expand_translate_requests(self, request): + """expand with translate requests and agents""" + requests = {} + for translator in self.translators: + if ( + request.lang in translator["language_codes"] + and translator["media_language"] != request.lang + ): + result = self.translate_client.translate( + request.question, request.lang, translator["media_language"] + ) + if result and result["translated"]: + media_request = AgentRequestExt() + media_request.sid = request.sid + media_request.request_id = request.request_id + media_request.time = request.time + media_request.lang = translator["media_language"] + media_request.question = result["text"] + media_request.audio = request.audio + media_request.tag = request.tag + media_request.context = request.context + media_request.scene = request.scene + media_request.user_id = request.user_id + media_request.session_context = request.session_context + media_request.original_lang = request.lang + media_request.original_question = request.question + + for agent in self.agents.values(): + if agent.enabled and agent.id in translator["media_agents"]: + requests[agent.id] = media_request # specify agent request + return requests + + def chat(self, request, agent_sessions): + agents = [agent for agent in self.agents.values() if agent.id in agent_sessions] + logger.info("All agents %r", [agent.id for agent in agents]) + + if request.question.startswith(":"): + return + translate_requests = self.expand_translate_requests(request) + translate_agents = [agent for agent in agents if agent.id in translate_requests] + if request.question.lower().startswith("event."): + allowed_agents = [ + "SoulTalkAgent", + ] + agents = [agent for agent in agents if agent.type in allowed_agents] + translate_agents = [ + agent for agent in translate_agents if agent.type in allowed_agents + ] + if not agents: + logger.warning("No soultalk agent") + is_prompt = request.question[0] == "{" and request.question[-1] == "}" + if is_prompt: + agents = [agent for agent in agents if agent.prompt_responses is True] + translate_agents = [ + agent for agent in translate_agents if agent.prompt_responses is True + ] + if not agents: + logger.warning("No prompt agents enabled") + agents = [agent for agent in agents if request.lang in agent.languages] + if translate_agents: + logger.info("Translate agents %s", translate_agents) + agents.extend(translate_agents) + if not agents: + logger.warning("No agents for chat request") + return + else: + logger.info("Using agents %s", [agent.id for agent in agents]) + + if "agent" in request.context and request.context["agent"]: + try: + requested_agents = [ + name.strip() + for name in request.context.pop("agent").split(",") + if name.strip() + ] + agents = [agent for agent in agents if agent.id in requested_agents] + if len(agents) == 0: + logger.error("The agent with id %r is not found", requested_agents) + return + logger.info("Use agents %s", [agent.id for agent in agents]) + except Exception as ex: + logger.error(ex) + return + else: + agents = [agent for agent in agents if agent.enabled] + + def keyfunc(agent): + return 150 if agent.level < 150 else 250 + + agent_batches = {} + levels = [] + agents = sorted(agents, key=keyfunc) + for k, g in groupby(agents, key=keyfunc): + agent_batches[k] = list(g) + levels.append(k) + + levels = sorted(levels) + # before_builtin_agents = [agent for agent in agents if 100 <= agent.level < 200] + # among_builtin_agents = [agent for agent in agents if 200 <= agent.level < 300] + # after_builtin_agents = [agent for agent in agents if 300 <= agent.level < 400] + + for n_batch, level in enumerate(levels, 1): + agent_batch = agent_batches[level] + logger.info( + "Batch %s/%s: agents: (%s) %s", + n_batch, + len(agent_batches), + len(agent_batch), + [agent.id for agent in agent_batch], + ) + for responses in self._chat( + request, agent_batch, agent_sessions, translate_requests + ): + responses = [response for response in responses if response.valid()] + if responses: + logger.info("Yielded valid responses %s", len(responses)) + yield responses + + # Keeps track of agents that are still running + def _agent_chat(self, agent, session, request, running_agents): + # Keep reference to the same dict. Each requests will create new dict + result = None + try: + result = agent.chat(session, request) + except Exception as ex: + logger.exception(ex) + raise ex + finally: + with self.running_agents_lock: + running_agents.pop(agent.id, None) + current_lvl = agent.current_level() + if len(running_agents) == 0: + min_level_waiting = 1000 + else: + min_level_waiting = min( + a.current_level() for a in running_agents.values() + ) + return (result, current_lvl, min_level_waiting) + + def _chat(self, request, agents, agent_sessions, translate_requests): + if not agents: + return [] + + timeout = request.context.get("timeout") # timeout by request + results = Queue() + running_agents = {agent.id: agent for agent in agents} + tasks = [ + ( + self._agent_chat, + agent, + agent_sessions.get(agent.id), + translate_requests.get(agent.id, request), + running_agents, + ) + for agent in agents + ] + + job = threading.Thread(target=self._run_tasks, args=(tasks, results, timeout)) + job.deamon = True + job.start() + waiting_responses = [] + # Keeps track until first responses can be published + start = time.time() + # Flag that its time to publish the responses regardles of queue wait + finished = False # all tasks are finished + publish_responses = request.hybrid_mode # Set true to hybrid + preference_sum = PreferentialResponseResolver.preferences(agents) + while True: + try: + response = results.get(block=False) + if response is None: + finished = True + else: + # Got response but its None + if response[0] is None: + continue + # Preference to be random but why wait, we can randomize before ranking + if preference_sum > 0 and not publish_responses: + chance = random.random() + if chance * preference_sum < response[0].preference: + for w in waiting_responses: + w[0].preference = 0 + publish_responses = True + else: + preference_sum -= max(0, response[0].preference) + # Return preference 11 results immediatly + if response[0].preference > 10: + publish_responses = True + except Empty: + response = None # Response is None + if self.interrupt and self.interrupt.is_set(): + logger.info("Interrupted no more results will be returned") + self.interrupt.clear() + return + # Put results to waiting queue + if response is not None: # Not the end of all threads or timeout + # Gets agent result, agent level, and current minimum level waiting + res, lvl, min_lvl = response + if res is not None: + waiting_responses.append((res, lvl, min_lvl)) + # Check if its time to publish the responses + its_time = ( + (time.time() - start > self.min_wait_for) + or finished + or publish_responses + ) + something_there = len(waiting_responses) > 0 + if something_there and its_time: + min_lvl = max([res[2] for res in waiting_responses]) + # Yield the responses that has no agents of lower level + published_responses = [ + res[0] for res in waiting_responses if res[1] < min_lvl + 1 + ] + # publish + if published_responses: + yield published_responses + waiting_responses = [ + res for res in waiting_responses if res[1] > min_lvl + ] + # Break on the end + if finished: + break + time.sleep(0.02) + + # if there are fallback responses publish for ranking + yield [res[0] for res in waiting_responses] + return + + def _run_tasks(self, tasks, results, timeout=None): + not_graceful = True + timeout = timeout or self.timeout + + with ThreadPoolExecutor(max_workers=20) as executor: + fs = {executor.submit(*task) for task in tasks} + try: + for future in as_completed(fs, timeout=timeout): + try: + response = future.result() + if response is not None: + results.put(response) + except Exception as ex: + logger.exception(ex) + except TimeoutError as ex: + logger.error("Timeout: %s", ex) + if not_graceful: + # https://gist.github.com/clchiou/f2608cbe54403edb0b13 + executor._threads.clear() + concurrent.futures.thread._threads_queues.clear() + finally: + results.put(None) # poison item indicates the process is done diff --git a/modules/ros_chatbot/src/ros_chatbot/chat_server.py b/modules/ros_chatbot/src/ros_chatbot/chat_server.py new file mode 100644 index 0000000..3fa238a --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/chat_server.py @@ -0,0 +1,487 @@ +#!/usr/bin/env python + +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import collections +import datetime +import logging +import os +import threading +import uuid +from queue import Queue +from typing import Dict, List + +import emoji +import haipy.memory_manager as mm + +from ros_chatbot.agents import registered_agents +from ros_chatbot.agents.model import AgentRequest +from ros_chatbot.chat_agent_schedular import ChatAgentSchedular +from ros_chatbot.db import update_published_response, write_request, write_responses +from ros_chatbot.response_ranker import ResponseRanker +from ros_chatbot.response_resolver import ( + PreferentialResponseResolver, + ProbabilisticResponseResolver, + SimpleResponseResolver, +) +from ros_chatbot.safety_classifier import TextSafetyClassifier +from ros_chatbot.session_manager import SessionManager +from ros_chatbot.utils import abs_path, load_agent_config + +from .agents.model import Agent, AgentResponse, AgentStreamResponse + +logger = logging.getLogger(__name__) + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +class ChatServer(object): + def __init__(self): + self.responses = {} + self.requests = {} # id -> request + self.request_docs = {} # id -> request documents + self.response_docs = {} # id -> response documents + # preload information + db_thread = threading.Thread(target=self.db_tasks_thread) + self.db_tasks = Queue() + db_thread.daemon = True + db_thread.start() + self.document_manager = mm.DocumentManager( + mongo_uri=os.environ.get("CLOUD_MONGO_DATABASE_URL") + ) + + config = load_agent_config() + self.agent_specs = config["agents"] + self.agent_configs = { + agent["args"]["id"]: agent.get("config", {}) for agent in config["agents"] + } + self.resolver_type = config.get("resolver_type", "simple") + self.translators = config.get("translators", []) + + # install agents + self.agents: Dict[str, Agent] = {} + for agent_spec in self.agent_specs: + if agent_spec["type"] not in registered_agents: + raise ValueError("Unknown agent type: %s" % agent_spec["type"]) + if agent_spec["type"] == "TranslatorAgent": # ignore + continue + agent = self._create_agent(agent_spec) + if agent: + logger.info("Created agent %s", agent.id) + self.init_config_agent(agent) + self.agents[agent.id] = agent + else: + logger.error("Agent %s was not created", agent_spec["type"]) + + self.ranker = ResponseRanker(self.agents) + # All other resolvers are absolete, as they designed for older agents. + if self.resolver_type == "simple": + self.resolver = PreferentialResponseResolver(self.ranker) + logger.info("Simple resolver") + elif self.resolver_type == "probablistic": + self.resolver = ProbabilisticResponseResolver(self.ranker) + logger.info("Probablistic resolver") + elif self.resolver_type == "preferential": + self.resolver = PreferentialResponseResolver(self.ranker) + logger.info("Preferential resolver") + print("Preferential resolver") + else: + raise ValueError("Unknown resolver type %r", self.resolver_type) + + self.chat_agent_schedular = ChatAgentSchedular(self.agents, self.translators) + self.session_manager = SessionManager(self.agents) + self.text_safety_classifier = TextSafetyClassifier() + + def db_tasks_thread(self): + while True: + task = self.db_tasks.get() + # each task is a list that first element is callable, with other eleme + if task: + try: + task[0](*task[1:]) + except Exception as ex: + logger.error(ex) + + def _create_agent(self, agent_spec): + args = agent_spec["args"] + if agent_spec["type"] not in registered_agents: + raise ValueError("Unknown controller type: %s" % agent_spec["type"]) + + cls = registered_agents[agent_spec["type"]] + + if agent_spec["type"] == "AIMLAgent": + # load agent specs + HR_CHATBOT_WORLD_DIR = os.environ.get("HR_CHATBOT_WORLD_DIR", "") + agent_spec_file = os.path.join(HR_CHATBOT_WORLD_DIR, "agents.yaml") + root_dir = os.path.dirname(os.path.realpath(agent_spec_file)) + args["character_yaml"] = abs_path(root_dir, args["character_yaml"]) + + if "media_agent" in args: + media_agent_spec = args["media_agent"] + if media_agent_spec["type"] in [ + "TranslatorAgent", + "GPT2Agent", + "AI21Agent", + ]: + raise ValueError("Media agent can't be nested") + args["media_agent"] = self._create_agent(media_agent_spec) + if agent_spec["type"] == "TranslatorAgent": + args["media_language"] = media_agent_spec["args"]["lang"] + + try: + agent = cls(**args) + return agent + except Exception as ex: + logger.exception( + "Initializing agent %s with args %s. Error %s", cls, args, ex + ) + + def init_config_agent(self, agent): + """Configure the agent for the first time""" + agent_config = self.agent_configs.get(agent.id) + if agent_config: + agent.set_config(agent_config, base=True) + + def config_agents(self, configs): + """Set agent configs""" + if configs: + for agent_id, config in list(configs.items()): + agent = self.agents.get(agent_id) + if agent: + agent.set_config(config, base=False) + else: + logger.error("Agent %r was not found", agent_id) + + def config_agent_types(self, configs): + """Set agent configs by its type""" + if configs: + for agent_type, config in list(configs.items()): + for agent in self.agents.values(): + if agent.type == agent_type: + agent.set_config(config, base=False) + + def reset_all_agents(self): + for agent in self.agents.values(): + agent.reset_config() + + @property + def installed_agent_types(self): + return {agent.type for agent in self.agents.values()} + + @property + def agent_id_type_mapping(self): + return {agent.id: agent.type for agent in self.agents.values()} + + def write_request_to_mongos(self, request): + try: + request_doc = mm.ChatRequest( + user_id="default", + conversation_id=request.sid, + text=request.question, + lang=request.lang, + audio=request.audio, + context=request.context or {}, + ) + self.document_manager.add_document(request_doc) + self.request_docs[request.request_id] = request_doc + except Exception as ex: + logger.error(ex) + + def new_request( + self, sid, text, lang, audio="", source="", context=None, session_context=None + ) -> AgentRequest: + # TODO: Check session + if not text: + raise ValueError("text is empty") + request = AgentRequest() + request.sid = sid + request.request_id = str(uuid.uuid4()) + request.question = text + request.lang = lang + request.audio = audio + request.source = source + request.context = context or {} + request.session_context = session_context + + try: + self.db_tasks.put([write_request, AttrDict(request.to_dict())]) + # write_request(request) + except Exception as ex: + print(ex) + logger.error("Can't wrrite the request to DB %s", ex) + + self.requests[request.request_id] = request + self.write_request_to_mongos(request) + + return request + + def on_switch_language(self, from_language, to_language): + for agent in self.agents.values(): + if hasattr(agent, "on_switch_language"): + agent.on_switch_language(from_language, to_language) + + def write_responses_to_mongodb(self, responses): + try: + for response in responses: + response_doc = mm.ChatResponse( + text=response.answer, + lang=response.lang, + agent_id=response.agent_id, + conversation_id=response.sid, + attachment=response.attachment, + request=self.request_docs.get(response.request_id), + trace=response.trace, + ) + self.response_docs[response.response_id] = response_doc + except Exception as ex: + logger.error(ex) + + def record_responses(self, responses: List[AgentResponse]): + self.write_responses_to_mongodb(responses) + self.db_tasks.put([write_responses, responses]) + for response in responses: + self.responses[response.response_id] = response + + def early_stop(self, fast_score, responses): + """Should the batch chat tasks stop early?""" + if fast_score > 0: + for response in responses: + agent = self.agents.get(response.agent_id) + score = self.ranker.score(agent, response) + if ( + score >= fast_score + and not response.attachment.get("non-verbal") + and not isinstance(response, AgentStreamResponse) + ): + logger.warning( + "Stopped early with agent %s score %s", response.agent_id, score + ) + return True + return False + + def find_external_used_responses(self, responses): + """Finds the response with True state in its attachment + True state means the response has been handled by external component + so the response with this state has to be executed for consistance. + """ + return [ + response + for response in responses + if "state" in response.attachment and response.attachment.get("state") == 1 + ] + + def check_safety(self, response): + if self.text_safety_classifier.classify(response.answer, response.lang): + return True + else: + logger.warning("Response %s is unsafe", response) + return False + + def demojize(self, response): + response.answer = emoji.demojize(response.answer) + return response + + def chat_with_ranking(self, request): + agent_sessions = self.session_manager.agent_sessions() + if not agent_sessions: + logger.error("No agent sessions") + return + + for responses in self.chat_agent_schedular.chat(request, agent_sessions): + if responses: + responses = [self.demojize(response) for response in responses] + responses = [ + response for response in responses if self.check_safety(response) + ] + if responses: + responses = self.ranker.rank(responses) + try: + self.record_responses(responses) + except Exception as ex: + logger.error("Failed to write responses to DB %s", ex) + if self.find_external_used_responses(responses): + logger.info("Found external handled response") + break + yield responses + + def chat_with_resolving(self, request, fast_score=0, interrupt_event=None): + """ + fast_score: the lower bound of response score the agents + + Returns the resolved response and all the other responses + """ + agent_sessions = self.session_manager.agent_sessions() + if not agent_sessions: + logger.error("No agent sessions") + return + self.chat_agent_schedular.interrupt = interrupt_event + # Wait for all responses from agent_schedular, agent schedular would report all responses then they become available, and no agent is still thinking with higher level. + all_responses = [] + # self.chat_agent_schedular + for responses in self.chat_agent_schedular.chat(request, agent_sessions): + if responses: + responses = [self.demojize(response) for response in responses] + responses = [ + response for response in responses if self.check_safety(response) + ] + if responses: + all_responses += responses + if self.find_external_used_responses(responses): + break + if self.early_stop(fast_score, responses): + break + + if all_responses: + logger.info("Got %s responses in total", len(all_responses)) + try: + self.record_responses(all_responses) + except Exception as ex: + logger.error("Failed to write responses to DB %s", ex) + external_responses = self.find_external_used_responses(all_responses) + if external_responses: + logger.info("Found external handled response") + response = self.resolver.resolve(external_responses) + else: + response = self.resolver.resolve(all_responses) + if response: + response.attachment["published"] = True + self.publish(response.response_id, resolver_type=self.resolver.type) + # the first response is the published response + return [response] + [ + r for r in all_responses if not r.attachment.get("published") + ] + else: + logger.info("No responses. Request %r", request.question) + + def add_record(self, text): + self.ranker.add_record(text) + + def reset(self, sid): + return self.session_manager.reset(sid) + + def get_context(self, sid): + agent_sessions = self.session_manager.agent_sessions() + if not agent_sessions: + logger.error("No agent sessions") + return + + context = collections.defaultdict(dict) + for agent_id, session_id in agent_sessions.items(): + agent = self.agents.get(agent_id) + if ( + agent + and agent.config.get("share_context") + and hasattr(agent, "get_context") + and isinstance(agent.get_context, collections.Callable) + ): + agent_context = agent.get_context(session_id) or {} + for k, v in list(agent_context.items()): + context[k] = v + logger.info("Got agent %s context %s", agent.id, context) + return context + + def set_context( + self, session_id, context: dict, output: bool = False, finished: bool = True + ): + agent_sessions = self.session_manager.agent_sessions() + if not agent_sessions: + logger.error("No agent sessions") + return + + if output: + for agent_id, _ in agent_sessions.items(): + agent = self.agents.get(agent_id) + if ( + agent + and agent.config.get("share_context") + and hasattr(agent, "set_output_context") + and isinstance(agent.set_output_context, collections.Callable) + ): + if finished: + logger.info("Finish output context: %s %s", agent_id, context) + else: + logger.info("Set output context: %s %s", agent_id, context) + agent.set_output_context(session_id, context, finished) + else: + for agent_id, _ in agent_sessions.items(): + agent = self.agents.get(agent_id) + if ( + agent + and agent.config.get("share_context") + and hasattr(agent, "set_context") + and isinstance(agent.set_context, collections.Callable) + ): + agent.set_context(session_id, context) + logger.info("Set context: %s %s", agent_id, context) + + def set_timeout(self, timeout, min_wait_for=0.0): + self.chat_agent_schedular.set_timeout(timeout, min_wait_for) + + # set agent timeout + for agent in self.agents.values(): + if hasattr(agent, "timeout"): + agent.timeout = timeout + + def publish(self, response_id, **kwargs): + """Publishes the response""" + try: + response_doc = self.response_docs.get(response_id) + if response_doc: + response_doc.published_at = datetime.datetime.utcnow() + self.document_manager.add_document(response_doc) + self.db_tasks.put([update_published_response, response_id, kwargs]) + logger.info("Wrote published response to DB successfully") + except Exception as ex: + logger.error("Failed to write published response to DB %s", ex) + + def update_response_document(self, response_id, text): + response_doc = self.response_docs.get(response_id) + if response_doc: + logger.info("Updating response document %r with %r", response_doc, text) + response_doc.text = text + self.document_manager.add_document(response_doc) + + def run_reflection(self, sid, text, lang): + results = [] + logger.info("Reflecting on %r", text) + for agent in self.agents.values(): + if hasattr(agent, "run_reflection"): + _results = agent.run_reflection(sid, text, lang) + if _results: + results += _results + return results + + def feedback(self, response, hybrid): + context = {} + for agent in self.agents.values(): + if hasattr(agent, "feedback"): + feedback_result = agent.feedback( + response.request_id, response.agent_id == agent.id, hybrid + ) + if feedback_result and isinstance(feedback_result, dict): + context.update(feedback_result) + return context + + def get_main_agent(self, lang): + for agent in self.agents.values(): + if agent.config.get("role") == "main" and lang in agent.languages: + return agent diff --git a/modules/ros_chatbot/src/ros_chatbot/context_manager.py b/modules/ros_chatbot/src/ros_chatbot/context_manager.py new file mode 100644 index 0000000..8e7dff0 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/context_manager.py @@ -0,0 +1,93 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import os +import threading + +from ros_chatbot.utils import get_ip, get_location + +logger = logging.getLogger(__name__) + + +class ContextManager: + def __init__(self): + self._session_context = None + + @property + def session_context(self): + return self._session_context + + @session_context.setter + def session_context(self, session_context): + self._session_context = session_context + self.preload_context() + self.reset_context() + + def preload_context(self): + """ + Preload necessary context data for the session. + + This method is responsible for initializing and loading any required + context data for the session. It sets up geolocation information, + client IP address, and other relevant parameters in the session context. + """ + location = None + location_thread = threading.Thread(target=lambda: get_location()) + location_thread.start() + location_thread.join(timeout=2) + if location_thread.is_alive(): + logger.warning("Geolocation retrieval timed out") + else: + location = get_location() + if location: + logger.info("Geolocation info %r", location) + location_str = " ".join( + filter(None, [location.get("neighborhood"), location.get("city")]) + ) + if location_str: + logger.info("Set geolocation %s", location_str) + self._session_context["geo_location"] = location_str + self._session_context["location"] = location_str + os.environ["LOCATION"] = location_str + + ip = None + ip_thread = threading.Thread(target=lambda: get_ip()) + ip_thread.start() + ip_thread.join(timeout=2) + if ip_thread.is_alive(): + logger.warning("IP retrieval timed out") + else: + ip = get_ip() + if ip: + logger.info("Set client IP %s", ip) + self._session_context["client_ip"] = ip + os.environ["IP"] = ip + + def reset_context(self): + self._session_context["turns"] = 0 + self._session_context["total_turns"] = 0 + self._session_context["done_steps"] = [] + self._session_context.proxy.delete_param("arf.events") + self._session_context.proxy.delete_param("arf.scenes") + self._session_context.proxy.delete_param("interlocutor") + self._session_context.proxy.delete_param("block_chat") + self._session_context["state"] = "" # reset state + self._session_context.user_context[ + "current_session" + ] = self._session_context.sid + self._session_context.user_context.proxy.expire("current_session", 72000) diff --git a/modules/ros_chatbot/src/ros_chatbot/data_loader.py b/modules/ros_chatbot/src/ros_chatbot/data_loader.py new file mode 100644 index 0000000..e2b6152 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/data_loader.py @@ -0,0 +1,175 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +from pathlib import Path +from typing import Dict, List + +import yaml +from haipy.schemas.airtable_schemas import OperationSceneRecord +from haipy.utils import to_list + +from ros_chatbot.schemas import Scene, SceneEvent +from ros_chatbot.utils import load_sheet_meta + +logger = logging.getLogger(__name__) + + +class DataLoader: + def __init__(self, cms_dir: str, session_context): + self.cms_dir = cms_dir + self.session_context = session_context + self.presets = {} + self.scenes = {} + self.scenes["default"] = Scene(name="default", type="preset", default=True) + + def load_operation_scenes(self) -> List[dict]: + scene_file = Path(self.cms_dir) / "airtable-operation-scenes.yaml" + scenes = [] + if scene_file.exists(): + with open(scene_file) as f: + scenes = yaml.safe_load(f) + scenes = [OperationSceneRecord(**scene) for scene in scenes] + logger.info("Loaded %d scenes from %s", len(scenes), scene_file) + else: + logger.warning("Scene file %s does not exist", scene_file) + + loaded_scenes = [] + for scene in scenes: + fields = scene.fields + variables = { + "general_prime": fields.GeneralPrimer, + "situational_prime": fields.SituationalPrimer, + "response_prime": fields.ResponsePrimer, + "objective": fields.Objective, + "location": fields.Location, + "prompt_template": fields.PromptTemplate, + } + # Filter prompts to include only specific keys and non-empty values + prompts = { + k: v + for k, v in variables.items() + if k in ["general_prime", "situational_prime", "response_prime"] and v + } + # Filter variables to include only non-empty values + variables = {k: v for k, v in variables.items() if v} + + loaded_scene = { + "name": fields.Name, + "type": "preset", + "variables": variables, + "tts_mapping": fields.TTSMapping or "", + "asr_context": fields.ASRContext or "", + "knowledge_base": fields.KnowledgeBase or "", + "prompts": prompts, + } + loaded_scenes.append(loaded_scene) + logger.info( + "Loaded scene: %s with variables: %s", loaded_scene["name"], variables + ) + return loaded_scenes + + def load_prompt_presets(self) -> Dict[str, dict]: + preset_file = Path(self.cms_dir) / "prompt_presets.yaml" + presets = {} + if preset_file.is_file(): + with open(preset_file) as f: + presets_data = yaml.safe_load(f) + if "presets" in presets_data: + presets.update(presets_data["presets"]) + logger.info( + "Loaded %d prompt presets from %s", len(presets), preset_file + ) + else: + logger.info("Prompt preset file %s is not found", preset_file) + return presets + + def load_arf_scene_sheets(self) -> Dict[str, Scene]: + sheets = load_sheet_meta(Path(self.cms_dir) / "arf_sheets") + + scenes = {} + for sheet in sheets: + events = [ + SceneEvent(**{"scene": sheet.scene, "arf_event": arf_event}) + for arf_event in sheet.arf_events + ] + scene = Scene( + **{ + "name": sheet.scene, + "default": sheet.header.get("DefaultScene", False), + "conditions": to_list(sheet.header.get("SceneCondition", [])), + "variables": sheet.header.get("Variables", {}), + "asr_context": sheet.header.get("ASRContext", ""), + "tts_mapping": sheet.header.get("TTSMapping", ""), + } + ) + scenes[scene.name] = { + "scene": scene, + "events": events, + } + logger.info("Loaded ARF scenes") + return scenes + + def load_all_data(self): + """Loads all necessary data including prompt presets, ARF scene sheets, and operation scenes""" + + # load prompt presets + presets = self.load_prompt_presets() + self.presets.update(presets) + self.scenes.update( + { + key: Scene( + name=key, + type="preset", + variables=value.get("context", {}), + knowledge_base=value.get("knowledge_base", ""), + ) + for key, value in self.presets.items() + } + ) + logger.info("Scenes %s", self.scenes) + + # load arf scenes + scenes_data = self.load_arf_scene_sheets() + + self.session_context.proxy.delete_param("arf.events") + self.session_context.proxy.delete_param("arf.scenes") + self.session_context.proxy.delete_param("interlocutor") + + for scene_name, data in scenes_data.items(): + events = data["events"] + scene = data["scene"] + if events: + self.session_context.proxy.set_param( + f"arf.events.{scene_name}", + [event.model_dump_json() for event in events], + ) + self.scenes[scene.name] = scene + + self.session_context.proxy.set_param( + "arf.scenes", [scene.model_dump_json() for scene in self.scenes.values()] + ) + + # load operation scenes + operation_scenes = self.load_operation_scenes() + for scene in operation_scenes: + key = scene["name"] + self.presets[key] = {} + self.presets[key]["title"] = scene["name"] + self.presets[key]["prompts"] = scene["prompts"] + self.scenes[scene["name"]] = Scene(**scene) + logger.info("All data loaded successfully") diff --git a/modules/ros_chatbot/src/ros_chatbot/db.py b/modules/ros_chatbot/src/ros_chatbot/db.py new file mode 100644 index 0000000..48b874a --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/db.py @@ -0,0 +1,273 @@ +# -*- coding: utf-8 -*- + +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import datetime +import hashlib +import logging +import os +import socket +import uuid +from contextlib import contextmanager + +import six +from haipy.db_models import ChatRequest, ChatResponse, ConvInsight +from sqlalchemy import create_engine, text +from sqlalchemy.orm import sessionmaker + +logger = logging.getLogger("hr.ros_chatbot.db") + + +def hash_string(string): + if isinstance(string, six.text_type): # convert utf-8 to ascii + string = string.encode("utf-8") + m = hashlib.md5(string) + return m.hexdigest() + + +def get_hostid(): + return hash_string(str(uuid.UUID(int=uuid.getnode()))) + + +def get_hostname(): + return socket.gethostname() + + +hostname = get_hostname() +hostid = get_hostid() + + +@contextmanager +def session_scope(): + """Provide a transactional scope around a series of operations.""" + session = Session() + try: + yield session + session.commit() + except Exception as ex: + logger.exception(ex) + session.rollback() + raise + finally: + session.close() + + +@contextmanager +def remote_session_scope(): + """Provide a transactional scope around a series of operations.""" + session = RemoteSession() + try: + yield session + session.commit() + except Exception as ex: + logger.exception(ex) + session.rollback() + raise + finally: + session.close() + + +def _get_request_record(request): + attachment = {} + if request.source: + attachment["source"] = request.source + if request.audio: + attachment["audio"] = request.audio + if request.tag: + attachment["tag"] = request.tag + if request.scene: + attachment["scene"] = request.scene + if request.context: + attachment["context"] = request.context + attachment["location"] = os.environ.get("LOCATION", "") + attachment["ip"] = os.environ.get("IP", "") + attachment["hostname"] = hostname + attachment["hostid"] = hostid + return ChatRequest( + app_id=request.app_id, + request_id=request.request_id, + created_at=request.time, + user_id=request.user_id, + conversation_id=request.sid, + text=request.question, + lang=request.lang, + **attachment, + ) + + +def _get_conv_insight(record): + return ConvInsight( + conversation_id=record["conversation_id"], + type=record["type"], + insight=record["insight"], + created_at=record["created_at"], + ) + + +def write_request(request): + """Writes chatbot request to database""" + if not request: + return + with session_scope() as session: + session.add(_get_request_record(request)) + if write_to_remote: + with remote_session_scope() as session: + session.add(_get_request_record(request)) + + +def _get_response_records(responses): + records = [] + for response in responses: + attachment = response.attachment if response.attachment else {} + attachment["robot"] = os.environ.get("ROBOT_NAME", "") + attachment["body"] = os.environ.get("ROBOT_BODY", "") + record = ChatResponse( + request_id=response.request_id, + response_id=response.response_id, + conversation_id=response.sid, + created_at=response.end_dt, + agent_id=response.agent_id, + text=response.answer, + lang=response.lang, + trace=response.trace, + **attachment, + ) + records.append(record) + return records + + +def write_responses(responses): + """Writes chatbot responses to database""" + if not responses: + return + with session_scope() as session: + for record in _get_response_records(responses): + session.add(record) + if write_to_remote: + with remote_session_scope() as session: + for record in _get_response_records(responses): + session.add(record) + + +def write_conv_insight(record): + """Writes conversation insight to database""" + if not record: + return + with session_scope() as session: + session.add(_get_conv_insight(record)) + if write_to_remote: + with remote_session_scope() as session: + session.add(_get_conv_insight(record)) + + +def update_published_response(response_id: str, attachment: dict): + """Updates the response with published state""" + with session_scope() as session: + response = ( + session.query(ChatResponse) + .filter(ChatResponse.response_id == response_id) + .one_or_none() + ) + if response: + response.published_at = datetime.datetime.utcnow() + response.attachment.update(attachment) + session.add(response) + else: + logger.warning( + "Failed to look up published responses. response_id %s", response_id + ) + if write_to_remote: + with remote_session_scope() as session: + response = ( + session.query(ChatResponse) + .filter(ChatResponse.response_id == response_id) + .one_or_none() + ) + if response: + response.published_at = datetime.datetime.utcnow() + response.attachment.update(attachment) + session.add(response) + else: + logger.warning( + "Failed to look up published responses. response_id %s", response_id + ) + + +def get_chat_stream(conversation_id: str): + with session_scope() as session: + query = text( + f"SELECT * FROM chat_stream where ConversationID='{conversation_id}'" + ) + logger.warning("Query %s", query) + results = session.execute(query).all() + return [r._asdict() for r in results] + + +root_password = os.environ.get("MYSQL_ROOT_PASSWORD") +database = os.environ.get("MYSQL_DATABASE") +remote_password = os.environ.get("MYSQL_REMOTE_PASSWORD") +remote_user = os.environ.get("MYSQL_REMOTE_USER") +remote_host = os.environ.get("MYSQL_REMOTE_HOST") +remote_port = os.environ.get("MYSQL_REMOTE_PORT") +remote_db_schema_version = os.environ.get("MYSQL_SCHEMA_VERSION") + +db_url = f"mysql+mysqldb://root:{root_password}@127.0.0.1:3306/{database}?charset=utf8" +remote_db_url = f"mysql+mysqldb://{remote_user}:{remote_password}@{remote_host}:{remote_port}/{database}?charset=utf8" + +# https://docs.sqlalchemy.org/en/13/faq/connections.html#mysql-server-has-gone-away +# https://docs.sqlalchemy.org/en/13/core/pooling.html#dealing-with-disconnects +connect_timeout = 3 +engine = create_engine( + db_url, + echo=False, + pool_pre_ping=True, + connect_args={"connect_timeout": connect_timeout}, +) +remote_engine = create_engine( + remote_db_url, + echo=False, + pool_pre_ping=True, + connect_args={"connect_timeout": connect_timeout}, +) + +RemoteSession = sessionmaker(bind=remote_engine) +Session = sessionmaker(bind=engine) +write_to_remote = False + +if os.environ.get("WRITE_TO_REMOTE_DB") == "1": + db_version = "unknown" + try: + with remote_engine.connect() as con: + rs = con.execute("SELECT version_num FROM alembic_version") + for row in rs: + db_version = row[0] + break + except Exception as ex: + logger.error(ex) + if db_version == remote_db_schema_version: + write_to_remote = True + logger.error("Remote DB schema verified") + else: + logger.error( + "Remote DB schema version %r, expect %r", + db_version, + remote_db_schema_version, + ) +else: + logger.warning("Writing to remote DB is disabled") + write_to_remote = False diff --git a/modules/ros_chatbot/src/ros_chatbot/ddr_node.py b/modules/ros_chatbot/src/ros_chatbot/ddr_node.py new file mode 100644 index 0000000..f0b431b --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/ddr_node.py @@ -0,0 +1,125 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import re +from typing import Any, List + +from ddynamic_reconfigure_python.ddynamic_reconfigure import DDynamicReconfigure +from pydantic import BaseModel + + +class Enum(BaseModel): + name: str + type: str + value: Any + description: str + + +class DDRNode: + """ + Dynamic reconfigure node, allow to make parameters in the children constructors similar to ROS params + """ + + def __init__(self, namespace=None, callback=None): + self.__callback = callback + self._dd_cfg = {} + self.__ddr = DDynamicReconfigure(namespace) + + def ddstart(self): + self.__ddr.start(self.__dd_callback) + + def __dd_callback(self, config, level=None): + if self.__callback: + config = self.__callback(config, level) + self._dd_cfg = config + return config + + def _get_cfg_entry(self, name, default): + return self._dd_cfg.get(name, default) + + def new_param( + self, name, description, default=None, min=None, max=None, edit_method="" + ): + type = self.__get_type(default) + self.__ddr.add(name, type, 0, description, default, min, max, edit_method) + setattr( + __class__, + name, + property(lambda self=self, x=name, d=default: self._get_cfg_entry(x, d)), + ) + + # allolws update configuration + def update_configuration(self, params): + self.__ddr.dyn_rec_srv.update_configuration(params) + + def enum(self, options): + enum = [] + if isinstance(options, list): + options = {o: o for o in options} + if not isinstance(options, dict): + raise TypeError("Enum options must be a dict or list") + + for k, v in options.items(): + # value name pairs + c = self.__ddr.const(self.__name_from_str(k), self.__get_type(v), v, k) + enum.append(c) + return self.__ddr.enum(enum, "enum") + + def add_enums(self, enums: List[Enum]) -> str: + """ + Add multiple enum constants and create an enum edit_method. + """ + enum_list = [ + self.__ddr.const(enum.name, enum.type, enum.value, enum.description) + for enum in enums + ] + return self.__ddr.enum(enum_list, "enum") + + def build_enum(self, enum: list): + """ + allow int and string anums for ddynrec. enum is array of name,value, description + """ + if not enum: + return "" + result = [ + self.__ddr.const( + e[0], "str" if isinstance(e[1], str) else "int", e[1], e[2] + ) + for e in enum + if len(e) == 3 + ] + return self.__ddr.enum(result, "enum") + + @staticmethod + def __get_type(value): + type_mapping = {int: "int", float: "double", str: "str", bool: "bool"} + value_type = type(value) + if value_type in type_mapping: + return type_mapping[value_type] + else: + raise TypeError(f"Unsupported type for value: {value_type.__name__}") + + @staticmethod + def __name_from_str(input_str): + # Replace any non-letter and non-number character with a single underscore + output_str = re.sub("[^a-zA-Z0-9]+", "_", input_str).lower()[:30] + + return output_str + + @property + def ddr(self): + return self.__ddr diff --git a/modules/ros_chatbot/src/ros_chatbot/handlers.py b/modules/ros_chatbot/src/ros_chatbot/handlers.py new file mode 100644 index 0000000..89ac32f --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/handlers.py @@ -0,0 +1,924 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import os +import time +from abc import ABCMeta, abstractmethod +from datetime import date, timezone +from typing import List, Optional + +import haipy.memory_manager as mm +from haipy.chat_history import ChatHistory +from haipy.parameter_server_proxy import EventListener, UserSessionContext +from haipy.scheduler import init as scheduler_init +from haipy.scheduler.schemas import ( + EmotionalContext, + GoalContext, + InterestContext, + PhysiologicalContext, +) +from haipy.scheduler.schemas.enums import DriverStatus +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_openai import ChatOpenAI + +from ros_chatbot.intention_manager import IntentionManager +from ros_chatbot.schemas import SceneContext + +mm.init(os.environ["CLOUD_MONGO_DATABASE_URL"]) +scheduler_init(os.environ["CLOUD_MONGO_DATABASE_URL"]) + +intention_manager = IntentionManager() + + +class Tune(BaseModel): + style: str = Field( + description="The response style can be academic, casual, decisive, for kids, serious, inquisitive, etc." + ) + length: int = Field(description="What should be the word count for the response?") + + +class Plan(BaseModel): + """Plan to follow in future""" + + steps: List[str] = Field( + description="different steps to follow, should be in sorted order but do not include the order number" + ) + + +class CompletedTasks(BaseModel): + steps: List[str] = Field(description="The tasks that have been completed") + + +class Attribute(BaseModel): + attribute: str = Field(description="The attribute name") + value: str = Field(description="The extracted attribute value") + + +class Attributes(BaseModel): + attributes: List[Attribute] = Field( + description="List of extracted attribute objects" + ) + + +class EventHandler(metaclass=ABCMeta): + def __init__( + self, session_context: UserSessionContext, scene_context: SceneContext + ): + """Initialize the event handler. + + Args: + session_context: Context for the current user session + scene_context: Context for the current scene + """ + # Set default values for trigger keys and timer interval if not defined by subclass + self.trigger_keys = getattr(self, "trigger_keys", []) + self.timer_interval = getattr(self, "timer_interval", None) + + self.session_context = session_context + self.scene_context = scene_context + self.logger = logging.getLogger( + f"hr.ros_chatbot.handlers.{self.__class__.__name__}" + ) + + def set_event_listener(self, event_listener: EventListener): + for key in self.trigger_keys: + event_listener.on_key_change(key, self) + if self.timer_interval: + event_listener.on_timer(self.timer_interval, self) + + @property + def input(self): + return self.session_context.get("input", "") + + @abstractmethod + def handle(self, event) -> bool: + """return True if the event was handled, False otherwise""" + pass + + +class LLMEventHandler(EventHandler): + def __init__( + self, + session_context: UserSessionContext, + scene_context: SceneContext, + prompt_template: str, + ): + super().__init__(session_context, scene_context) + self.prompt = ChatPromptTemplate.from_template( + prompt_template, template_format="jinja2" + ) + self.session_context = session_context + self.chat_history = ChatHistory(f"{session_context.ns}.history") + + @property + def history(self): + self.chat_history.set_history_key(f"{self.session_context.ns}.history") + return self.chat_history.format_history_text() + + +class TuneRegulator(LLMEventHandler): + def __init__( + self, session_context: UserSessionContext, scene_context: SceneContext + ): + self.trigger_keys = ["situational_prime"] + prompt = """Evaluate the situation and determine what style and length that fit best. + + The situation + {{situational_prime}} + """ + super(TuneRegulator, self).__init__(session_context, scene_context, prompt) + self.pipeline = self.prompt | ChatOpenAI( + model="gpt-4o", temperature=0.6 + ).with_structured_output(Tune) + + def handle(self, event) -> bool: + if "situational_prime" not in self.session_context: + return False + context = {"situational_prime": self.session_context["situational_prime"]} + result = self.pipeline.invoke(context) + if result: + self.session_context["response_length"] = result.length + self.session_context["response_style"] = result.style + self.logger.warning("Update response tune %s", result) + return True + + +class EmotionalHandler(LLMEventHandler): + def __init__( + self, session_context: UserSessionContext, scene_context: SceneContext + ): + self.trigger_keys = ["physiological_context", "turns"] + prompt = ( + "You are a humanoid robot with sophisticated emotional capabilities. Your emotional state is defined by several key dimensions:\n" + "- Sadness vs Happiness (-1 to 1)\n" + "- Excitement vs Calmness (-1 to 1)\n" + "- Anxiety vs Confidence (-1 to 1)\n" + "- Hostility vs Friendliness (-1 to 1)\n" + "- Boredom vs Interest (-1 to 1)\n\n" + "Analyze the conversation and your current physiological state to determine your emotional state along these dimensions.\n" + "If any dimension becomes significantly negative (below -0.5), suggest specific actions to help restore emotional balance.\n\n" + "{% if physiological_context %}\n" + "Your current physiological state:\n" + "{{physiological_context}}\n" + "{% endif %}\n" + "\n" + "{% if history %}\n" + "Recent conversation history:\n" + "{{history}}\n" + "{% endif %}\n" + "\n" + "Provide:\n" + "1. A numerical assessment for each emotional dimension\n" + "2. An overall description of your emotional state\n" + "3. Effective behavioral reactions to restore emotional balance if needed\n" + ) + super().__init__(session_context, scene_context, prompt) + self.pipeline = self.prompt | ChatOpenAI( + model="gpt-4o", temperature=0.6 + ).with_structured_output(EmotionalContext) + + def handle(self, event) -> bool: + context = { + "history": self.history, + "physiological_context": self.session_context.get("physiological_context"), + } + result = self.pipeline.invoke(context) + if result: + context = EmotionalContext(**result) + self.session_context["emotion_context"] = str(context) + self.logger.warning("Update emotion context %s", context) + # fires when the emotion valence is smaller than the threshold + if context.valence < 0.5: + intention_manager.add_driver(context) + return True + + +class PhysiologicalHandler(LLMEventHandler): + def __init__( + self, session_context: UserSessionContext, scene_context: SceneContext + ): + self.trigger_keys = ["body_state"] + prompt = ( + "You are an advanced autonomous social robot with sophisticated self-monitoring and self-repair capabilities.\n" + "Your integrated sensor system continuously monitors:\n" + "- Servo positions and movements\n" + "- Component temperatures\n" + "- System load and resource usage\n" + "- Power levels and consumption\n" + "- Overall system stability\n\n" + "Analyze the following self-diagnostic data and make sure to provide:\n" + "1. A clear assessment of your physiological state\n" + "2. Recommended actions to take to recover from any issues\n\n" + "Current self-diagnostic data:\n" + "{{body_state}}\n" + ) + super().__init__(session_context, scene_context, prompt) + self.pipeline = self.prompt | ChatOpenAI( + model="gpt-4o", temperature=0.9 + ).with_structured_output(PhysiologicalContext) + + def handle(self, event) -> bool: + if self.session_context.get("body_state"): + context = {"body_state": self.session_context.get("body_state")} + result = self.pipeline.invoke(context) + if result: + context = PhysiologicalContext(**result) + self.session_context["physiological_context"] = str(context) + self.logger.warning("Update physiological context %s", context) + if context.valence < 0.5: + intention_manager.add_driver(context) + return True + + +class Planner(LLMEventHandler): + def __init__( + self, session_context: UserSessionContext, scene_context: SceneContext + ): + self.trigger_keys = ["objective"] + prompt_template = ( + "For the given objective, come up with a simple step by step plan.\n" + "This plan should involve individual tasks, that if executed correctly will yield the correct answer. Do not add any superfluous steps.\n" + "The result of the final step should be the final answer. Make sure that each step has all the information needed - do not skip steps.\n\n" + "# Situation\n" + "{{situation}}\n\n" + "# Objective\n" + "{{objective}}\n" + ) + super().__init__(session_context, scene_context, prompt_template) + self.pipeline = self.prompt | ChatOpenAI( + model="gpt-4o", temperature=0.6 + ).with_structured_output(Plan) + + def handle(self, event) -> bool: + context = { + "situation": self.session_context.get("situational_prime"), + "objective": self.session_context.get("objective"), + } + result = self.pipeline.invoke(context) + if result: + self.session_context["planned_steps"] = result.steps + self.logger.info("Planned steps %s", result.steps) + + return True + + +class Evaluator(LLMEventHandler): + def __init__( + self, session_context: UserSessionContext, scene_context: SceneContext + ): + self.trigger_keys = ["turns"] + prompt_template = ( + "You are a conversational AI. You have a list of tasks to cover in your " + "conversation with the user. Based only on the conversation, list the tasks in the plan " + "that have been considered completed.\n\n" + "# Your plan\n" + "{{planned_steps}}\n\n" + "# History\n" + "{{history}}\n" + ) + super().__init__(session_context, scene_context, prompt_template) + self.pipeline = self.prompt | ChatOpenAI( + model="gpt-4o", temperature=0.6 + ).with_structured_output(CompletedTasks) + + def handle(self, event) -> bool: + planned_steps = "\n".join( + [ + f"{i}. {task}" + for i, task in enumerate( + self.session_context.get("planned_steps", []), 1 + ) + ] + ) + self.chat_history.set_history_key(f"{self.session_context.ns}.history") + if self.chat_history.is_empty(): + self.logger.warning("No history to evaluate plan") + return True + context = {"planned_steps": planned_steps, "history": self.history} + if planned_steps: + result = self.pipeline.invoke(context) + if result: + self.session_context["done_steps"] = ( + self.session_context.get("done_steps", []) + result.steps + ) + self.logger.warning("Completed steps %s", result.steps) + return True + + +class Replanner(LLMEventHandler): + def __init__( + self, session_context: UserSessionContext, scene_context: SceneContext + ): + self.trigger_keys = ["done_steps"] + prompt_template = ( + "You are conversional AI.\n" + "You are in this situation: {{situational_prime}}\n" + "{% if body_state %}\nYour body state: {{body_state}}\n{% endif %}\n" + "{% if emotion_context %}\nYour emotion state: {{emotion_context}}\n{% endif %}\n" + "Your objective was: {{objective}}\n" + "It is essential that you prioritize your needs. The needs from your physical body is the highest priority." + "Then your emotion state is the second priority, and finally your objective is the third priority.\n\n" + "Your original plan was:\n" + "{{planned_steps}}\n\n" + "You have currently done the follow tasks:\n" + "{{done_steps}}\n\n" + "Update your plan accordingly." + ) + super().__init__(session_context, scene_context, prompt_template) + self.pipeline = self.prompt | ChatOpenAI( + model="gpt-4o", temperature=0.6 + ).with_structured_output(Plan) + + def handle(self, event) -> bool: + planned_steps = "\n".join( + [ + f"{i}. {task}" + for i, task in enumerate( + self.session_context.get("planned_steps", []), 1 + ) + ] + ) + done_steps = "\n".join( + [ + f"{i}. {task}" + for i, task in enumerate(self.session_context.get("done_steps", []), 1) + ] + ) + body_state = self.session_context.get("body_state") + situational_prime = self.session_context.get("situational_prime") + objective = self.session_context.get("objective") + emotion_context = self.session_context.get("emotion_context") + context = { + "planned_steps": planned_steps, + "done_steps": done_steps, + "objective": objective, + "situational_prime": situational_prime, + "body_state": body_state, + "emotion_context": emotion_context, + } + result = self.pipeline.invoke(context) + if result: + self.session_context["planned_steps"] = result.steps + self.logger.warning("Replanned steps %s", result.steps) + return True + + +class GoalPlanner(LLMEventHandler): + """Create a goal based on the objective""" + + def __init__( + self, session_context: UserSessionContext, scene_context: SceneContext + ): + self.trigger_keys = ["objective"] + prompt_template = ( + "You are a humanoid robot with advanced goal planning capabilities. Your task is to create a detailed, achievable plan for accomplishing objectives.\n\n" + "Current Date: {{date}}\n" + "Objective: {{objective}}\n\n" + "As an autonomous robot, you should:\n" + "1. Analyze the objective carefully and identify the key components required for success\n" + "2. Break down the objective into 3-5 concrete, sequential subgoals that build upon each other\n" + "3. For each subgoal:\n" + " - Make it specific and measurable (e.g. 'Gather 3 data points' rather than 'Collect data')\n" + " - Consider any prerequisites or dependencies\n" + " - Set realistic completion criteria\n" + "4. If the objective has time constraints:\n" + " - Set appropriate deadlines for the overall goal\n" + " - Allow buffer time for unexpected challenges\n" + "5. Evaluate the overall goal:\n" + " - Importance (0-1): How critical is this to your function and purpose?\n" + " - Ease (0-1): How achievable is this given your capabilities?\n" + " - Motivation (0-1): How aligned is this with your core directives?\n\n" + "Format your response as a GoalContext object containing:\n" + "- Name: A concise title for the goal\n" + "- Description: Detailed explanation of what needs to be accomplished\n" + "- Subgoals: List of specific tasks with clear success criteria\n" + "- Deadline: Target completion date/time if time-sensitive\n" + "- Metrics: Numerical scores for importance, ease, and motivation\n" + ) + super().__init__(session_context, scene_context, prompt_template) + self.pipeline = self.prompt | ChatOpenAI( + model="gpt-4o", temperature=0.6 + ).with_structured_output(GoalContext) + + def handle(self, event) -> bool: + context = { + "objective": self.session_context.get("objective"), + "date": date.today().isoformat(), + } + print(context) + result = self.pipeline.invoke(context) + if result: + goal_context = GoalContext(**result) + # convert local time to UTC + goal_context.deadline = ( + goal_context.deadline.astimezone(timezone.utc) + if goal_context.deadline + else None + ) + goal_context.status = DriverStatus.PENDING + self.session_context["goal_context"] = goal_context + intention_manager.add_driver(goal_context) + return True + + +class ConversationGoalCreator(LLMEventHandler): + """Create a goal based on the conversation history""" + + def __init__( + self, session_context: UserSessionContext, scene_context: SceneContext + ): + self.trigger_keys = ["end_scene"] + prompt_template = ( + "You are a humanoid robot tasked with planning how to achieve goals effectively.\n" + "You are in this situation: {{situational_prime}}\n" + "{% if history %}\n" + "Recent conversation history:\n" + "{{history}}\n" + "{% endif %}\n\n" + "Instructions:\n" + "1. Analyze the conversation history and identify any explicit or implicit goals, needs, or commitments\n" + "2. Create a clear, actionable goal that addresses the key points from the conversation\n" + "3. Break down the goal into 2-4 concrete subgoals that can be tracked and measured\n" + "4. Consider any time constraints or deadlines mentioned\n" + "5. Rate the following on a scale of 0-1:\n" + " - Importance: How critical is this goal?\n" + " - Ease: How achievable is this goal?\n" + " - Motivation: How motivated are you to pursue this goal?\n\n" + "Provide your response as a GoalContext object with:\n" + "- A clear description of the overall goal\n" + "- A prioritized list of subgoals\n" + "- Deadline (if mentioned or implied)\n" + "- Metrics for importance, ease, and motivation\n" + ) + super().__init__(session_context, scene_context, prompt_template) + self.pipeline = self.prompt | ChatOpenAI( + model="gpt-4o", temperature=0.6 + ).with_structured_output(GoalContext) + + def handle(self, event) -> bool: + context = { + "situational_prime": self.session_context.get("situational_prime"), + "history": self.history, + } + result = self.pipeline.invoke(context) + if result: + goal_context = GoalContext(**result) + goal_context.status = DriverStatus.PENDING + self.session_context["goal_context"] = goal_context + intention_manager.add_driver(goal_context) + return True + + +class GoalEvaluator(LLMEventHandler): + """Evaluate goals""" + + class Evaluation(BaseModel): + goals_completed: List[int] = Field( + description="The index of the goals that have been completed" + ) + goals_cancelled: List[int] = Field( + description="The index of the goals that are cancelled" + ) + goals_dormant: List[int] = Field( + description="The index of the goals that need to be put on hold" + ) + goals_active: List[int] = Field( + description="The index of the goals that are active" + ) + reason: str = Field( + description="The reason for the decision of goals to be completed, cancelled, dormant or active" + ) + + def __init__( + self, session_context: UserSessionContext, scene_context: SceneContext + ): + self.timer_interval = 5 + prompt_template = ( + "You are a humanoid robot with advanced goal planning capabilities. Your task is to evaluate the status and progress of multiple goals.\n\n" + "Current Situation: {{situational_prime}}\n\n" + "{% if goals %}\n" + "Active Goals:\n" + "{{goals}}\n" + "{% endif %}\n\n" + "Recent Conversation Context:\n" + "{{history}}\n\n" + "Instructions:\n" + "1. Carefully analyze each goal and determine its current status:\n" + " - COMPLETED: Goal has been fully achieved\n" + " - CANCELLED: Goal is no longer relevant or achievable\n" + " - DORMANT: Goal should be temporarily paused\n" + " - ACTIVE: Goal is currently being worked on\n\n" + "2. For each goal, consider:\n" + " - Progress made so far\n" + " - Current relevance to situation\n" + " - Resource availability\n" + " - Dependencies on other goals\n" + " - Time constraints\n" + " - Changes in context or priorities\n\n" + "3. Provide clear reasoning for each status change, especially for:\n" + " - Why certain goals are considered complete\n" + " - Why goals need to be cancelled\n" + " - Why goals should become dormant\n" + " - Why goals should become active\n\n" + "Return the indices of goals for each status category along with detailed reasoning.\n" + ) + super().__init__(session_context, scene_context, prompt_template) + self.pipeline = self.prompt | ChatOpenAI( + model="gpt-4o", temperature=0.6 + ).with_structured_output(GoalEvaluator.Evaluation) + + def handle(self, event) -> bool: + drivers = intention_manager.get_drivers() + goals = [ + f"{i}. {driver.name}: {driver.description}" + for i, driver in enumerate(drivers) + ] + context = { + "situational_prime": self.session_context.get("situational_prime"), + "goals": goals, + "history": self.history, + } + result = self.pipeline.invoke(context) + if result: + self.logger.warning("Goal evaluation %s", result) + for driver in drivers: + driver.context.status = DriverStatus.PENDING + for index in result.goals_completed: + driver = drivers[index] + driver.context.status = DriverStatus.COMPLETED + for index in result.goals_cancelled: + driver = drivers[index] + driver.context.status = DriverStatus.CANCELLED + for index in result.goals_dormant: + driver = drivers[index] + driver.context.status = DriverStatus.DORMANT + for index in result.goals_active: + driver = drivers[index] + driver.context.status = DriverStatus.ACTIVE + # for driver in drivers: + # driver.save() + return True + + +class ConflictResolver(LLMEventHandler): + """Detects conflicts among goals, interests, deep drives etc, and resolve them""" + + class Resolution(BaseModel): + chosen_goal: List[int] = Field(description="The index of the goal to be chosen") + reason: str = Field(description="The reason for the chosen goal") + + def __init__( + self, session_context: UserSessionContext, scene_context: SceneContext + ): + self.timer_interval = 5 + prompt_template = ( + "You are a conflict resolution agent for a humanoid robot, responsible for prioritizing and selecting between competing goals.\n\n" + "Key Decision Rules:\n" + "1. Deep drives (fundamental needs/values) take highest priority\n" + "2. Consider urgency, importance, and resource constraints\n" + "3. Evaluate potential impact on long-term objectives\n" + "4. Account for current emotional and physiological state\n" + "5. Ensure chosen goals align with core robot directives\n\n" + "Current Situation: {{situational_prime}}\n\n" + "{% if goals %}\n" + "Active Goals Under Consideration:\n" + "{{goals}}\n" + "{% endif %}\n\n" + "Instructions:\n" + "1. Analyze each goal's priority level and type (deep drive vs standard goal)\n" + "2. Evaluate potential conflicts and dependencies between goals\n" + "3. Select the most appropriate goal(s) to pursue\n" + "4. Provide clear reasoning for your selection\n\n" + "Which goal(s) should be prioritized and why?" + ) + super().__init__(session_context, scene_context, prompt_template) + self.pipeline = self.prompt | ChatOpenAI( + model="gpt-4o", temperature=0.6 + ).with_structured_output(ConflictResolver.Resolution) + + def handle(self, event) -> bool: + drivers = intention_manager.get_drivers() + goals = [ + f"{i}. {driver.name}: {driver.description}" + + ( + "\n Subgoals:\n - " + + "\n - ".join(str(subgoal) for subgoal in driver.context.subgoals) + if driver.context.subgoals + else "" + ) + for i, driver in enumerate(drivers) + ] + context = { + "situational_prime": self.session_context.get("situational_prime"), + "goals": goals, + } + result = self.pipeline.invoke(context) + if result: + self.logger.warning("Conflict resolution %s", result) + for driver in drivers: + driver.context.status = DriverStatus.PENDING + for index in result.chosen_goal: + driver = drivers[index] + driver.context.status = DriverStatus.ACTIVE + for driver in drivers: + driver.save() + return True + + +class GoalUpdator(EventHandler): + def __init__( + self, session_context: UserSessionContext, scene_context: SceneContext + ): + self.timer_interval = 5 + super().__init__(session_context, scene_context) + + def handle(self, event) -> bool: + goals = intention_manager.prioritize_drivers() + goals = [ + f"[Status: {driver.context.status}, Level: {driver.level}] {driver.name}: {driver.description}" + for i, driver in enumerate(goals) + ] + self.session_context["goals"] = goals + self.logger.warning("Updated goals %s", goals) + return True + + +class InterestHandler(LLMEventHandler): + def __init__( + self, session_context: UserSessionContext, scene_context: SceneContext + ): + self.timer_interval = 5 + prompt_template = ( + "You are a humanoid robot.\n" + "You are in this situation: {{situational_prime}}\n" + "{% if body_state %}\nYour body state: {{body_state}}\n{% endif %}\n" + "{% if emotion_context %}\nYour emotion state: {{emotion_context}}\n{% endif %}\n" + "{% if history %}\n" + "The conversation\n" + "{{history}}\n" + "{% endif %}\n" + ) + super().__init__(session_context, scene_context, prompt_template) + self.pipeline = self.prompt | ChatOpenAI( + model="gpt-4o", temperature=0.6 + ).with_structured_output(InterestContext) + + def handle(self, event) -> bool: + context = { + "history": self.history, + "situational_prime": self.session_context.get("situational_prime"), + "body_state": self.session_context.get("body_state"), + "emotion_context": self.session_context.get("emotion_context"), + } + result = self.pipeline.invoke(context) + if result: + context = InterestContext(**result) + self.session_context["interest_context"] = str(context) + self.logger.warning("Update interest state %s", context) + if context.interest_valence > 0.5: + intention_manager.add_driver(context) + return True + + +class PersonObjectUpdator(EventHandler): + def __init__( + self, session_context: UserSessionContext, scene_context: SceneContext + ): + self.trigger_keys = ["interlocutor"] + super(PersonObjectUpdator, self).__init__(session_context, scene_context) + + def handle(self, event) -> bool: + """ + Handle the event to update the person objects based on the interlocutors' names. + """ + interlocutors = self.session_context.get("interlocutor", "").split(",") + interlocutors = [name.strip() for name in interlocutors if name.strip()] + + if not interlocutors: + # Reset person objects + self.session_context["person_objects"] = [] + self.scene_context.scene_document.person_ids = [] + self.scene_context.update_scene_document() + return True + + person_objects = [] + person_object_infos = [] + + for name in interlocutors: + # Load person object + person_object = mm.Person.load_people_object(name) + if person_object: + self.logger.warning("Person object loaded: %s", person_object) + else: + first_name, last_name = mm.Person.split_name(name) + person_object = mm.Person(first_name=first_name, last_name=last_name) + person_object.save() + self.logger.warning("New person object created: %s", person_object) + + person_objects.append(person_object) + + person_object_info = mm.Person.format_person_object(person_object) + if person_object_info: + person_object_infos.append(person_object_info) + + if person_object_infos: + # person_objects is used by prompt template + self.session_context["person_objects"] = person_object_infos + + for person_object in person_objects: + if person_object.id not in self.scene_context.scene_document.person_ids: + self.scene_context.scene_document.person_ids.append(person_object.id) + + self.scene_context.scene_document.save() + return True + + +class Diarykeeper(LLMEventHandler): + def __init__( + self, session_context: UserSessionContext, scene_context: SceneContext + ): + self.trigger_keys = ["ended_scene"] # make a diary when the scene ends + prompt = """You are a journalist who takes notes in the form of a diary. You help analyze and maintain the personal experiences, thoughts, and events from the perspective of a robot who is part of the conversation, according to the given conversation history between the robot and a human. + +The date was: {{ date }} +{% if location %} +The place the conversation took place was: {{location}} +{% endif %} +{% if interlocutor %} +The human interlocutor was: {{interlocutor}} +{% endif %} +{% if objective %} +The objective of the robot was: {{objective}} +{% endif %} + +# Conversation History +{{history}} +""" + super(Diarykeeper, self).__init__(session_context, scene_context, prompt) + self.pipeline = self.prompt | ChatOpenAI( + model="gpt-4o", temperature=0.6 + ).with_structured_output(mm.DiaryData) + self.chat_history.n_messages = None + self.chat_history.ai_message_prompt = "Robot: " + self.context = Diarykeeper.Context() + + class Context(BaseModel): + uid: Optional[str] + sid: Optional[str] + objective: Optional[str] + location: Optional[str] + interlocutor: Optional[str] + history: Optional[str] + + def set_diary_context(self, **context): + self.context = Diarykeeper.Context(**context) + + def handle(self, event) -> bool: + context = { + "objective": self.context.objective, + "history": self.context.history, + "date": date.today(), + "location": self.context.location, + "interlocutor": self.context.interlocutor, + } + if self.context.history == ChatHistory.empty_history: + return True + result = self.pipeline.invoke(context) + if result: + self.session_context["diary"] = result + diary = mm.Diary( + uid=self.context.uid, + conversation_id=self.context.sid, + interlocutor=self.context.interlocutor, + **result, + ) + diary.save() + self.logger.warning("Diary %s", diary.json()) + return True + + +class AttributeExtractor(LLMEventHandler): + def __init__( + self, session_context: UserSessionContext, scene_context: SceneContext + ): + self.trigger_keys = ["input"] + prompt = """ +Given the conversation history as context and the user's response below, extract the attributes in the user response and return a JSON message with those attributes. Do not make up things out of user's response. + +## Conversation History +{{history}} + + +{{input}} + + +{% if attribute_set %} +attributes to extract: {{attribute_set}}. If there is no such attribute, return nothing. +{% else %} +Automatially extract the personal attributes about the users. Do not extract trival information. If there is no attributes, return nothing. +{% endif %} +""" + super(AttributeExtractor, self).__init__(session_context, scene_context, prompt) + self.pipeline = self.prompt | ChatOpenAI( + model="gpt-4o", temperature=0.6 + ).with_structured_output(Attributes) + + def handle(self, event) -> bool: + if not self.input: + return True + if not self.session_context.get("person_object_id"): + return True + history = "" + messages = self.chat_history.filtered_messages(self.input) + for message in reversed(messages): + if message.type == "ai": + history = message.content + break + context = { + "input": self.input, + "history": history, + "attribute_set": self.session_context.get("attribute_set"), + } + result = self.pipeline.invoke(context) + if result and result.attributes and "attribute_set" in self.session_context: + del self.session_context["attribute_set"] + self.logger.warning("Extracted attributes %s", result) + + # update interlocutor + for attribute in result.attributes: + if ( + attribute.attribute == "name" + and self.session_context.get("interlocutor") != attribute.value + ): + self.logger.warning("Update interlocutor %s", attribute.value) + self.session_context["interlocutor"] = attribute.value + time.sleep(3) # wait for Person Object to be updated + + if self.session_context.get("person_object_id"): + person_object = mm.Person.get(self.session_context["person_object_id"]) + if person_object and result.attributes: + person_object = person_object.run() + # update the attributes of the person object + for attribute in result.attributes: + person_object.add_attribute(attribute.attribute, attribute.value) + person_object.save() + return True + + +class handlers: + _handler_classes = [ + AttributeExtractor, + Diarykeeper, + Evaluator, + PersonObjectUpdator, + GoalPlanner, + PhysiologicalHandler, + EmotionalHandler, + Replanner, + TuneRegulator, + ] + + locals().update({handler.__name__: handler for handler in _handler_classes}) + + @classmethod + def get_all_handlers(cls): + """Returns a list of all handler classes""" + return cls._handler_classes + + +if __name__ == "__main__": + import os + + logging.basicConfig(level=logging.INFO) + + # os.environ["LANGCHAIN_TRACING_V2"] = "true" + from haipy.parameter_server_proxy import UserSessionContext + + session_context = UserSessionContext(uid="default", sid="default") + listener = EventListener() + listener.subscribe("default.default") + # PhysiologicalHandler(session_context, None).set_event_listener(listener) + # EmotionalHandler(session_context, None).set_event_listener(listener) + GoalPlanner(session_context, None).set_event_listener(listener) + # GoalUpdator(session_context, None).set_event_listener(listener) + # ConflictResolver(session_context, None).set_event_listener(listener) + # InterestHandler(session_context, None).set_event_listener(listener) + # GoalEvaluator(session_context, None).set_event_listener(listener) + while True: + time.sleep(1) diff --git a/modules/ros_chatbot/src/ros_chatbot/interact/__init__.py b/modules/ros_chatbot/src/ros_chatbot/interact/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/ros_chatbot/src/ros_chatbot/interact/action_types.py b/modules/ros_chatbot/src/ros_chatbot/interact/action_types.py new file mode 100644 index 0000000..96fee55 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/interact/action_types.py @@ -0,0 +1,29 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +INTERRUPTION = "interruption" +INTERRUPTION_RESUME = "interruption_resume" +RESET = "reset" +PLACEHOLDER_UTTERANCE = "placeholder_utterance" +SWITCH_LANGUAGE = "switch_language" +RESPONSIVITY = "responsivity" +MONITOR = "monitor" +HANDLE_FACE_EVENT = "handle_face_event" +LOAD_USER_PROFILE = "load_user_profile" +UPDATE_USER_PROFILE = "update_user_profile" +SET_AUTONOMOUS_MODE = "set_autonomous_mode" +SET_HYBRID_MODE = "set_hybrid_mode" diff --git a/modules/ros_chatbot/src/ros_chatbot/interact/base.py b/modules/ros_chatbot/src/ros_chatbot/interact/base.py new file mode 100644 index 0000000..584c599 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/interact/base.py @@ -0,0 +1,42 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +from abc import ABCMeta, abstractmethod + + +class EventGenerator(object, metaclass=ABCMeta): + @abstractmethod + def generate(self, text, lang): + """Detects intent for the text and language""" + pass + + @staticmethod + def create_event(type, payload=None): + # TODO: validate event payload + event = { + "type": type, + "payload": payload, + } + return event + + +class BasicEventGenerator(EventGenerator): + def __init__(self, type): + self._type = type + + def generate(self, payload): + return self.create_event(self._type, payload) diff --git a/modules/ros_chatbot/src/ros_chatbot/interact/controller_manager.py b/modules/ros_chatbot/src/ros_chatbot/interact/controller_manager.py new file mode 100644 index 0000000..30fbe31 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/interact/controller_manager.py @@ -0,0 +1,251 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import functools +import logging +import os +import threading +import time +from collections import defaultdict + +import yaml +from benedict import benedict +from haipy.utils import envvar_yaml_loader, to_list + +from . import event_types +from .controllers import registered_controllers +from .controllers.base import AsyncAgentController, SyncAgentController +from .event_generator import StateEvent + +logger = logging.getLogger(__name__) + + +def synchronized(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + lock = vars(self).get("_sync_lock", None) + if lock is None: + lock = vars(self).setdefault("_sync_lock", threading.RLock()) + with lock: + return func(self, *args, **kwargs) + + return wrapper + + +class ControllerManager(object): + def __init__(self, state, action_handler): + self.state = state + self.event_generator = StateEvent(self.state, self.on_event) + self.action_handler = action_handler + self._events = [] + self._listeners = defaultdict(list) + self.added_new_event = threading.Event() + + envvar_yaml_loader() + # load specs + HR_CHATBOT_WORLD_DIR = os.environ.get("HR_CHATBOT_WORLD_DIR", "") + spec_file = os.path.join(HR_CHATBOT_WORLD_DIR, "controllers.yaml") + if spec_file and os.path.isfile(spec_file): + with open(spec_file) as f: + config = yaml.safe_load(f) + self.specs = config["controllers"] + else: + self.specs = [] + logger.warning("Controller spec file was not found") + + # load config + config_file = os.path.join(HR_CHATBOT_WORLD_DIR, "control.yaml") + if config_file and os.path.isfile(config_file): + with open(config_file) as f: + config = yaml.safe_load(f) + self.controller_configs = { + controller["id"]: benedict(controller.get("config", {})) + for controller in config["controllers"] + } + else: + self.controller_configs = {} + logger.warning("Controller config file was not found") + + self.install_controllers() + + def wait_for(self, event_types, timeout=None): + """Waits for the coming of any of the events of a given type""" + end_time = None + if timeout is not None: + end_time = time.time() + timeout + pos = len(self._events) + while True: + self.added_new_event.clear() + if end_time is not None: + if end_time < time.time(): + logger.info("event timeout") + return False + for event in self._events[pos:]: + if event["type"] in to_list(event_types): + return True + self.added_new_event.wait(0.02) + + def get_controller(self, id): + return self.controllers.get(id) + + def install_controllers(self): + # install controllers + self.controllers = {} + for spec in self.specs: + if spec["type"] not in registered_controllers: + raise ValueError("Unknown controller type: %s" % spec["type"]) + if spec.get("disabled"): + continue + + args = spec.get("args", {}) + args["state"] = self.state + args["store"] = self + cls = registered_controllers[spec["type"]] + controller = cls(**args) + for event in controller.subscribe_events: + self.addEventListener(event, controller) + self.controllers[controller.id] = controller + if self.controllers: + logger.info( + "Added controllers %s", ", ".join(list(self.controllers.keys())) + ) + + # set controller config + for id, controller in self.controllers.items(): + if id in self.controller_configs: + config = self.controller_configs[id] + controller.set_config(config) + else: + logger.info('controller "%s" has no config', id) + + def setup_controllers(self, cfg): + for controller_id in [ + "placeholder_utterance_controller", + "language_switch_controller", + "interruption_controller", + "emotion_controller", + "monitor_controller", + "command_controller", + "responsivity_controller", + "user_acquisition_controller", + ]: + cfg_name = "enable_%s" % controller_id + if cfg_name in cfg: + controller_enabled = getattr(cfg, cfg_name) + else: + continue + controller = self.get_controller(controller_id) + if controller: + if controller.enabled != controller_enabled: + controller.enabled = controller_enabled + logger.info( + "Controller %s is %s" + % ( + controller.id, + "enabled" if controller.enabled else "disabled", + ) + ) + else: + if controller_enabled: + setattr(cfg, cfg_name, False) + logger.warning("No controller %s configured", controller_id) + + placeholder_utterance_controller = self.get_controller( + "placeholder_utterance_controller" + ) + if placeholder_utterance_controller: + if cfg.placeholder_utterances: + utterances = cfg.placeholder_utterances.splitlines() + utterances = [ + utterance.strip() for utterance in utterances if utterance.strip() + ] + placeholder_utterance_controller.set_placeholder_utterances(utterances) + else: + placeholder_utterance_controller.set_placeholder_utterances([]) + placeholder_utterance_controller.set_prob_escalate_step( + cfg.placeholder_prob_step + ) + + def dispatch(self, action): + """Dispatches the actions from controllers""" + self.action_handler(action) + + def wait_controller_finish(self): + for controller in self.controllers.values(): + if isinstance(controller, AsyncAgentController): + while True: + if controller.is_idle(): + break + else: + logger.warning("wait for %s to finish", controller.id) + time.sleep(0.1) + + @synchronized + def act(self): + """Gets the synchronous actions""" + actions = [] + for controller in self.controllers.values(): + if isinstance(controller, SyncAgentController): + if not controller.events.empty(): + try: + _actions = controller.act() + if _actions: + actions.extend(_actions) + except Exception as ex: + logger.exception(ex) + finally: + controller.done() + return actions + + @synchronized + def on_event(self, event): + """Handles new events""" + self._events.append(event) + # event_id = len(self._events) + # payload_str = str(event["payload"])[:40] # TODO: stringify payload + # logger.warning("Event %d type: %s, payload: %s...", event_id, event['type'], payload_str) + self.added_new_event.set() + + event_listeners = [] + event_listeners += self._listeners.get(event["type"], []) + event_listeners += self._listeners.get(event_types.ALL_EVENTS, []) + for listener in event_listeners: + if listener.enabled: + logger.info("[%s] observe %s", listener.id, event) + listener.observe(event) + else: + logger.debug("Event listener %r is disabled", listener.id) + + def addEventListener(self, type, listener): + self._listeners[type].append(listener) + + def register_event_generator(self, state, generator): + self.event_generator.register_event_generator(state, generator) + + def reset(self): + # reset event generator + self.event_generator.reset() + + # reset controllers + for controller in self.controllers.values(): + if hasattr(controller, "reset"): + try: + controller.reset() + except Exception as ex: + logger.exception(ex) + + logger.warning("Reset controller manager") diff --git a/modules/ros_chatbot/src/ros_chatbot/interact/controllers/__init__.py b/modules/ros_chatbot/src/ros_chatbot/interact/controllers/__init__.py new file mode 100644 index 0000000..22bc6e8 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/interact/controllers/__init__.py @@ -0,0 +1,40 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +from .responsivity_controller import ResponsivityController +from .interruption_controller import InterruptionController +from .interruption_resume_controller import InterruptionResumeController +from .placeholder_utterance_controller import PlaceholderUtteranceController +from .language_switch_controller import LanguageSwitchController +from .command_controller import CommandController +from .emotion_controller import EmotionController +from .monitor_controller import MonitorController +from .user_acquisition_controller import UserAcquisitionController + +_controller_classes = [ + ResponsivityController, + InterruptionController, + InterruptionResumeController, + PlaceholderUtteranceController, + LanguageSwitchController, + CommandController, + EmotionController, + MonitorController, + UserAcquisitionController, +] + +registered_controllers = {cls.type: cls for cls in _controller_classes} diff --git a/modules/ros_chatbot/src/ros_chatbot/interact/controllers/base.py b/modules/ros_chatbot/src/ros_chatbot/interact/controllers/base.py new file mode 100644 index 0000000..abbddfa --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/interact/controllers/base.py @@ -0,0 +1,85 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +from queue import Queue +from abc import ABCMeta, abstractmethod + +logger = logging.getLogger(__name__) + + +class AgentController(object): + type = "base" + + def __init__(self, id, store, state): + """ + id: the unique id of the controller + store: the action store + state: the robot state + """ + self.id = id + self.store = store + self.state = state + self.events = Queue(maxsize=10) + self.config = {} + self.subscribe_events = [] + + @property + def enabled(self): + return self.config.get("enabled", False) + + @enabled.setter + def enabled(self, enabled: bool): + self.config["enabled"] = enabled + + def set_config(self, config: dict): + self.config = config + if "enabled" in self.config: + self.enabled = self.config["enabled"] + + def observe(self, event: dict): + self.events.put(event) # {type, payload} + + def create_action(self, type: str, payload=None): + # TODO: validate action payload + action = {"type": type} + action["payload"] = payload or {} + return action + + +class AsyncAgentController(AgentController, metaclass=ABCMeta): + """Synchronized agent controller""" + + type = "AsyncAgentController" + + @abstractmethod + def is_idle(self): + """indicates whether the controller is idle""" + pass + + +class SyncAgentController(AgentController, metaclass=ABCMeta): + """Synchronized agent controller""" + + type = "SyncAgentController" + + @abstractmethod + def act(self): + pass + + def done(self): + self.event = None diff --git a/modules/ros_chatbot/src/ros_chatbot/interact/controllers/command_controller.py b/modules/ros_chatbot/src/ros_chatbot/interact/controllers/command_controller.py new file mode 100644 index 0000000..cfa16ba --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/interact/controllers/command_controller.py @@ -0,0 +1,48 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging + +from ros_chatbot.interact import action_types +from ros_chatbot.interact import event_types as types + +from .base import SyncAgentController + +logger = logging.getLogger(__name__) + + +class CommandController(SyncAgentController): + + type = "CommandController" + + def __init__(self, id, store, state): + super(CommandController, self).__init__(id, store, state) + self.subscribe_events = [types.USER_COMMAND, types.UTTERANCE] + + def act(self): + actions = [] + while not self.events.empty(): + event = self.events.get() + payload = event["payload"] + type = event["type"] + if type == types.USER_COMMAND: + command = payload["text"] + if command == ":reset": + actions.append(self.create_action(action_types.RESET)) + else: + logger.warning("Unknown command %s", command) + return actions diff --git a/modules/ros_chatbot/src/ros_chatbot/interact/controllers/emotion_controller.py b/modules/ros_chatbot/src/ros_chatbot/interact/controllers/emotion_controller.py new file mode 100644 index 0000000..bbaa5b5 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/interact/controllers/emotion_controller.py @@ -0,0 +1,82 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import os +import re +import logging +import requests +from collections import deque + +from haipy.nlp.sentiment_classifier import TransformerSentimentClassifier +from ros_chatbot.interact import event_types as types + +from .base import SyncAgentController + +logger = logging.getLogger(__name__) + + +class EmotionController(SyncAgentController): + + type = "EmotionController" + + def __init__(self, id, store, state): + super(EmotionController, self).__init__(id, store, state) + host = os.environ.get("NLP_SERVER_HOST", "localhost") + port = os.environ.get("NLP_SERVER_PORT", 8401) + self.state = state + self.classifier = TransformerSentimentClassifier(host, port) + self.subscribe_events = [ + types.UTTERANCE, + ] + self.sentiment_length = 5 + self.decays = [0.6, 0.7, 0.8, 0.9, 1] + + def act(self): + actions = [] + while not self.events.empty(): + event = self.events.get() + if event["type"] == types.UTTERANCE: + payload = event["payload"] + sentiment = self.classifier.detect_sentiment( + payload["text"], payload["lang"] + ) + if sentiment: + # update sentiments + state = self.state.getState() + if "user_sentiments" in state: + sentiments = state["user_sentiments"] + else: + sentiments = deque(maxlen=self.sentiment_length) + sentiments.extend([0] * self.sentiment_length) + sentiments.append(sentiment) + self.state.update(user_sentiments=sentiments) + + # find significant sentiment + state = self.state.getState() + sentiments = state["user_sentiments"] + decayed_sentiments = [ + abs(i) * j for i, j in zip(sentiments, self.decays) + ] # decay + abs_sentiments = [abs(i) for i in decayed_sentiments] + index = abs_sentiments.index(max(abs_sentiments)) + significant_sentiment = sentiments[index] + + # fire event + if significant_sentiment > 0.6 or significant_sentiment < -0.4: + self.state.update(user_sentiment_trigger=significant_sentiment) + logger.info("user sentiment %s", significant_sentiment) + return actions diff --git a/modules/ros_chatbot/src/ros_chatbot/interact/controllers/interruption_controller.py b/modules/ros_chatbot/src/ros_chatbot/interact/controllers/interruption_controller.py new file mode 100644 index 0000000..8d72b2b --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/interact/controllers/interruption_controller.py @@ -0,0 +1,190 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import random +import threading + +from haipy.utils import to_list +from haipy.nlp.intent_classifier import IntentDetector +from ros_chatbot.interact import event_types as types +from ros_chatbot.interact import action_types as action_types + +from .base import AsyncAgentController + +logger = logging.getLogger(__name__) + + +class InterruptionController(AsyncAgentController): + + type = "InterruptionController" + + def __init__(self, id, store, state): + super(InterruptionController, self).__init__(id, store, state) + self.subscribe_events = [ + types.UTTERANCE, + types.KEYWORDS, + ] + self._intent_detector = IntentDetector("soultalk") + + self.current_event = None + self.running = threading.Event() + job = threading.Thread(target=self.run, name="Thread-%s" % self.id) + job.daemon = True + job.start() + + def observe(self, event): + if not self.state.is_robot_speaking(): + logger.info("Nothing to interrupt") + elif self.state.is_interruption_mode(): + logger.info("Interruption resuming mode") + else: + self.events.put(event) + + @AsyncAgentController.enabled.setter + def enabled(self, enabled): + self.config["enabled"] = enabled + if enabled: + self.running.set() + else: + self.running.clear() + + def is_idle(self): + return self.events.empty() and self.current_event is None + + def run(self): + while True: + if self.running.is_set(): + self.current_event = self.events.get() + type = self.current_event["type"] + payload = self.current_event["payload"] + try: + if type == types.KEYWORDS: + text = payload["text"] + lang = payload["lang"] + if self.is_short_pause_interruption(text, lang): + logger.info("Short pause interruption detected") + action = self._make_action("soft_interruption", lang, True) + self.store.dispatch(action) + if type == types.UTTERANCE: + text = payload["text"] + lang = payload["lang"] + # if self.is_full_stop_interrupt(text, lang): + # logger.warning("Full stop interruption detected") + # action = self._make_action('full_stop_interruption', lang, False) + # self.store.dispatch(action) + if self.is_short_pause_interruption(text, lang): + logger.info("Short pause interruption detected") + action = self._make_action( + "short_pause_interruption", lang, False + ) + self.store.dispatch(action) + elif self.is_long_input_interruption(text, lang): + logger.info("Long input interruption detected") + action = self._make_action( + "long_input_interruption", lang, False + ) + self.store.dispatch(action) + except Exception as ex: + logger.exception(ex) + self.current_event = None + else: + self.running.wait() + + def _make_action(self, action_type, lang, resume): + utterance = "" + key = "%s.utterances.%s" % (action_type, lang) + if key in self.config: + utterances = to_list(self.config[key]) + utterance = random.choice(utterances) + else: + logger.info("No interruption utterances") + payload = { + "text": utterance, + "lang": lang, + "type": action_type, + "resume": resume, + "controller": self.id, + } + action = self.create_action(action_types.INTERRUPTION, payload) + return action + + def is_full_stop_interrupt(self, text, lang): + if lang == "en-US": + if "full_stop_interruption.keywords" in self.config: + keywords = to_list(self.config["full_stop_interruption.keywords"]) + for keyword in keywords: + if keyword.lower() in text.lower(): + return True + if "full_stop_interruption.intents" in self.config: + try: + result = self._intent_detector.detect_intent(text, lang) + except Exception as ex: + logger.error(ex) + return False + if result: + logger.info("Intent %r", result) + intent = result["intent"]["name"] + if intent in to_list(self.config["full_stop_interruption.intents"]): + return True + return False + + def is_soft_interruption(self, text, lang): + if lang == "en-US": + if "soft_interruption.max_input_words" in self.config: + max_input_words = self.config["soft_interruption.max_input_words"] + input_length = text.split(" ") + if len(input_length) > max_input_words: + logger.info("Long input (%s) not soft interrupt", len(input_length)) + return False + if "soft_interruption.keywords" in self.config: + keywords = to_list(self.config["soft_interruption.keywords"]) + for keyword in keywords: + if keyword.lower() in text.lower(): + return True + return False + + def is_short_pause_interruption(self, text, lang): + if lang == "en-US": + if "short_pause_interruption.max_input_words" in self.config: + max_input_words = self.config[ + "short_pause_interruption.max_input_words" + ] + input_length = text.split(" ") + if len(input_length) > max_input_words: + logger.info( + "Long input (len=%s) is not short pause interrupt", + len(input_length), + ) + return False + if "short_pause_interruption.keywords" in self.config: + keywords = to_list(self.config["short_pause_interruption.keywords"]) + for keyword in keywords: + if keyword.lower() in text.lower(): + logger.info("Interruption words %s", keyword) + return True + return False + + def is_long_input_interruption(self, text, lang): + if lang == "en-US": + if "long_input_interruption.min_input_words" in self.config: + min_input_words = self.config["long_input_interruption.min_input_words"] + input_length = text.split(" ") + if len(input_length) >= min_input_words: + logger.info("Long input (%s) hard interrupt", len(input_length)) + return True + return False diff --git a/modules/ros_chatbot/src/ros_chatbot/interact/controllers/interruption_resume_controller.py b/modules/ros_chatbot/src/ros_chatbot/interact/controllers/interruption_resume_controller.py new file mode 100644 index 0000000..2329ac6 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/interact/controllers/interruption_resume_controller.py @@ -0,0 +1,49 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging + +from ros_chatbot.interact import event_types as types +from ros_chatbot.interact import action_types + +from .base import AgentController + +logger = logging.getLogger(__name__) + + +class InterruptionResumeController(AgentController): + + type = "InterruptionResumeController" + + def __init__(self, id, store, state): + super(InterruptionResumeController, self).__init__(id, store, state) + self.subscribe_events = [ + types.ROBOT_INTERRUPTED, + ] + + def observe(self, event): + type = event["type"] + payload = event["payload"] + if type == types.ROBOT_INTERRUPTED: + if payload["text"]: + payload = { + "text": payload["text"], + "lang": payload["lang"], + "controller": self.id, + } + action = self.create_action(action_types.INTERRUPTION_RESUME, payload) + self.store.dispatch(action) diff --git a/modules/ros_chatbot/src/ros_chatbot/interact/controllers/language_switch_controller.py b/modules/ros_chatbot/src/ros_chatbot/interact/controllers/language_switch_controller.py new file mode 100644 index 0000000..62d5bce --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/interact/controllers/language_switch_controller.py @@ -0,0 +1,91 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import random + +from haipy.nlp.intent_classifier import IntentDetector +from ros_chatbot.interact import event_types as types +from ros_chatbot.interact import action_types + +from .base import SyncAgentController + +logger = logging.getLogger(__name__) + + +class LanguageSwitchController(SyncAgentController): + + type = "LanguageSwitchController" + + def __init__(self, id, store, state): + super(LanguageSwitchController, self).__init__(id, store, state) + self.subscribe_events = [types.UTTERANCE] + self.intent_detector = IntentDetector("dialogflow") + + def set_config(self, config): + super(LanguageSwitchController, self).set_config(config) + self.target_language_code = self.config["target_language_code"] + self.language_switch_response = self.config["language_switch_response"] + + def reset(self): + if self.intent_detector: + self.intent_detector.reset() + + def act(self): + actions = [] + while not self.events.empty(): + event = self.events.get() + payload = event["payload"] + type = event["type"] + if type == types.UTTERANCE: + if self.intent_detector: + intent = "" + confidence = 0 + try: + result = self.intent_detector.detect_intent( + payload["text"], payload["lang"] + ) + intent = result["intent"]["name"] + confidence = result["intent"]["confidence"] + logger.info("Intent %s confidence %s", intent, confidence) + except Exception as ex: + logger.error(ex) + return + if confidence > 0.3 and intent == "language.switch": + target_lang = "" + for entity in result["entities"]: + if entity["entity"] == "languge": + target_lang = entity["value"] + logger.info("Detected target language %r", target_lang) + break + # switch language + if target_lang and target_lang in self.target_language_code: + target_lang = self.target_language_code[target_lang] + response = self.language_switch_response.get(target_lang) + if response and isinstance(response, list): + response = random.choice(response) + payload = { + "text": response, + "lang": target_lang, + "controller": self.id, + } + actions.append( + self.create_action( + action_types.SWITCH_LANGUAGE, payload + ) + ) + return actions diff --git a/modules/ros_chatbot/src/ros_chatbot/interact/controllers/monitor_controller.py b/modules/ros_chatbot/src/ros_chatbot/interact/controllers/monitor_controller.py new file mode 100644 index 0000000..e74ffca --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/interact/controllers/monitor_controller.py @@ -0,0 +1,66 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import os +import re +import logging +import requests +from collections import deque +import threading + +from ros_chatbot.interact import event_types as types +from ros_chatbot.interact import action_types + +from .base import AsyncAgentController + +logger = logging.getLogger(__name__) + + +class MonitorController(AsyncAgentController): + + type = "MonitorController" + + def __init__(self, id, store, state): + super(MonitorController, self).__init__(id, store, state) + self.subscribe_events = [types.ALL_EVENTS] + + self.current_event = None + self.running = threading.Event() + job = threading.Thread(target=self.run, name="Thread-%s" % self.id) + job.daemon = True + job.start() + + @AsyncAgentController.enabled.setter + def enabled(self, enabled): + self.config["enabled"] = enabled + if enabled: + self.running.set() + else: + self.running.clear() + + def is_idle(self): + return self.events.empty() and self.current_event is None + + def run(self): + while True: + if self.running.is_set(): + self.current_event = self.events.get() + action = self.create_action(action_types.MONITOR, self.current_event) + self.store.dispatch(action) + self.current_event = None + else: + self.running.wait() diff --git a/modules/ros_chatbot/src/ros_chatbot/interact/controllers/placeholder_utterance_controller.py b/modules/ros_chatbot/src/ros_chatbot/interact/controllers/placeholder_utterance_controller.py new file mode 100644 index 0000000..f78ba08 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/interact/controllers/placeholder_utterance_controller.py @@ -0,0 +1,166 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import os +import random +import re +import threading +import time + +from ros_chatbot.interact import action_types +from ros_chatbot.interact import event_types as types + +from .base import AgentController + +logger = logging.getLogger(__name__) + + +class PlaceholderUtteranceController(AgentController): + + type = "PlaceholderUtteranceController" + + def __init__(self, id, store, state): + super(PlaceholderUtteranceController, self).__init__(id, store, state) + self.subscribe_events = [types.UTTERANCE, types.ROBOT_SPEAKING] + + self.running = threading.Event() + self.current_language = None + + self.utterance_cand = [] + + self.wait_for_response = threading.Event() + self.time_since_new_speech = None + self.speak_to_robot_prob = 0 + self.speak_to_robot_detected = False + self.prob_escalate_step = 0.25 + + # self.robot_awaken_phrase_pattern = None + # character = os.environ.get('HR_CHARACTER') + # if character: + # self.robot_awaken_phrase_pattern = re.compile( + # r"(hi|hey|hello) {}".format(character), re.IGNORECASE) + # else: + # logger.warning("No character name is found") + # self.robot_awaken_phrase_pattern = re.compile( + # r"(hi|hey|hello) {}".format("sophia"), re.IGNORECASE) + + job = threading.Thread(target=self.run) + job.daemon = True + job.start() + + def set_config(self, config): + super(PlaceholderUtteranceController, self).set_config(config) + if "prob_escalate_step" in self.config: + self.prob_escalate_step = self.config["prob_escalate_step"] + + def run(self): + while True: + if self.running.is_set(): + + utterance_cand = None + if not self.utterance_cand: + if "utterances" in self.config: + utterance_cand = self.config["utterances"].get( + self.current_language + ) + else: + utterance_cand = self.utterance_cand + + if ( + self.time_since_new_speech is not None + and self.speak_to_robot_detected + ): + now = time.time() + silence = now - self.time_since_new_speech + logger.info("Silence %s", silence) + silence_factor = int(silence * 2) # number of 1/2 second + if self.wait_for_response.is_set(): + # fill placeholder + prob = min(1, silence_factor * self.prob_escalate_step) + if not utterance_cand: + logger.warning("No placeholder utterances") + self.time_since_new_speech = None + logger.info("Probability %s", prob) + if utterance_cand and random.random() < prob: + utterance = random.choice(utterance_cand) + payload = { + "text": utterance, + "lang": self.current_language, + "controller": self.id, + } + action = self.create_action( + action_types.PLACEHOLDER_UTTERANCE, payload + ) + self.store.dispatch(action) + self.time_since_new_speech = None + time.sleep(0.5) + else: + self.running.wait() + + def reset(self): + self.time_since_new_speech = None + self.wait_for_response.clear() + + @AgentController.enabled.setter + def enabled(self, enabled: bool): + self.config["enabled"] = enabled + if enabled: + self.running.set() + else: + self.running.clear() + self.reset() + + def observe(self, event): + if event["type"] == types.ROBOT_SPEAKING: + if event["payload"]: + logger.info("Reset") + self.reset() # reset timer when robot starts to speak + return + + if event["type"] == types.UTTERANCE: + payload = event["payload"] + + # TODO: other means to detect the user speech event + self.speak_to_robot_detected = True + + self.current_language = payload["lang"] or "en-US" + if self.speak_to_robot_detected: + self.time_since_new_speech = time.time() + self.wait_for_response.set() + + # if self.robot_awaken_phrase_pattern and \ + # self.robot_awaken_phrase_pattern.search(msg.utterance): + # if self.placeholder_config: + # self.speak_to_robot_prob = self.placeholder_config['speaking_initial_prob'] + # logger.info("Awaken robot") + # elif self.placeholder_config: + # # update the probability + # #self.speak_to_robot_prob -= self.placeholder_config['speaking_prob_decay_step'] # delay probability 10% + # #self.speak_to_robot_prob = max(0, self.speak_to_robot_prob) + + # #logger.info("Speak to robot prob %s", self.speak_to_robot_prob) + # #self.speak_to_robot_detected = random.random() < self.speak_to_robot_prob + # #if self.speak_to_robot_detected: + # # logger.info("Speak to robot detected") + + def set_placeholder_utterances(self, utterances): + # TODO: set default utterances from config + self.utterance_cand = utterances + + def set_prob_escalate_step(self, step): + self.prob_escalate_step = step diff --git a/modules/ros_chatbot/src/ros_chatbot/interact/controllers/responsivity_controller.py b/modules/ros_chatbot/src/ros_chatbot/interact/controllers/responsivity_controller.py new file mode 100644 index 0000000..35e320a --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/interact/controllers/responsivity_controller.py @@ -0,0 +1,104 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import re +import logging + +from haipy.utils import to_list +from haipy.nlp.intent_classifier import IntentDetector +from ros_chatbot.interact import event_types as types +from ros_chatbot.interact import action_types + +from .base import SyncAgentController + +logger = logging.getLogger(__name__) + + +class ResponsivityController(SyncAgentController): + """ + Track the likelihood of response given various conditions. + + The likelihood of response could be changed by + - preset keyphrases + - perception + - control API + + """ + + type = "ResponsivityController" + + def __init__(self, id, store, state): + super(ResponsivityController, self).__init__(id, store, state) + self.subscribe_events = [ + types.UTTERANCE, + ] + self._intent_detector = IntentDetector("soultalk") + + def is_full_stop(self, text, lang): + if lang == "en-US": + if "full_stop.keywords" in self.config: + keywords = to_list(self.config["full_stop.keywords"]) + for keyword in keywords: + if keyword.lower() in text.lower(): + return True + if "full_stop.intents" in self.config: + try: + result = self._intent_detector.detect_intent(text, lang) + except Exception as ex: + logger.error(ex) + return False + if result: + intent = result["intent"]["name"] + if intent in to_list(self.config["full_stop.intents"]): + return True + return False + + def is_wakenup(self, text, lang): + if lang == "en-US": + if "wakeup.keywords" in self.config: + keywords = to_list(self.config["wakeup.keywords"]) + for keyword in keywords: + if keyword.lower() in text.lower(): + return True + if "wakeup.regular_expressions" in self.config: + expressions = to_list(self.config["wakeup.regular_expressions"]) + for expression in expressions: + if re.match(expression, text, re.IGNORECASE): + return True + return False + + def act(self): + actions = [] + while not self.events.empty(): + event = self.events.get() + payload = event["payload"] + if event["type"] == types.UTTERANCE: + text = payload["text"] + lang = payload["lang"] + if self.is_full_stop(text, lang): + logger.warning("Full stop detected. Text %r", text) + payload = {"controller": self.id, "responsivity": 0} + actions.append( + self.create_action(action_types.RESPONSIVITY, payload) + ) + elif self.is_wakenup(text, lang): + logger.warning("Wake up detected") + payload = {"controller": self.id, "responsivity": 1} + actions.append( + self.create_action(action_types.RESPONSIVITY, payload) + ) + return actions diff --git a/modules/ros_chatbot/src/ros_chatbot/interact/controllers/user_acquisition_controller.py b/modules/ros_chatbot/src/ros_chatbot/interact/controllers/user_acquisition_controller.py new file mode 100644 index 0000000..07932c7 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/interact/controllers/user_acquisition_controller.py @@ -0,0 +1,77 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import logging +import threading + +from ros_chatbot.interact import event_types as types +from ros_chatbot.interact import action_types + +from .base import AsyncAgentController + +logger = logging.getLogger(__name__) + + +class UserAcquisitionController(AsyncAgentController): + + type = "UserAcquisitionController" + + def __init__(self, id, store, state): + super(UserAcquisitionController, self).__init__(id, store, state) + self.subscribe_events = [ + types.FACE_EVENT, + types.USER_PROFILE, + ] + + self.current_event = None + self.running = threading.Event() + job = threading.Thread(target=self.run, name="Thread-%s" % self.id) + job.daemon = True + job.start() + + def is_idle(self): + return self.events.empty() and self.current_event is None + + @AsyncAgentController.enabled.setter + def enabled(self, enabled): + self.config["enabled"] = enabled + if enabled: + self.running.set() + else: + self.running.clear() + + def run(self): + while True: + if self.running.is_set(): + self.current_event = self.events.get() + type = self.current_event["type"] + if type == types.FACE_EVENT: + payload = self.current_event["payload"] + payload["controller"] = self.id + action = self.create_action(action_types.HANDLE_FACE_EVENT, payload) + self.store.dispatch(action) + if type == types.USER_PROFILE: + payload = { + "controller": self.id, + "profile": self.current_event["payload"], + } + action = self.create_action( + action_types.UPDATE_USER_PROFILE, payload + ) + self.current_event = None + else: + self.running.wait() diff --git a/modules/ros_chatbot/src/ros_chatbot/interact/event_generator.py b/modules/ros_chatbot/src/ros_chatbot/interact/event_generator.py new file mode 100644 index 0000000..0e54a23 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/interact/event_generator.py @@ -0,0 +1,156 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import time +import logging +import threading +from typing import Callable + +from . import event_types as events +from ros_chatbot.interact.base import BasicEventGenerator, EventGenerator + +logger = logging.getLogger(__name__) + + +class FaceEventGenerator(EventGenerator): + class Face(object): + def __init__(self, id) -> None: + self.id = id + self.last_seen = time.time() + + @property + def is_known(self): + return not self.id.startswith("U") + + def __hash__(self): + return hash(self.id) + + def __eq__(self, other): + return self.id == other.id + + def __init__(self, face_lost_timeout=60): + self.faces = set([]) + self.face_lost_timeout = face_lost_timeout + + def reset(self): + self.faces = set([]) + + def generate(self, face_ids): + faces = set( + [FaceEventGenerator.Face(face_id) for face_id in face_ids if face_id] + ) + + lost_faces = [] + now = time.time() + for face in self.faces - faces: + if now - face.last_seen > self.face_lost_timeout: + lost_faces.append(face) + self.faces.remove(face) + + new_faces = faces - self.faces + + # update faces using newer faces + self.faces = faces.union(self.faces) + + if new_faces or lost_faces: + payload = { + "new_faces": [face.id for face in new_faces], + "lost_faces": [face.id for face in lost_faces], + } + event = self.create_event(events.FACE_EVENT, payload) + return event + + +class PoseEventGenerator(EventGenerator): + def __init__(self): + self.last = None + + def generate(self, payload): + # TODO: fix duplicated poses issue + # number_of_poses = len(payload) if payload else 0 + number_of_poses = 1 if payload else 0 + event = None + if self.last != number_of_poses: + if number_of_poses != 0: + event = self.create_event(events.USER_POSE_DETECTED, payload) + else: + if self.last: + event = self.create_event(events.USER_POSE_LOST) + self.last = number_of_poses + return event + + +# the state key to event generator mapping +EVENT_GENERATORS = { + "utterance": BasicEventGenerator(events.UTTERANCE), + "robot_speaking": BasicEventGenerator(events.ROBOT_SPEAKING), + "user_speaking": BasicEventGenerator(events.USER_SPEAKING), + "command": BasicEventGenerator(events.USER_COMMAND), + "keywords": BasicEventGenerator(events.KEYWORDS), + "robot_interrupted": BasicEventGenerator(events.ROBOT_INTERRUPTED), + "user_sentiment_trigger": BasicEventGenerator(events.USER_SENTIMENT_TRIGGER), + "user_profile": BasicEventGenerator(events.USER_PROFILE), + "poses": PoseEventGenerator(), + "face_ids": FaceEventGenerator(10), +} + + +class StateEvent(object): + """The class that converts states to events""" + + def __init__(self, state, on_event: Callable[[dict], None]): + self._state = state + self._callback = on_event + self._state_pos = 0 # the position of state the event generated from + self._generators = EVENT_GENERATORS # the event generators + + job = threading.Thread( + name="state_update_checker", target=self.state_update_checker + ) + job.daemon = True + job.start() + + def state_update_checker(self): + while True: + self._state.wait_for_update() + self._on_state_change() + + def register_event_generator(self, state, generator: EventGenerator): + self._generators[state] = generator + + def _on_state_change(self): + states = self._state.getState() + state_keys = self._state.getStateChange() + for state_key in state_keys[self._state_pos :]: + self._state_pos += 1 # update the position + generator = self._generators.get(state_key) + if generator: + msg_data = states.get(state_key) + event = generator.generate(msg_data) + if event: + try: + self._callback(event) + except Exception as ex: + logger.exception(ex) + + def reset(self): + logger.warning("Reset event generator") + self._state.reset() + self._state_pos = 0 + for generator in self._generators.values(): + if hasattr(generator, "reset"): + generator.reset() diff --git a/modules/ros_chatbot/src/ros_chatbot/interact/event_types.py b/modules/ros_chatbot/src/ros_chatbot/interact/event_types.py new file mode 100644 index 0000000..a6f1e7f --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/interact/event_types.py @@ -0,0 +1,29 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +UTTERANCE = "utterance" +KEYWORDS = "keywords" +ROBOT_INTERRUPTED = "robot_interrupted" +ROBOT_SPEAKING = "robot_speaking" +USER_SPEAKING = "user_speaking" +USER_COMMAND = "command" +USER_SENTIMENT_TRIGGER = "user_sentiment_trigger" +USER_POSE_DETECTED = "user_pose_detected" +USER_POSE_LOST = "user_pose_lost" +USER_PROFILE = "user_profile" +FACE_EVENT = "face_event" +ALL_EVENTS = "_all_events" diff --git a/modules/ros_chatbot/src/ros_chatbot/interact/state.py b/modules/ros_chatbot/src/ros_chatbot/interact/state.py new file mode 100644 index 0000000..01def68 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/interact/state.py @@ -0,0 +1,99 @@ +## +## Copyright (C) 2017-2025 Hanson Robotics +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see . +## + +import copy +import threading +import logging + +logger = logging.getLogger(__name__) + + +class State(object): + def __init__(self): + self._state = {} + self._state_diffs = [] + self.lock = threading.RLock() + self.updated = threading.Event() + + def reset(self): + with self.lock: + self._state = {} + self._state_diffs = [] + self.updated.set() + logger.warning("Reset state") + + def update(self, **kwargs): + with self.lock: + state_updated = False + for key, value in kwargs.items(): + if key in self._state: + if value != self._state[key]: + self._state[key] = value + self._state_diffs.append(key) + state_updated = True + else: + self._state[key] = value + self._state_diffs.append(key) + state_updated = True + if state_updated: + self.updated.set() + + def wait_for_update(self, timeout=None): + flag = self.updated.wait(timeout) + self.updated.clear() + return flag + + def getState(self): + return copy.copy(self._state) + + def getStateValue(self, key): + return copy.copy(self._state.get(key)) + + def getStateChange(self): + return self._state_diffs[:] + + def __getattr__(self, attr): + """ + Automatically get the attribute value from the internal state. + + e.g. is_robot_speaking will return the value of "robot_speaking" + in the state. + """ + + def f(): + if attr.startswith("is_"): + key = attr[3:] + if key in self._state: + value = self._state[key] + if isinstance(value, bool): + return value + else: + raise AttributeError("State %r is not boolean" % key) + else: + return False # if the state is not set return False + else: + raise AttributeError("No such attribute %r" % attr) + + return f + + +if __name__ == "__main__": + state = State() + state.update(abc=False, abcd=True) + print(state.is_abc()) + print(state.is_abcd()) + print(state.is_abcde()) diff --git a/modules/ros_chatbot/src/ros_chatbot/pyaiml/AimlParser.py b/modules/ros_chatbot/src/ros_chatbot/pyaiml/AimlParser.py new file mode 100644 index 0000000..45229d3 --- /dev/null +++ b/modules/ros_chatbot/src/ros_chatbot/pyaiml/AimlParser.py @@ -0,0 +1,687 @@ +""" +Copyright 2003-2010 Cort Stratton. All rights reserved. +Copyright 2015, 2016 Hanson Robotics + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the + distribution. + +THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY +EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FREEBSD PROJECT OR +CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +from xml.sax.handler import ContentHandler +from xml.sax.xmlreader import Locator +import sys +import xml.sax +import xml.sax.handler +import logging + +logger = logging.getLogger("hr.chatbot.aiml.aimlparser") + + +class AimlParserError(Exception): + pass + + +class AimlHandler(ContentHandler): + # The legal states of the AIML parser + _STATE_OutsideAiml = 0 + _STATE_InsideAiml = 1 + _STATE_InsideCategory = 2 + _STATE_InsidePattern = 3 + _STATE_AfterPattern = 4 + _STATE_InsideThat = 5 + _STATE_AfterThat = 6 + _STATE_InsideTemplate = 7 + _STATE_AfterTemplate = 8 + + def __init__(self, encoding="UTF-8"): + self.categories = {} + self._encoding = encoding + self._state = self._STATE_OutsideAiml + self._version = "" + self._namespace = "" + self._forwardCompatibleMode = False + self._currentPattern = "" + self._currentPatternLocation = "" + self._currentThat = "" + self._currentTopic = "" + self._insideTopic = False + self._currentUnknown = "" # the name of the current unknown element + + # This is set to true when a parse error occurs in a category. + self._skipCurrentCategory = False + + # Counts the number of parse errors in a particular AIML document. + # query with getNumErrors(). If 0, the document is AIML-compliant. + self._numParseErrors = 0 + + # TODO: select the proper validInfo table based on the version number. + self._validInfo = self._validationInfo101 + + # This stack of bools is used when parsing
  • elements inside + # elements, to keep track of whether or not an + # attribute-less "default"
  • element has been found yet. Only + # one default
  • is allowed in each element. We need + # a stack in order to correctly handle nested tags. + self._foundDefaultLiStack = [] + + # This stack of strings indicates what the current whitespace-handling + # behavior should be. Each string in the stack is either "default" or + # "preserve". When a new AIML element is encountered, a new string is + # pushed onto the stack, based on the value of the element's "xml:space" + # attribute (if absent, the top of the stack is pushed again). When + # ending an element, pop an object off the stack. + self._whitespaceBehaviorStack = ["default"] + + self._elemStack = [] + self._locator = Locator() + self.setDocumentLocator(self._locator) + + def getNumErrors(self): + "Return the number of errors found while parsing the current document." + return self._numParseErrors + + def setEncoding(self, encoding): + """Set the text encoding to use when encoding strings read from XML. + + Defaults to 'UTF-8'. + + """ + self._encoding = encoding + + def _location(self): + "Return a string describing the current location in the source file." + line = self._locator.getLineNumber() + column = self._locator.getColumnNumber() + return "(line %d, column %d)" % (line, column) + + def _pushWhitespaceBehavior(self, attr): + """Push a new string onto the whitespaceBehaviorStack. + + The string's value is taken from the "xml:space" attribute, if it exists + and has a legal value ("default" or "preserve"). Otherwise, the previous + stack element is duplicated. + + """ + assert ( + len(self._whitespaceBehaviorStack) > 0 + ), "Whitespace behavior stack should never be empty!" + try: + if attr["xml:space"] == "default" or attr["xml:space"] == "preserve": + self._whitespaceBehaviorStack.append(attr["xml:space"]) + else: + raise AimlParserError( + "Invalid value for xml:space attribute " + self._location() + ) + except KeyError: + self._whitespaceBehaviorStack.append(self._whitespaceBehaviorStack[-1]) + + def startElementNS(self, name, qname, attr): + logger.info("QNAME:", qname) + logger.info("NAME:", name) + uri, elem = name + if elem == "bot": + logger.info("name:", attr.getValueByQName("name"), "a'ite?") + self.startElement(elem, attr) + pass + + def startElement(self, name, attr): + # Wrapper around _startElement, which catches errors in _startElement() + # and keeps going. + + # If we're inside an unknown element, ignore everything until we're + # out again. + if self._currentUnknown != "": + return + # If we're skipping the current category, ignore everything until + # it's finished. + if self._skipCurrentCategory: + return + + # process this start-element. + try: + self._startElement(name, attr) + except AimlParserError as msg: + # Print the error message + logger.error("PARSE ERROR: %s" % msg) + + self._numParseErrors += 1 # increment error count + # In case of a parse error, if we're inside a category, skip it. + if self._state >= self._STATE_InsideCategory: + self._skipCurrentCategory = True + + def _startElement(self, name, attr): + if name == "aiml": + # tags are only legal in the OutsideAiml state + if self._state != self._STATE_OutsideAiml: + raise AimlParserError("Unexpected tag " + self._location()) + self._state = self._STATE_InsideAiml + self._insideTopic = False + self._currentTopic = "" + try: + self._version = attr["version"] + except KeyError: + # This SHOULD be a syntax error, but so many AIML sets out there are missing + # "version" attributes that it just seems nicer to let it slide. + # raise AimlParserError, "Missing 'version' attribute in tag "+self._location() + # print "WARNING: Missing 'version' attribute in tag "+self._location() + # print " Defaulting to version 1.0" + self._version = "1.0" + self._forwardCompatibleMode = self._version != "1.0.1" + self._pushWhitespaceBehavior(attr) + # Not sure about this namespace business yet... + # try: + # self._namespace = attr["xmlns"] + # if self._version == "1.0.1" and self._namespace != "http://alicebot.org/2001/AIML-1.0.1": + # raise AimlParserError, "Incorrect namespace for AIML v1.0.1 "+self._location() + # except KeyError: + # if self._version != "1.0": + # raise AimlParserError, "Missing 'version' attribute(s) in tag "+self._location() + elif self._state == self._STATE_OutsideAiml: + # If we're outside of an AIML element, we ignore all tags. + return + elif name == "topic": + # tags are only legal in the InsideAiml state, and only + # if we're not already inside a topic. + if (self._state != self._STATE_InsideAiml) or self._insideTopic: + raise AimlParserError("Unexpected tag").with_traceback( + self._location() + ) + try: + self._currentTopic = str(attr["name"]) + except KeyError: + raise AimlParserError( + 'Required "name" attribute missing in element ' + + self._location() + ) + self._insideTopic = True + elif name == "category": + # tags are only legal in the InsideAiml state + if self._state != self._STATE_InsideAiml: + raise AimlParserError("Unexpected tag " + self._location()) + self._state = self._STATE_InsideCategory + self._currentPattern = "" + self._currentPatternLocation = "" + self._currentThat = "" + # If we're not inside a topic, the topic is implicitly set to * + if not self._insideTopic: + self._currentTopic = "*" + self._elemStack = [] + self._pushWhitespaceBehavior(attr) + elif name == "pattern": + # tags are only legal in the InsideCategory state + if self._state != self._STATE_InsideCategory: + raise AimlParserError("Unexpected tag " + self._location()) + self._state = self._STATE_InsidePattern + elif name == "that" and self._state == self._STATE_AfterPattern: + # are legal either inside a