diff --git a/ads/aqua/extension/__init__.py b/ads/aqua/extension/__init__.py index 5fe85bcfb..c7ab84886 100644 --- a/ads/aqua/extension/__init__.py +++ b/ads/aqua/extension/__init__.py @@ -13,6 +13,9 @@ from ads.aqua.extension.evaluation_handler import __handlers__ as __eval_handlers__ from ads.aqua.extension.finetune_handler import __handlers__ as __finetune_handlers__ from ads.aqua.extension.model_handler import __handlers__ as __model_handlers__ +from ads.aqua.extension.playground_handler import ( + __handlers__ as __playground_handlers__, +) from ads.aqua.extension.ui_handler import __handlers__ as __ui_handlers__ __handlers__ = ( @@ -22,6 +25,7 @@ + __deployment_handlers__ + __ui_handlers__ + __eval_handlers__ + + __playground_handlers__ ) diff --git a/ads/aqua/extension/playground_handler.py b/ads/aqua/extension/playground_handler.py new file mode 100644 index 000000000..e2d69349b --- /dev/null +++ b/ads/aqua/extension/playground_handler.py @@ -0,0 +1,370 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*-- + +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +import json +from dataclasses import dataclass, field +from typing import Dict + +from tornado.web import HTTPError +import tornado +import random + +from ads.aqua import logger +from ads.aqua.extension.base_handler import AquaAPIhandler +from ads.aqua.playground.entities import Message, Session, Thread +from ads.aqua.playground.model_invoker import ModelInvoker +from ads.aqua.playground.playground import MessageApp, SessionApp, ThreadApp +from ads.common.extended_enum import ExtendedEnumMeta +from ads.common.serializer import DataClassSerializable +from ads.common.utils import batch_convert_case + + +class Errors(str): + INVALID_INPUT_DATA_FORMAT = "Invalid format of input data." + NO_INPUT_DATA = "No input data provided." + MISSING_REQUIRED_PARAMETER = "Missing required parameter: '{}'" + + +@dataclass +class NewSessionRequest(DataClassSerializable): + """Dataclass representing the request on creating a new session.""" + + model_id: str = None + + +@dataclass +class PostMessageRequest(DataClassSerializable): + """Dataclass representing the request on posting a new message.""" + + session: Session = field(default_factory=Session) + thread: Thread = field(default_factory=Thread) + message: Message = field(default_factory=Message) + answer: Message = field(default_factory=Message) + + +class ChunkResponseStatus(str, metaclass=ExtendedEnumMeta): + SUCCESS = "success" + ERROR = "error" + + +@dataclass(repr=False) +class ChunkResponse(DataClassSerializable): + """Class representing server response. + + Attributes + ---------- + status: str + Response status. + message: (str, optional). Defaults to "". + The response message. + payload: (Dict, optional). Defaults to None. + The payload of the response. + """ + + status: str = None + message: str = None + payload: Dict = None + + +class AquaPlaygroundSessionHandler(AquaAPIhandler): + """ + Handles the management and interaction with Playground sessions. + + Methods + ------- + get(self, id="") + Retrieves a list of sessions or a specific session by ID. + post(self, *args, **kwargs) + Creates a new playground session. + read(self, id: str) + Reads the detailed information of a specific Playground session. + list(self) + Lists all the playground sessions. + + Raises + ------ + HTTPError: For various failure scenarios such as invalid input format, missing data, etc. + """ + + def get(self, id=""): + """ + Retrieve a list of all sessions or a specific session by its ID. + + Parameters + ---------- + id: (str, optional) + The ID of the session to retrieve. Defaults to an empty string, + which implies fetching all sessions. + + Returns + ------- + The session data or a list of sessions. + """ + if not id: + return self.list() + return self.read(id) + + def read(self, id: str): + """Read the information of a Playground session.""" + try: + return self.finish(SessionApp().get(id=id, include_threads=True)) + except Exception as ex: + raise HTTPError(500, str(ex)) + + def list(self): + """List playground sessions.""" + try: + return self.finish(SessionApp().list()) + except Exception as ex: + raise HTTPError(500, str(ex)) + + def post(self, *args, **kwargs): + """ + Creates a new Playground session by model ID. + The session data is extracted from the JSON body of the request. + If session for given model ID exists, then the existing session will be returned. + + Raises + ------ + HTTPError + If the input data is invalid or missing, or if an error occurs during session creation. + """ + try: + input_data = self.get_json_body() + except Exception as ex: + raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) + + if not input_data: + raise HTTPError(400, Errors.NO_INPUT_DATA) + + new_session_request = NewSessionRequest.from_dict( + batch_convert_case(input_data, to_fmt="snake") + ) + + if not new_session_request.model_id: + raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("modelId")) + + try: + self.finish(SessionApp().create(model_id=new_session_request.model_id)) + except Exception as ex: + raise HTTPError(500, str(ex)) + + +class AquaPlaygroundThreadHandler(AquaAPIhandler): + """ + Handles the management and interaction with Playground threads. + + Methods + ------- + get(self, thread_id="") + Retrieves a list of threads or a specific thread by ID. + post(self, *args, **kwargs) + Creates a new playground thread. + delete(self) + Deletes (soft delete) a specified thread by ID. + read(self, thread_id: str) + Reads the detailed information of a specific Playground thread. + list(self) + Lists all the threads in a session. + + Raises + ------ + HTTPError: For various failure scenarios such as invalid input format, missing data, etc. + """ + + def get(self, thread_id: str = ""): + """ + Retrieve a list of all threads or a specific thread by its ID. + + Parameters + ---------- + thread_id (str, optional) + The ID of the thread to retrieve. Defaults to an empty string, + which implies fetching all threads. + + Returns + ------- + The thread data or a list of threads. + """ + if not thread_id: + return self.list() + return self.read(thread_id) + + def read(self, thread_id: str): + """Read the information of a playground thread.""" + try: + return self.finish( + ThreadApp().get(thread_id=thread_id, include_messages=True) + ) + except Exception as ex: + raise HTTPError(500, str(ex)) + + def list(self): + """ + List playground threads. + + Args + ---- + session_id: str + The ID of the session to list associated threads. + """ + session_id = self.get_argument("session_id") + try: + return self.finish(ThreadApp().list(session_id=session_id)) + except Exception as ex: + raise HTTPError(500, str(ex)) + + async def post(self, *args, **kwargs): + """ + Adds a new message into the Playground thread. + If the thread doesn't exist yet, then it will be created. + """ + self.set_header("Content-Type", "application/json") + self.set_header("Transfer-Encoding", "chunked") + + try: + request_data: PostMessageRequest = PostMessageRequest.from_dict( + batch_convert_case(self.get_json_body(), to_fmt="snake") + ) + except Exception as ex: + logger.debug(ex) + error_msg = ChunkResponse( + status=ChunkResponseStatus.ERROR, + message=Errors.INVALID_INPUT_DATA_FORMAT, + ).to_json() + self.write(f"{len(error_msg):X}\r\n{error_msg}\r\n0\r\n\r\n") + await self.flush() + return + + thread_app = ThreadApp() + message_app = MessageApp() + # Register all entities in the DB + try: + # Add thread into DB if it not exists + new_thread = thread_app.create( + request_data.session.session_id, + name=request_data.thread.name, + thread_id=request_data.thread.id, + ) + + # Add user message into DB + new_user_message = message_app.create( + thread_id=new_thread.id, + content=request_data.message.content, + message_id=request_data.message.message_id, + parent_message_id=request_data.message.parent_message_id, + role=request_data.message.role, + rate=request_data.message.rate, + payload=request_data.message.payload, + model_params=request_data.message.model_params.to_dict(), + ) + + # Add system answer into DB + new_system_message = message_app.create( + thread_id=new_thread.id, + content=request_data.answer.content, + message_id=request_data.answer.message_id, + parent_message_id=request_data.answer.parent_message_id, + role=request_data.answer.role, + rate=request_data.answer.rate, + payload=request_data.answer.payload, + model_params=request_data.answer.model_params.to_dict(), + ) + + # Send initial OK status to the client + initial_response = ChunkResponse( + status=ChunkResponseStatus.SUCCESS, message="" + ).to_json() + + self.write(f"{len(initial_response):X}\r\n{initial_response}\r\n") + await self.flush() + except Exception as ex: + logger.debug(ex) + error_msg = ChunkResponse( + status=ChunkResponseStatus.ERROR, message=str(ex) + ).to_json() + self.write(f"{len(error_msg):X}\r\n{error_msg}\r\n0\r\n\r\n") + await self.flush() + return + + try: + model_response_text = "" + model_invoker = ModelInvoker( + endpoint=f"{request_data.session.model.endpoint.rstrip('/')}/predict", + prompt=request_data.message.content, + params=request_data.message.model_params.to_dict(), + ) + for item in model_invoker.invoke(): + if item.startswith("data"): + if "[DONE]" in item: + continue + item_json = json.loads(item[6:]) + else: + item_json = json.loads(item) + + if item_json.get("object") == "error": + # {"object":"error","message":"top_k must be -1 (disable), or at least 1, got 0.","type":"invalid_request_error","param":null,"code":null} + raise HTTPError(400, item_json.get("message")) + else: + chunk = ChunkResponse( + status=ChunkResponseStatus.SUCCESS, + message="", + payload=item_json["choices"][0]["text"], + ).to_json() + + model_response_text += item_json["choices"][0]["text"] + + # update system message in DB + message_app.update( + message_id=new_system_message.message_id, + content=model_response_text, + rate=new_system_message.rate, + status=new_system_message.status, + ) + + self.write(f"{len(chunk):X}\r\n{chunk}\r\n") + await self.flush() + + # Indicate the end of the response + self.write("0\r\n\r\n") + await self.flush() + except Exception as ex: + logger.debug(ex) + # Handle unexpected errors + error_msg = ChunkResponse( + status=ChunkResponseStatus.ERROR, message=str(ex) + ).to_json() + self.write(f"{len(error_msg):X}\r\n{error_msg}\r\n0\r\n\r\n") + await self.flush() + + def delete(self, *args, **kwargs): + """ + Deletes (soft delete) the thread by ID. + + Args + ---- + thread_id: str + The ID of the thread to be deleted. + """ + thread_id = self.get_argument("threadId") + if not thread_id: + raise HTTPError( + 400, Errors.Errors.MISSING_REQUIRED_PARAMETER.format("threadId") + ) + + # Only soft deleting with updating a status field. + try: + ThreadApp().deactivate(thread_id=thread_id) + self.set_status(204) # no content + self.finish() + except Exception as ex: + raise HTTPError(500, str(ex)) + + +__handlers__ = [ + ("playground/session/?([^/]*)", AquaPlaygroundSessionHandler), + ("playground/thread/?([^/]*)", AquaPlaygroundThreadHandler), +] diff --git a/ads/aqua/playground/__init__.py b/ads/aqua/playground/__init__.py new file mode 100644 index 000000000..9eadd9943 --- /dev/null +++ b/ads/aqua/playground/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*-- + +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ diff --git a/ads/aqua/playground/const.py b/ads/aqua/playground/const.py new file mode 100644 index 000000000..019b8f3a8 --- /dev/null +++ b/ads/aqua/playground/const.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*-- + +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +from ads.common.extended_enum import ExtendedEnumMeta + + +class Status(str, metaclass=ExtendedEnumMeta): + """Enumeration for the status of various entities like records, sessions, messages, etc.""" + + ACTIVE = "active" + ARCHIVED = "archived" + PENDING = "pending" + FAILED = "failed" + + +class MessageRate(str, metaclass=ExtendedEnumMeta): + """Enumeration for message rating.""" + + DEFAULT = 0 + LIKE = 1 + DISLIKE = -1 + + +class MessageRole(str, metaclass=ExtendedEnumMeta): + """Enumeration for message roles.""" + + USER = "user" + SYSTEM = "system" + + +class ObjectType(str, metaclass=ExtendedEnumMeta): + """The status of the record.""" + + SESSION = "session" + THREAD = "thread" + MESSAGE = "message" diff --git a/ads/aqua/playground/db_context.py b/ads/aqua/playground/db_context.py new file mode 100644 index 000000000..6b5a2a295 --- /dev/null +++ b/ads/aqua/playground/db_context.py @@ -0,0 +1,496 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*-- + +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +import os +from datetime import datetime +from typing import Dict, List, Optional + +from sqlalchemy import create_engine +from sqlalchemy.orm import scoped_session, sessionmaker + +from ads.aqua.playground.const import MessageRate, MessageRole, ObjectType, Status +from ads.aqua.playground.db_models import Base, MessageModel, SessionModel, ThreadModel +from ads.aqua.playground.entities import Message, Session, Thread +from ads.aqua.playground.errors import ( + MessageNotFoundError, + SessionNotFoundError, + ThreadNotFoundError, +) + +DATABASE_NAME = "playground.db" +DATABASE_PATH = os.environ.get("AQUA_PLAYGROUND") or os.path.join( + os.path.abspath(os.path.expanduser("~")), ".aqua" +) + + +OBJECT_MODEL_MAP = { + ObjectType.SESSION: SessionModel, + ObjectType.THREAD: ThreadModel, + ObjectType.MESSAGE: MessageModel, +} + + +class DBContext: + """ + A class to handle database operations for Playground sessions, threads, and messages. + + Attributes + ---------- + engine (Engine): SQLAlchemy engine instance for database connection. + DBSession (sessionmaker): Factory for creating new SQLAlchemy session objects. + """ + + def __init__(self, db_url: str, echo: bool = False): + """ + Initializes the database context with a given database URL. + + Parameters + ---------- + db_url (str): A database URL that indicates database dialect and connection arguments. + echo: (bool, optional). Whether to show the debug information or not. + """ + self.engine = create_engine(db_url, echo=echo, future=True) + self.DBSession = scoped_session( + sessionmaker(bind=self.engine, future=True, expire_on_commit=True) + ) + self.init_db() + + def init_db(self): + Base.metadata.create_all(self.engine) + + def get_sessions( + self, only_active: bool = True, include_threads: bool = False + ) -> List[Session]: + """ + Retrieves all threads for a specific playground db_session. + + Parameters + ---------- + only_active: (bool, optional). Defaults to True. + Whether to load all or only active sessions. + include_threads: (bool, optional). Defaults to False. + Whether to include the associated threads or not. + + Returns + ------- + List[SessionModel] + A list of playground sessions. + """ + with self.DBSession() as db_session: + db_session.expire_on_commit + query = db_session.query(SessionModel) + + if only_active: + query.filter_by(status=Status.ACTIVE) + + return [ + Session.from_db_model(session_model, include_threads=include_threads) + for session_model in query.all() + ] + + def get_session( + self, + session_id: str = None, + model_id: str = None, + include_threads: bool = False, + ) -> Optional[Session]: + """ + Retrieves a playground session by its session ID or model ID. + + Parameters + ---------- + session_id: (str, optional) + The unique session identifier for the playground db_session. + model_id: (str, optional) + The unique model identifier for the playground db_session. + include_threads: (bool, optional). Defaults to False. + Whether to include the associated threads or not. + + Returns + ------- + Optional[SessionModel] + The retrieved playground session if found, else None. + + Raises + ------ + ValueError + If neither session_id nor model_id was provided. + SessionNotFoundError + If session with the provided ID doesn't exist. + """ + + if not (session_id or model_id): + raise ValueError("Either session ID or model ID need to be provided.") + + with self.DBSession() as db_session: + query = db_session.query(SessionModel) + if session_id: + session_model = query.filter_by(id=session_id).first() + else: + session_model = query.filter_by(model_id=model_id).first() + + if not session_model: + raise SessionNotFoundError(session_id or model_id) + + return Session.from_db_model(session_model, include_threads=include_threads) + + def add_session( + self, + model_id: str, + model_name: str, + model_endpoint: str, + session_id: str = None, + status: str = Status.ACTIVE, + ) -> Session: + """ + Adds a new playground session to the database. + + Parameters + ---------- + model_id (str): The unique model identifier for the new db_session. + model_name (str): The name of the model. + model_endpoint (str): The model endpoint. + session_id (str, optional): The ID of the session. + status (str, optional): The status of the db_session. + + Returns + ------- + SessionModel + The newly created playground db_session. + """ + with self.DBSession() as db_session: + new_session = SessionModel( + id=session_id, + model_id=model_id, + model_name=model_name, + model_endpoint=model_endpoint, + created=datetime.now(), + updated=datetime.now(), + status=status, + ) + db_session.add(new_session) + db_session.commit() + + return Session.from_db_model(new_session) + + def get_session_threads( + self, + session_id: str, + only_active: bool = True, + ) -> List[Thread]: + """ + Retrieves all threads for a specific playground db_session. + + Parameters + ---------- + session_id: (str) + The ID of the session for which to retrieve threads. + only_active: (bool, optional). Defaults to True. + Whether to load all or only active sessions. + + Returns + ------- + List[ThreadModel] + A list of playground threads associated with the db_session. + """ + with self.DBSession() as db_session: + query = db_session.query(ThreadModel).filter_by( + playground_session_id=session_id + ) + + if only_active: + query.filter_by(status=Status.ACTIVE) + + return [Thread.from_db_model(thread_model) for thread_model in query.all()] + + def get_thread( + self, + thread_id: str = None, + include_messages: bool = False, + ) -> Optional[Thread]: + """ + Retrieves a playground thread by its ID. + + Parameters + ---------- + thread_id: str + The unique thread identifier. + include_messages: (bool, optional). Defaults to False. + Whether to include the associated messages or not. + + Returns + ------- + Optional[ThreadModel] + The retrieved playground thread if found, else None. + + Raises + ------ + ThreadNotFoundError + If thread with provided ID doesn't exist. + """ + with self.DBSession() as db_session: + thread_model = db_session.query(ThreadModel).filter_by(id=thread_id).first() + + if not thread_model: + raise ThreadNotFoundError(thread_id=thread_id) + + return Thread.from_db_model(thread_model, include_messages=include_messages) + + def add_thread( + self, + session_id: str, + name: str, + thread_id: str = None, + status: str = Status.ACTIVE, + ) -> Thread: + """ + Adds a new thread to an existing playground db_session. + + Parameters + ---------- + session_id (str): The ID of the session to which the thread belongs. + name (str): The name of the thread. + thread_id (str, optional): The thread ID. + status (str, optional): The status of the thread. Defaults to active. + + Returns + ------- + Thread + The newly created playground thread. + """ + with self.DBSession() as db_session: + new_thread = ThreadModel( + id=thread_id, + playground_session_id=session_id, + name=name, + created=datetime.now(), + updated=datetime.now(), + status=status, + ) + db_session.add(new_thread) + db_session.commit() + + return Thread.from_db_model(new_thread) + + def update_thread( + self, thread_id: str, name: str = None, status: str = None + ) -> Optional[Thread]: + """ + Updates a playground thread by its ID. + + Parameters + ---------- + thread_id: str + The unique thread identifier. + + Returns + ------- + Optional[Message] + Updated thread or None. + + Raises + ------ + ThreadNotFoundError + If message with provided ID doesn't exist. + """ + with self.DBSession() as db_session: + thread_model = db_session.query(ThreadModel).filter_by(id=thread_id).first() + + if not thread_model: + raise ThreadNotFoundError(thread_id=thread_id) + + thread_model.name = name or thread_model.name + thread_model.status = status or thread_model.status + thread_model.updated = datetime.now() + + db_session.commit() + return Thread.from_db_model(thread_model) + + def get_thread_messages( + self, thread_id: str, only_active: bool = True + ) -> List[Message]: + """ + Retrieves all messages in a specific playground thread. + + Parameters + ---------- + thread_id (str): The ID of the thread for which to retrieve messages. + only_active: (bool, optional). Defaults to True. + Whether to load all or only active messages. + + Returns + ------- + List[Message] + A list of playground messages in the thread. + """ + with self.DBSession() as db_session: + query = db_session.query(MessageModel).filter_by( + playground_thread_id=thread_id + ) + + if only_active: + query.filter_by(status=Status.ACTIVE) + + return [ + Message.from_db_model(message_model) for message_model in query.all() + ] + + def add_message_to_thread( + self, + thread_id: str, + content: str, + message_id: str = None, + parent_message_id: str = None, + role: str = MessageRole.USER, + rate: int = MessageRate.DEFAULT, + payload: Dict = None, + model_params: Dict = None, + status: str = Status.ACTIVE, + ) -> Message: + """ + Adds a message to a specific playground thread. + + Parameters + ---------- + thread_id (str): The ID of the thread to which the message will be added. + content (str): The text content of the message. + message_id (str, optional): The message ID. + parent_message_id (str, optional): The parent message. + payload (Dict, optional): The model payload. + model_params (Dict, optional): The model parameters. + status (str): The status of the message. + role (str): The role of the message (e.g., 'user', 'system'). + + Returns + ------- + Message + The newly created playground message. + """ + with self.DBSession() as db_session: + new_message = MessageModel( + id=message_id, + playground_thread_id=thread_id, + parent_id=parent_message_id, + content=content or "", + created=datetime.now(), + updated=datetime.now(), + status=status or Status.ACTIVE, + role=role or MessageRole.USER, + rate=rate or MessageRate.DEFAULT, + payload=payload or {}, + model_params=model_params or {}, + ) + db_session.add(new_message) + db_session.commit() + return Message.from_db_model(new_message) + + def get_message( + self, + message_id: str = None, + ) -> Optional[Message]: + """ + Retrieves a playground message by its ID. + + Parameters + ---------- + message_id: str + The unique message identifier. + + Returns + ------- + Optional[Message] + The retrieved playground message if found, else None. + + Raises + ------ + MessageNotFoundError + If message with provided ID doesn't exist. + """ + with self.DBSession() as db_session: + message_model = ( + db_session.query(MessageModel).filter_by(id=message_id).first() + ) + + if not message_model: + raise MessageNotFoundError(message_id == message_id) + + return Message.from_db_model(message_model) + + def update_message( + self, message_id: str, content: str = None, status: str = None, rate: int = None + ) -> Optional[Message]: + """ + Updates a playground message by its ID. + + Parameters + ---------- + message_id: str + The unique message identifier. + + Returns + ------- + Optional[Message] + The retrieved playground message if found, else None. + + Raises + ------ + MessageNotFoundError + If message with provided ID doesn't exist. + """ + with self.DBSession() as db_session: + message_model = ( + db_session.query(MessageModel).filter_by(id=message_id).first() + ) + + if not message_model: + raise MessageNotFoundError(message_id=message_id) + + message_model.content = content or message_model.content + message_model.status = status or message_model.status + message_model.rate = rate or message_model.rate + message_model.updated = datetime.now() + + db_session.commit() + return Message.from_db_model(message_model) + + def update_status(self, object_type: str, object_id: str, status: str): + """ + Update the status of a session, thread, or message. + + Parameters + ---------- + object_type (str): The type of object to update ('session', 'thread', or 'message'). + object_id (str): The ID of the object to update. + status (str): The new status to set for the object. + """ + with self.DBSession() as db_session: + if object_type in OBJECT_MODEL_MAP: + db_session.query(OBJECT_MODEL_MAP[object_type]).filter_by( + id=object_id + ).update({"status": status, "updated": datetime.now()}) + db_session.commit() + + def delete_object(self, object_type: str, object_id: str): + """ + Delete a session, thread, or message from the database. + + Parameters + ---------- + object_type (str): The type of object to delete ('session', 'thread', or 'message'). + object_id (str): The ID of the object to delete. + """ + with self.DBSession() as db_session: + if object_type in OBJECT_MODEL_MAP: + db_session.query(OBJECT_MODEL_MAP[object_type]).filter_by( + id=object_id + ).delete() + db_session.commit() + + +###################### INIT DB CONTEXT###################################### +os.makedirs(DATABASE_PATH, exist_ok=True) +db_context = DBContext(db_url=f"sqlite:///{DATABASE_PATH}/{DATABASE_NAME}") +############################################################################ diff --git a/ads/aqua/playground/db_models.py b/ads/aqua/playground/db_models.py new file mode 100644 index 000000000..758622795 --- /dev/null +++ b/ads/aqua/playground/db_models.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*-- + +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +from datetime import datetime +from typing import List +import uuid + +from sqlalchemy import JSON, TIMESTAMP, ForeignKey, Integer, String, Text +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + + +class Base(DeclarativeBase): + pass + + +class SessionModel(Base): + """ + Represents a playground session table in the database. + + + Attributes: + id (Mapped[str]): The primary key, the UUID. + model_id (Mapped[str]): The id of the model. + model_name: (Mapped[str]): The name of the model. + model_endpoint: (Mapped[str]): The model endpoint. + created (Mapped[datetime]): The creating timestamp of the session. + updated (Mapped[datetime]): The updating timestamp of the session. + status (Mapped[str]): The status of the session. + """ + + __tablename__ = "playground_session" + + id: Mapped[str] = mapped_column( + String(36), primary_key=True, default=lambda: str(uuid.uuid4()) + ) + model_id: Mapped[str] = mapped_column(String, unique=True) + model_name: Mapped[str] = mapped_column(String) + model_endpoint: Mapped[str] = mapped_column(String) + created: Mapped[datetime] = mapped_column(TIMESTAMP) + updated: Mapped[datetime] = mapped_column(TIMESTAMP) + status: Mapped[str] = mapped_column(String) + + threads: Mapped[List["ThreadModel"]] = relationship( + "ThreadModel", back_populates="session", cascade="all, delete-orphan" + ) + settings: Mapped["SessionSettingsModel"] = relationship( + "SessionSettingsModel", back_populates="session", uselist=False + ) + + +class SessionSettingsModel(Base): + """ + Represents a session configuration table in the database. + + Attributes: + id (Mapped[str]): The primary key, the UUID. + playground_session_id (Mapped[str]): Foreign key linking to the SessionModel table. + model_params: (Mapped[Dict]): The model related parameters stored as JSON. + - max_tokens: int + - temperature: float + - top_k: int + - top_p: float + """ + + __tablename__ = "playground_session_settings" + + id: Mapped[str] = mapped_column( + String(36), primary_key=True, default=lambda: str(uuid.uuid4()) + ) + playground_session_id: Mapped[str] = mapped_column( + String, ForeignKey("playground_session.id"), unique=True + ) + model_params: Mapped[dict] = mapped_column(JSON) + + session: Mapped["SessionModel"] = relationship( + "SessionModel", back_populates="settings", uselist=False + ) + + +class ThreadModel(Base): + """ + Represents a thread table in the database. + + + Attributes: + id (Mapped[str]): The primary key, the UUID. + playground_session_id (Mapped[str]): Foreign key linking to the SessionModel table. + name (Mapped[str]): The name of the thread. + created (Mapped[datetime]): The creating timestamp of the thread. + updated (Mapped[datetime]): The updating timestamp of the thread. + status (Mapped[str]): The status of the thread. + """ + + __tablename__ = "playground_thread" + + id: Mapped[str] = mapped_column( + String(36), primary_key=True, default=lambda: str(uuid.uuid4()) + ) + playground_session_id: Mapped[str] = mapped_column( + String, ForeignKey("playground_session.id") + ) + name: Mapped[str] = mapped_column(String) + created: Mapped[datetime] = mapped_column(TIMESTAMP) + updated: Mapped[datetime] = mapped_column(TIMESTAMP) + status: Mapped[str] = mapped_column(String) + + session: Mapped["SessionModel"] = relationship( + "SessionModel", back_populates="threads" + ) + messages: Mapped[List["MessageModel"]] = relationship( + "MessageModel", back_populates="thread", cascade="all, delete-orphan" + ) + + +class MessageModel(Base): + """ + Represents a message table in the database. + + Attributes: + id (Mapped[str]): The primary key, the UUID. + parent_id (Mapped[str]): The parent message. + playground_thread_id (Mapped [str]): Foreign key linking to the ThreadModel table. + content (Mapped[str]): The message text. + payload (Mapped[dict]): The payload info. + model_params (Mapped[dict]): The model parameters. + created (Mapped[datetime]): The timestamp of the request. + updated (Mapped[datetime]): The timestamp of the request. + status (Mapped[str]): The status of the request. + rate (Mapped[int]): The rate of the response. [-1, 0, 1]. + role (Mapped[str]): The role of the message. [system, user] + """ + + __tablename__ = "playground_message" + id: Mapped[str] = mapped_column( + String(36), primary_key=True, default=lambda: str(uuid.uuid4()) + ) + parent_id: Mapped[str] = mapped_column(String, nullable=True) + playground_thread_id: Mapped[str] = mapped_column( + String, ForeignKey("playground_thread.id") + ) + content: Mapped[str] = mapped_column(Text) + payload: Mapped[dict] = mapped_column(JSON) + model_params: Mapped[dict] = mapped_column(JSON) + created: Mapped[datetime] = mapped_column(TIMESTAMP) + updated: Mapped[datetime] = mapped_column(TIMESTAMP) + status: Mapped[str] = mapped_column(String) + rate: Mapped[int] = mapped_column(Integer) + role: Mapped[str] = mapped_column(String) + + thread: Mapped["ThreadModel"] = relationship( + "ThreadModel", back_populates="messages" + ) diff --git a/ads/aqua/playground/entities.py b/ads/aqua/playground/entities.py new file mode 100644 index 000000000..c50b8927e --- /dev/null +++ b/ads/aqua/playground/entities.py @@ -0,0 +1,387 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*-- +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + + +import datetime +import uuid +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Dict, List, Optional + +from ads.aqua.playground.const import MessageRate, Status, MessageRole +from ads.aqua.playground.db_models import MessageModel, SessionModel, ThreadModel +from ads.common.serializer import DataClassSerializable + + +@dataclass +class SearchId: + """ + Class helper to encode a search id. + """ + + model_id: str = None + record_id: str = None + + @classmethod + def parse(cls, id: str) -> "SearchId": + """ + The method differentiates between record ID and model ID based on the content of the ID. + If the ID contains 'ocid', it is treated as model ID. + + Parameters + ---------- + id: str + Input ID to parse. + """ + result = cls() + + if not id: + raise ValueError( + "Incorrect id was provided. " + "It should either be a model ID or a record ID." + ) + + if "ocid" in (id or ""): + result.model_id = id + else: + result.record_id = id + + return result + + +@dataclass(repr=False) +class VLLModelParams(DataClassSerializable): + """ + Parameters specific to Versatile Large Language Model. + + Attributes + ---------- + model: (str, optional) + Model name. + max_tokens: (int, optional) + Maximum number of tokens to generate. + temperature: (float, optional) + Controls randomness in generation. + top_p: (float, optional) + Top probability mass. + frequency_penalty: (float, optional) + Penalizes new tokens based on their existing frequency. + presence_penalty: (float, optional) + Penalizes new tokens based on their presence. + top_k: (int, optional) + Keeps only top k candidates at each generation step. + echo: (bool, optional) + Echoes the input text in the output. + logprobs: (int, optional) + Number of log probabilities to return. + use_beam_search: (bool, optional) + Whether to use beam search for generation. + ignore_eos: (bool, optional) + Whether to ignore end-of-sequence tokens during generation. + n: (int, optional) + Number of output sequences to return for a given prompt. + best_of: (int, optional) + Controls how many completions to generate for each prompt. + stop: (List[str], optional) + Stop words or phrases to use when generating. + stream: (bool, optional) + Indicates whether the response should be streamed. + min_p: (float, optional) + Minimum probability threshold for token selection. + """ + + model: Optional[str] = "/opt/ds/model/deployed_model" + max_tokens: Optional[int] = 2048 + temperature: Optional[float] = 0.7 + top_p: Optional[float] = 1.0 + frequency_penalty: Optional[float] = 0.0 + presence_penalty: Optional[float] = 0.0 + top_k: Optional[int] = 1 + echo: Optional[bool] = False + logprobs: Optional[int] = None + use_beam_search: Optional[bool] = False + ignore_eos: Optional[bool] = False + n: Optional[int] = 1 + best_of: Optional[int] = 1 + stop: Optional[List[str]] = field(default_factory=list) + stream: Optional[bool] = True + min_p: Optional[float] = 0.0 + + def __post_init__(self): + if not self.model: + self.model = "/opt/ds/model/deployed_model" + if self.stream is None: + self.stream = True + + +@dataclass(repr=False) +class Message(DataClassSerializable): + """ + Data class representing a message in a thread. + + Attributes + ---------- + id: str + Unique identifier of the message. + parent_id: int + Identifier of the parent message. + thread_id: int + Identifier of the thread to which the message belongs. + session_id: int + Identifier of the session to which the message thread belongs. + content: str + The actual content of the message. + payload: Dict + Additional payload associated with the message. + status: str + Status of the message. Can be `active` or `archived`. + rate: int + Rating of the message, based on the MessageRate enum. + role: str + Role of the message, based on the MessageRole enum. + created: datetime.datetime + Creation timestamp of the message. + updated: datetime.datetime + Updating timestamp of the message. + answers: List[Message] + List of system messages for the user message. + """ + + message_id: str = None + parent_message_id: str = None + session_id: str = None + thread_id: str = None + content: str = None + payload: Dict = None + model_params: VLLModelParams = field(default_factory=VLLModelParams) + status: str = Status.ACTIVE + rate: int = MessageRate.DEFAULT + role: str = None + created: datetime.datetime = None + updated: datetime.datetime = None + answers: List["Message"] = field(default_factory=list) + + def __post_init__(self): + if not self.message_id: + self.message_id = str(uuid.uuid4()) + if not self.status: + self.status = Status.ACTIVE + if not self.rate: + self.rate = MessageRate.DEFAULT + if not self.role: + self.role = MessageRole.USER + + @classmethod + def from_db_model(cls, data: MessageModel) -> "Message": + """ + Creates Message instance from MessageModel object. + + Parameters + ---------- + data: MessageModel + The DB representation of the message. + + Returns + ------- + Message + + The instance of the playground message. + """ + return cls( + message_id=data.id, + parent_message_id=data.parent_id, + content=data.content, + thread_id=data.playground_thread_id, + session_id=data.thread.playground_session_id, + created=data.created, + updated=data.updated, + status=data.status, + rate=data.rate, + role=data.role, + payload=data.payload, + model_params=VLLModelParams.from_dict(data.model_params), + ) + + +@dataclass(repr=False) +class Thread(DataClassSerializable): + """ + Data class representing a thread in a session. + + Attributes + ---------- + id: str + Unique identifier of the thread. + name: str + Name of the thread. + session_id: str + Identifier of the session to which the thread belongs. + created: datetime.datetime + Creation timestamp of the thread. + updated: datetime.datetime + Updating timestamp of the thread. + status: str + Status of the message. Can be `active` or `archived`. + messages: List[Message] + List of messages in the thread. + """ + + id: str = None + name: str = None + session_id: str = None + created: datetime.datetime = None + updated: datetime.datetime = None + status: str = Status.ACTIVE + messages: List[Message] = field(default_factory=list) + + def __post_init__(self): + if not self.id: + self.id = str(uuid.uuid4()) + if not self.status: + self.status = Status.ACTIVE + + @classmethod + def from_db_model( + cls, data: ThreadModel, include_messages: bool = False + ) -> "Thread": + """ + Creates Thread instance from ThreadModel object. + + Parameters + ---------- + data: ThreadModel + The DB representation of the thread. + include_messages: (bool, optional) + Include the associated messages into the result. + + Returns + ------- + Thread + The instance of the playground thread. + """ + obj = cls( + id=data.id, + name=data.name, + session_id=data.playground_session_id, + created=data.created, + updated=data.updated, + status=data.status, + ) + + if include_messages and data.messages: + # Assign the list of answers to the parent messages + messages = [ + Message.from_db_model(data=message_model) + for message_model in data.messages + ] + + # Filter and return only root messages (those with parent_id == None) + result_messages = [msg for msg in messages if not msg.parent_message_id] + + # Group messages by parent ID + message_map = defaultdict(list) + for msg in messages: + message_map[msg.parent_message_id].append(msg) + + # Add child messages + for msg in result_messages: + msg.answers = message_map[msg.message_id] + + obj.messages = result_messages + + return obj + + +@dataclass(repr=False) +class ModelInfo(DataClassSerializable): + """ + Data class representing model deployment details. + + Attributes + ---------- + id: (str, optional) + The model deployment ID. + name: (str, optional) + The model deployment name. + endpoint: (str, optional) + The model deployment endpoint. + """ + + id: Optional[str] = None + name: Optional[str] = None + endpoint: Optional[str] = None + + +@dataclass(repr=False) +class Session(DataClassSerializable): + + """ + Data class representing a session in the Aqua Playground. + + Attributes + ---------- + id: str + Unique identifier of the session. + created: datetime.datetime + Creation timestamp of the session. + updated: datetime.datetime + Updating timestamp of the session. + status: str + Status of the message. Can be `active` or `archived`. + threads: List[Thread] + List of threads in the session. + model: ModelInfo + Model deployment details. + """ + + session_id: str = None + created: datetime.datetime = None + updated: datetime.datetime = None + status: str = Status.ACTIVE + threads: List[Thread] = field(default_factory=list) + model: ModelInfo = field(default_factory=ModelInfo) + + def __post_init__(self): + if not self.session_id: + self.session_id = str(uuid.uuid4()) + if not self.status: + self.status = Status.ACTIVE + + @classmethod + def from_db_model( + cls, data: SessionModel, include_threads: bool = False + ) -> "Session": + """ + Creates Session instance form SessionModel object. + + Parameters + ---------- + data: SessionModel + The DB representation of the session. + include_threads: (bool, optional) + Whether to include the threads into the result. + + Returns + ------- + Session + The instance of the playground session. + """ + + obj = cls( + session_id=data.id, + created=data.created, + updated=data.updated, + status=data.status, + model=ModelInfo( + id=data.model_id, name=data.model_name, endpoint=data.model_endpoint + ), + ) + + if include_threads and data.threads: + obj.threads = [ + Thread.from_db_model(thread_model) for thread_model in data.threads + ] + + return obj diff --git a/ads/aqua/playground/errors.py b/ads/aqua/playground/errors.py new file mode 100644 index 000000000..5d1c86019 --- /dev/null +++ b/ads/aqua/playground/errors.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*-- +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + + +class ModelDeploymentNotFoundError(Exception): + """Exception raised when the model deployment with the given ID cannot be found.""" + + def __init__(self, model_id: str): + super().__init__( + f"The model deployment with ID: `{model_id}` cannot be found. " + "Please check if the model deployment exists." + ) + + +class SessionNotFoundError(Exception): + """Exception raised when the session with the given ID cannot be found.""" + + def __init__(self, search_id: str): + super().__init__( + f"The session with provided ID: `{search_id}` cannot be found. " + "Please ensure that the session with given ID exists." + ) + + +class ThreadNotFoundError(Exception): + """Exception raised when the thread with the given ID cannot be found.""" + + def __init__(self, thread_id: str): + super().__init__( + f"The thread with provided ID: `{thread_id}` cannot be found. " + "Please ensure that the thread with given ID exists." + ) + + +class MessageNotFoundError(Exception): + """Exception raised when the message with the given ID cannot be found.""" + + def __init__(self, message_id: str): + super().__init__( + f"The message with provided ID: `{message_id}` cannot be found. " + "Please ensure that the message with given ID exists." + ) diff --git a/ads/aqua/playground/model_invoker.py b/ads/aqua/playground/model_invoker.py new file mode 100644 index 000000000..890441291 --- /dev/null +++ b/ads/aqua/playground/model_invoker.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*-- + +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +import json + +import requests +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry + +from ads.common.auth import default_signer + +DEFAULT_RETRIES = 3 + +# The amount of time to wait between retry attempts for failed request. +DEFAULT_BACKOFF_FACTOR = 0.3 + + +class ModelInvoker: + """ + A class to invoke models via HTTP requests with retry logic. + + + Attributes + ---------- + endpoint (str): The URL endpoint to send the request. + prompt (str): The prompt to send in the request body. + params (dict): Additional parameters for the model. + retries (int): The number of retry attempts for the request. + backoff_factor (float): The factor to determine the delay between retries. + """ + + def __init__( + self, + endpoint: str, + prompt: str, + params: dict, + retries: int = DEFAULT_RETRIES, + backoff_factor: float = DEFAULT_BACKOFF_FACTOR, + auth=None, + ): + self.auth = auth or default_signer() + self.endpoint = endpoint + self.prompt = prompt + self.params = params + self.retries = retries + self.backoff_factor = backoff_factor + self.session = self._create_session_with_retries(retries, backoff_factor) + + def _create_session_with_retries( + self, retries: int, backoff_factor: float + ) -> requests.Session: + """ + Creates a requests Session with a mounted HTTPAdapter for retry logic. + + Returns + ------- + session (requests.Session): The configured session for HTTP requests. + """ + session = requests.Session() + retry_strategy = Retry( + total=retries, + status_forcelist=[429, 500, 502, 503, 504], + backoff_factor=backoff_factor, + ) + adapter = HTTPAdapter(max_retries=retry_strategy) + session.mount("https://", adapter) + return session + + def invoke(self): + """ + The generator that invokes the model endpoint with retries and streams the response. + + Yields + ------ + line (str): A line of the streamed response. + """ + headers = { + "Content-Type": "application/json", + "enable-streaming": "true", + } + + # print({"prompt": self.prompt, **self.params}) + + try: + response = self.session.post( + self.endpoint, + auth=self.auth["signer"], + headers=headers, + json={"prompt": self.prompt, **self.params}, + stream=True, + ) + + response.raise_for_status() + + for line in response.iter_lines(): + if line: + yield line.decode("utf-8") + + except requests.RequestException as e: + yield json.dumps({"object": "error", "message": str(e)}) diff --git a/ads/aqua/playground/playground.py b/ads/aqua/playground/playground.py new file mode 100644 index 000000000..35d153add --- /dev/null +++ b/ads/aqua/playground/playground.py @@ -0,0 +1,474 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*-- +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +""" +This Python module is part of Oracle's Artificial Intelligence QA Playground, +a tool for managing and interacting with AI Quick Actions models. It includes classes +for handling sessions, threads, messages, and model parameters, along with custom exceptions +and utilities for session and thread management. +""" + +from typing import Dict, Generator, List + +from ads.aqua import logger +from ads.aqua.playground.const import MessageRate, MessageRole, ObjectType, Status +from ads.aqua.playground.db_context import ObjectType, db_context +from ads.aqua.playground.entities import ( + Message, + SearchId, + Session, + Thread, + VLLModelParams, +) +from ads.aqua.playground.errors import SessionNotFoundError, ThreadNotFoundError +from ads.aqua.playground.model_invoker import ModelInvoker +from ads.common.decorator import require_nonempty_arg +from ads.model.deployment.model_deployment import ModelDeployment + + +class SessionApp: + """ + Application class containing APIs for managing Aqua Playground sessions. + + + Methods + ------- + list(self, only_active: bool = True, include_threads: bool = False) -> Session + Lists the registered playground sessions. + get(model_id: str) -> Session + Retrieves a session associated with the specified model ID. + deactivate(model_id: str) + Deactivates the session associated with the specified model ID. + activate(model_id: str) + Activates the session associated with the specified model ID. + """ + + def list( + self, only_active: bool = True, include_threads: bool = False + ) -> List[Session]: + """ + Lists the registered playground sessions. + + Parameters + ---------- + only_active: (bool, optional). Defaults to True. + Whether to load all or only active sessions. + include_threads: (bool, optional). Defaults to False. + Whether to include the associated threads or not. + + Returns + ------- + List[Session] + The list of playground sessions. + """ + return db_context.get_sessions( + only_active=only_active, include_threads=include_threads + ) + + @require_nonempty_arg("id", "Either session ID or model ID need to be provided.") + def get(self, id: str, include_threads: bool = False) -> Session: + """ + Retrieve a playground session details by its session ID or model ID. + + The method differentiates between session ID and model ID based on the content of the ID. + If the ID contains 'ocid', it is treated as model ID. + + Parameters + ---------- + id: str + The session ID or model ID of the playground session. + include_threads: (bool, optional) + Whether include threads in result or not. + + Returns + ------- + Optional[Session] + The retrieved playground session if found, else None. + """ + search_id = SearchId.parse(id) + return db_context.get_session( + session_id=search_id.record_id, + model_id=search_id.model_id, + include_threads=include_threads, + ) + + @require_nonempty_arg("model_id", "The model ID must be provided.") + def create(self, model_id: str) -> Session: + """ + Creates a new playground session for the given model ID. + If the session with the given model ID already exists, then it will be returned. + + Parameters + ---------- + model_id: str + The model ID to create the playground session for. + + Returns + ------- + Session + The playground session instance. + + Raises + ------ + ValueError + If model ID not provided. + """ + + try: + session = self.get(id=model_id, include_threads=True) + logger.info( + "A Session with the provided model ID already exists. " + "Returning the existing session." + ) + except SessionNotFoundError: + model_deployment = ModelDeployment.from_id(model_id) + session = db_context.add_session( + model_id=model_deployment.model_deployment_id, + model_name=model_deployment.display_name, + model_endpoint=model_deployment.url, + ) + + return session + + @require_nonempty_arg("session_id", "The session ID must be provided.") + def activate(self, session_id: str): + """ + Activates the session associated with the given ID. + + Parameters + ---------- + session: str + The ID of the playground session to deactivate. + + Raises + ------ + ValueError + If session ID not provided. + """ + db_context.update_status( + object_type=ObjectType.SESSION, object_id=session_id, status=Status.ACTIVE + ) + + @require_nonempty_arg("session_id", "The session ID must be provided.") + def deactivate(self, session_id: str): + """ + Deactivates the session associated with the given ID. + + Parameters + ---------- + session: str + The ID of the playground session to deactivate. + """ + db_context.update_status( + object_type=ObjectType.SESSION, object_id=session_id, status=Status.ARCHIVED + ) + + @require_nonempty_arg("prompt", "The message must be provided.") + @require_nonempty_arg("endpoint", "The model endpoint must be provided.") + def invoke_model( + self, + endpoint: str, + prompt: str, + params: Dict = None, + ) -> Generator[str, None, None]: + """ + Generator to invoke the model and streams the result. + + Parameters + ---------- + endpoint:str + The URL endpoint to send the request. + prompt: str + The content of the message to be posted. + params: (Dict, optional) + Model parameters to be associated with the message. + Currently supported VLLM+OpenAI parameters. + + --model-params '{ + "max_tokens":500, + "temperature": 0.5, + "top_k": 10, + "top_p": 0.5, + "model": "/opt/ds/model/deployed_model", + ...}' + + Yields: + str + A line of the streamed response. + """ + yield from ModelInvoker( + endpoint=endpoint, + prompt=prompt, + params=VLLModelParams.from_dict(params).to_dict(), + ).invoke() + + +class ThreadApp: + """ + Application class containing APIs for managing threads within Aqua Playground sessions. + + Methods + ------- + list(self, session_id: str, only_active: bool = True) -> List[Thread] + Lists the registered playground session threads by session ID. + get(thread_id: str) + Retrieves a thread by its ID. + create(self, session_id: str, name: str, thread_id: str = None, status: str = Status.ACTIVE) -> Thread + Creates a new playground thread for the given session ID. + deactivate(thread_id: str) + Deactivates the thread with the given ID. + activate(thread_id: str) + Activates the thread with the given ID. + """ + + @require_nonempty_arg("session_id", "The session ID must be provided.") + def list(self, session_id: str, only_active: bool = True) -> List[Thread]: + """ + Lists the registered playground threads by session ID. + + Parameters + ---------- + session_id: str + The session ID to get the playground threads for. + The model ID can be also provided. The session id will be retrieved by model ID. + only_active: (bool, optional). Defaults to True. + Whether to load all or only active threads. + + Returns + ------- + List[Thread] + The list of playground session threads. + """ + + session = SessionApp().get(id=session_id, include_threads=False) + return db_context.get_session_threads( + session_id=session.session_id, only_active=only_active + ) + + @require_nonempty_arg("thread_id", "The thread ID must be provided.") + def get(self, thread_id: str, include_messages: bool = False) -> Thread: + """ + Retrieve a thread based on its ID. + + Parameters + ---------- + thread_id: str + The ID of the thread to be retrieved. + include_messages: (bool, optional). Defaults to False. + Whether include messages in result or not. + + Returns + ------- + Thread + The playground session thread. + + Raise + ----- + ThreadNotFoundError + If thread doesn't exist. + """ + return db_context.get_thread( + thread_id=thread_id, include_messages=include_messages + ) + + @require_nonempty_arg("session_id", "The session ID must be provided.") + @require_nonempty_arg("name", "The name for the new thread must be provided.") + def create( + self, + session_id: str, + name: str, + thread_id: str = None, + status: str = Status.ACTIVE, + ) -> Thread: + """ + Creates a new playground thread for the given session ID or model ID. + + Parameters + ---------- + session_id: str + The session ID to create the playground thread for. + The model ID can be also provided. The session id will be retrieved by model ID. + name: str + The name of the thread. + thread_id: (str, optional) + The ID of the thread. Will be auto generated if not provided. + status: (str, optional) + The status of the thread. Can be either `active` or `archived`. + + Returns + ------- + Thread + The playground thread instance. + """ + session = SessionApp().get(id=session_id, include_threads=False) + thread = None + if thread_id: + try: + thread = db_context.update_thread( + thread_id=thread_id, name=name, status=status + ) + except ThreadNotFoundError: + pass + + if not thread: + thread = db_context.add_thread( + session_id=session.session_id, + name=name, + status=status, + thread_id=thread_id, + ) + + return thread + + @require_nonempty_arg("thread_id", "The thread ID must be provided.") + def deactivate(self, thread_id: str): + """ + Deactivates the thread with the specified ID. + + Parameters + ---------- + thread_id: str + The ID of the thread to be deactivated. + """ + db_context.update_status( + object_type=ObjectType.THREAD, object_id=thread_id, status=Status.ARCHIVED + ) + + @require_nonempty_arg("thread_id", "The thread ID must be provided.") + def activate(self, thread_id: str): + """ + Activates the thread with the specified ID. + + Parameters + ---------- + thread_id: str + The ID of the thread to be activated. + """ + db_context.update_status( + object_type=ObjectType.THREAD, object_id=thread_id, status=Status.ACTIVE + ) + + +class MessageApp: + """ + Application class containing APIs for managing messages within Aqua Playground thread. + + Methods + ------- + + create(self, thread_id: str, content: str, ...) -> Message + Posts a message to the specified thread. + """ + + @require_nonempty_arg("thread_id", "The session ID must be provided.") + def create( + self, + thread_id: str, + content: str, + message_id: str = None, + parent_message_id: str = None, + role: str = MessageRole.USER, + rate: int = MessageRate.DEFAULT, + payload: Dict = None, + model_params: Dict = None, + status: str = Status.ACTIVE, + ) -> Message: + """ + Creates a new message for the given thread ID. + + Parameters + ---------- + thread_id: str + The ID of the thread to which the message will be added. + content: str + The text content of the message. + message_id: (str, optional) + The message ID. + parent_message_id: (str, optional) + The parent message. + payload: (Dict, optional) + The model payload. + model_params: (Dict, optional) + The model parameters. + status: (str) + The status of the message. + role: (str) + The role of the message (e.g., 'user', 'system'). + + Returns + ------- + Message + The playground message instance. + """ + return db_context.add_message_to_thread( + thread_id=thread_id, + content=content, + message_id=message_id, + parent_message_id=parent_message_id, + role=role, + rate=rate, + payload=payload, + model_params=model_params, + status=status, + ) + + @require_nonempty_arg("message_id", "The message ID must be provided.") + def update( + self, + message_id: str, + content: str, + rate: int = MessageRate.DEFAULT, + status: str = Status.ACTIVE, + ) -> Message: + """ + Updates a message by provided ID. + + Parameters + ---------- + thread_id: str + The ID of the thread to which the message will be added. + content: str + The text content of the message. + message_id: (str, optional) + The message ID. + parent_message_id: (str, optional) + The parent message. + payload: (Dict, optional) + The model payload. + model_params: (Dict, optional) + The model parameters. + status: (str) + The status of the message. + role: (str) + The role of the message (e.g., 'user', 'system'). + + Returns + ------- + Message + The playground message instance. + """ + return db_context.update_message( + content=content, + message_id=message_id, + rate=rate, + status=status, + ) + + +class PlaygroundApp: + """ + Aqua Playground Application. + + Attributes + ---------- + session: SessionApp + Managing playground sessions. + thread: ThreadApp + Managing playground threads within sessions. + """ + + session = SessionApp() + thread = ThreadApp() + message = MessageApp() diff --git a/tests/unitary/with_extras/aqua/test_playground_entities.py b/tests/unitary/with_extras/aqua/test_playground_entities.py new file mode 100644 index 000000000..8129e916b --- /dev/null +++ b/tests/unitary/with_extras/aqua/test_playground_entities.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*-- + +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +from datetime import datetime +from unittest.mock import MagicMock + +import pytest + +from ads.aqua.playground.const import MessageRate, MessageRole, Status +from ads.aqua.playground.entities import ( + Message, + ModelInfo, + SearchId, + Session, + Thread, + VLLModelParams, +) + + +class TestSearchId: + def test_parse_valid_model_id(self): + # Test parsing a string that contains 'ocid' as a model ID + model_id = "test_id" + result = SearchId.parse(model_id) + assert result.model_id == model_id + assert result.record_id is None + + def test_parse_valid_record_id(self): + # Test parsing a string that is a numeric value as a record ID + record_id = "12345" + result = SearchId.parse(record_id) + assert result.record_id == int(record_id) + assert result.model_id is None + + def test_parse_invalid_id(self): + # Test parsing an invalid string which should raise ValueError + invalid_id = "invalid" + with pytest.raises(ValueError): + SearchId.parse(invalid_id) + + +class TestVLLModelParams: + def test_default_params(self): + # Test the default initialization of VLLModelParams + params = VLLModelParams() + assert params.model == "/opt/ds/model/deployed_model" + assert params.max_tokens == 2048 + assert params.temperature == 0.7 + assert params.top_p == 1.0 + assert params.frequency_penalty == 0.0 + assert params.presence_penalty == 0.0 + assert params.top_k == 0 + assert params.echo == False + assert params.logprobs == None + assert params.use_beam_search == False + assert params.ignore_eos == False + assert params.n == 1 + assert params.best_of == 1 + assert params.stop == None + assert params.stream == False + assert params.min_p == 0.0 + + def test_custom_params(self): + # Test initialization with custom values for VLLModelParams + params = VLLModelParams(model="custom_model", max_tokens=1024) + assert params.model == "custom_model" + assert params.max_tokens == 1024 + + def test_post_init(self): + # Test post initialization + params = VLLModelParams(model=None) + assert params.model == "/opt/ds/model/deployed_model" + + +class TestMessage: + def test_default_params(self): + # Test the default initialization of Message + params = Message() + assert params.message_id == None + assert params.parent_message_id == None + assert params.session_id == None + assert params.thread_id == None + assert params.content == None + assert params.payload == None + assert params.status == Status.ACTIVE + assert params.rate == MessageRate.DEFAULT + assert params.role == None + assert params.created == None + assert params.answers == [] + assert params.model_params == VLLModelParams() + + def test_from_db_model(self): + # Test creating a Message instance from a MessageModel object + # Assuming a mock MessageModel object + mock_date = datetime.now() + mock_message_model = MagicMock( + id=2, + parent_id=1, + playground_thread_id=3, + content="test", + payload={}, + model_params={ + "model": "test", + "max_tokens": 2048, + "temperature": 0.7, + "top_p": 1.0, + }, + created=mock_date, + status="active", + rate=0, + role="user", + ) + + message = Message.from_db_model(mock_message_model) + assert message.message_id == 2 + assert message.parent_message_id == 1 + assert message.thread_id == 3 + assert message.content == "test" + + assert message.payload == {} + assert message.model_params == VLLModelParams.from_dict( + {"model": "test", "max_tokens": 2048, "temperature": 0.7, "top_p": 1.0} + ) + assert message.created == mock_date + assert message.status == "active" + assert message.rate == 0 + assert message.role == "user" + + +class TestThread: + def test_default_params(self): + # Test the default initialization of Thread + params = Thread() + assert params.id == None + assert params.name == None + assert params.session_id == None + assert params.created == None + assert params.status == Status.ACTIVE + assert params.messages == [] + + def test_from_db_model(self): + # Test creating a Thread instance from a ThreadModel object including associated messages + # Assuming mock ThreadModel and MessageModel objects + mock_date = datetime.now() + mock_message_model_question = MagicMock( + id=1, + parent_id=None, + playground_thread_id=3, + content="question", + payload={}, + model_params={ + "model": "test", + "max_tokens": 2048, + "temperature": 0.7, + "top_p": 1.0, + }, + created=mock_date, + status="active", + rate=0, + role=MessageRole.USER, + ) + + mock_message_model_answer = MagicMock( + id=2, + parent_id=1, + playground_thread_id=3, + content="answer", + payload={}, + model_params={ + "model": "test", + "max_tokens": 2048, + "temperature": 0.7, + "top_p": 1.0, + }, + created=mock_date, + status="active", + rate=0, + role=MessageRole.SYSTEM, + ) + + mock_thread_model = MagicMock( + id=1, + name="test", + playground_session_id=1, + created=mock_date, + status=Status.ACTIVE, + messages=[mock_message_model_question, mock_message_model_answer], + ) + + thread = Thread.from_db_model(mock_thread_model, include_messages=True) + assert thread.id == 1 + assert len(thread.messages) == 1 + assert thread.messages[0].message_id == 1 + assert thread.messages[0].content == "question" + + assert len(thread.messages[0].answers) == 1 + assert thread.messages[0].answers[0].message_id == 2 + assert thread.messages[0].answers[0].content == "answer" + + assert thread.session_id == 1 + assert thread.created == mock_date + assert thread.status == Status.ACTIVE + + +class TestModelInfo: + def test_initialization(self): + # Test initialization of ModelInfo with specific model details + model_info = ModelInfo( + id="model1", name="Model One", endpoint="http://endpoint" + ) + assert model_info.id == "model1" + assert model_info.name == "Model One" + assert model_info.endpoint == "http://endpoint" + + +class TestSession: + def test_default_params(self): + # Test the default initialization of Session + params = Session() + assert params.session_id == None + assert params.created == None + assert params.status == Status.ACTIVE + assert params.threads == [] + assert params.model == ModelInfo() + + def test_from_db_model_with_threads(self): + # Test creating a Session instance from a SessionModel object including associated threads + # Assuming mock SessionModel and ThreadModel objects + mock_date = datetime.now() + mock_thread_model = MagicMock( + id=1, + name="test", + session_id=1, + created=mock_date, + status=Status.ACTIVE, + messages=[], + ) + + mock_session_model = MagicMock( + id=1, + name="test", + created=mock_date, + status=Status.ACTIVE, + threads=[mock_thread_model], + model_id="model_id", + model_name="model_name", + model_endpoint="model_endpoint", + ) + + session = Session.from_db_model(mock_session_model, include_threads=True) + assert session.session_id == 1 + assert len(session.threads) == 1 + assert session.threads[0].id == 1 + assert session.created == mock_date + assert session.model == ModelInfo( + id="model_id", name="model_name", endpoint="model_endpoint" + )