Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion llm-service/app/routers/index/sessions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@
import json
import logging
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, Generator, Any

import time
from fastapi import APIRouter, Header, HTTPException
from fastapi.responses import StreamingResponse
from llama_index.core.base.llms.types import ChatResponse
Expand All @@ -58,6 +58,7 @@
)
from ....services.chat.suggested_questions import generate_suggested_questions
from ....services.chat_history.chat_history_manager import (
RagMessage,
RagStudioChatMessage,
chat_history_manager,
)
Expand Down Expand Up @@ -292,6 +293,7 @@ def generate_stream() -> Generator[str, None, None]:
query=request.query,
configuration=configuration,
user_name=remote_user,
response_id=None,
)

# If we get here and the cancel_event is set, the client has disconnected
Expand All @@ -317,6 +319,35 @@ def generate_stream() -> Generator[str, None, None]:
response: ChatResponse = item
# Check for cancellation between each response
if cancel_event.is_set():
print("Client disconnected between events")
if response.additional_kwargs.get("response_id"):
updated_response = RagStudioChatMessage(
id=response.additional_kwargs["response_id"],
session_id=session_id,
source_nodes=(
response.source_nodes
if hasattr(response, "source_nodes")
else []
),
inference_model=session.inference_model,
rag_message=RagMessage(
user=request.query,
assistant=(
response.message.content
if response.message.content
else ""
),
),
evaluations=[],
timestamp=time.time(),
condensed_question=None,
status="cancelled",
)
chat_history_manager.update_message(
session_id=session_id,
message_id=updated_response.id,
message=updated_response,
)
logger.info("Client disconnected during result processing")
break
if "chat_event" in response.additional_kwargs:
Expand Down Expand Up @@ -347,6 +378,35 @@ def generate_stream() -> Generator[str, None, None]:
logger.exception("Timeout: Failed to stream chat completion")
yield 'data: {{"error" : "Timeout: Failed to stream chat completion"}}\n\n'
except Exception as e:
if response.additional_kwargs.get("response_id"):
updated_response = RagStudioChatMessage(
id=response.additional_kwargs["response_id"],
session_id=session_id,
source_nodes=(
response.source_nodes
if hasattr(response, "source_nodes")
else []
),
inference_model=session.inference_model,
rag_message=RagMessage(
user=request.query,
assistant=(
response.message.content
if response.message.content
else ""
),
),
evaluations=[],
timestamp=time.time(),
condensed_question=None,
status="error",
error_message=str(e),
)
chat_history_manager.update_message(
session_id=session_id,
message_id=updated_response.id,
message=updated_response,
)
logger.exception("Failed to stream chat completion")
yield f'data: {{"error" : "{e}"}}\n\n'
finally:
Expand Down
64 changes: 5 additions & 59 deletions llm-service/app/services/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,19 @@
import uuid
from typing import Optional

from llama_index.core.chat_engine.types import AgentChatResponse

from app.ai.vector_stores.vector_store_factory import VectorStoreFactory
from app.rag_types import RagPredictConfiguration
from app.services import evaluators, llm_completion
from app.services.chat.utils import retrieve_chat_history, format_source_nodes
from app.services import llm_completion
from app.services.chat.streaming_chat import finalize_response
from app.services.chat.utils import retrieve_chat_history
from app.services.chat_history.chat_history_manager import (
Evaluation,
RagMessage,
RagStudioChatMessage,
chat_history_manager,
)
from app.services.metadata_apis.session_metadata_api import Session
from app.services.mlflow import record_rag_mlflow_run, record_direct_llm_mlflow_run
from app.services.mlflow import record_direct_llm_mlflow_run
from app.services.query import querier
from app.services.query.querier import get_nodes_from_output
from app.services.query.query_configuration import QueryConfiguration

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -125,58 +122,6 @@ def _run_chat(
)


def finalize_response(
chat_response: AgentChatResponse,
condensed_question: str | None,
query: str,
query_configuration: QueryConfiguration,
response_id: str,
session: Session,
user_name: Optional[str],
) -> RagStudioChatMessage:
if condensed_question and (condensed_question.strip() == query.strip()):
condensed_question = None

orig_source_nodes = chat_response.source_nodes
source_nodes = get_nodes_from_output(chat_response.response, session)

# if node with id present in orig_source_nodes, then don't add it again
node_ids_present = set([node.node_id for node in orig_source_nodes])
for node in source_nodes:
if node.node_id not in node_ids_present:
orig_source_nodes.append(node)

chat_response.source_nodes = orig_source_nodes

evaluations = []
if len(chat_response.source_nodes) != 0:
relevance, faithfulness = evaluators.evaluate_response(
query, chat_response, session.inference_model
)
evaluations.append(Evaluation(name="relevance", value=relevance))
evaluations.append(Evaluation(name="faithfulness", value=faithfulness))
response_source_nodes = format_source_nodes(chat_response)
new_chat_message = RagStudioChatMessage(
id=response_id,
session_id=session.id,
source_nodes=response_source_nodes,
inference_model=session.inference_model,
rag_message=RagMessage(
user=query,
assistant=chat_response.response,
),
evaluations=evaluations,
timestamp=time.time(),
condensed_question=condensed_question,
)
record_rag_mlflow_run(
new_chat_message, query_configuration, response_id, session, user_name
)
chat_history_manager.append_to_history(session.id, [new_chat_message])

return new_chat_message


def direct_llm_chat(
session: Session, response_id: str, query: str, user_name: Optional[str]
) -> RagStudioChatMessage:
Expand All @@ -197,6 +142,7 @@ def direct_llm_chat(
),
timestamp=time.time(),
condensed_question=None,
status="success",
)
chat_history_manager.append_to_history(session.id, [new_chat_message])
return new_chat_message
84 changes: 78 additions & 6 deletions llm-service/app/services/chat/streaming_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,24 @@

from app.ai.vector_stores.vector_store_factory import VectorStoreFactory
from app.rag_types import RagPredictConfiguration
from app.services import llm_completion, models
from app.services.chat.chat import finalize_response
from app.services.chat.utils import retrieve_chat_history
from app.services import llm_completion, models, evaluators
from app.services.chat.utils import retrieve_chat_history, format_source_nodes
from app.services.chat_history.chat_history_manager import (
RagStudioChatMessage,
RagMessage,
chat_history_manager,
Evaluation,
)
from app.services.metadata_apis.session_metadata_api import Session
from app.services.mlflow import record_direct_llm_mlflow_run
from app.services.mlflow import record_direct_llm_mlflow_run, record_rag_mlflow_run
from app.services.query import querier
from app.services.query.chat_engine import (
FlexibleContextChatEngine,
build_flexible_chat_engine,
)
from app.services.query.querier import (
build_retriever,
get_nodes_from_output,
)
from app.services.query.query_configuration import QueryConfiguration

Expand All @@ -73,6 +74,7 @@ def stream_chat(
query: str,
configuration: RagPredictConfiguration,
user_name: Optional[str],
response_id: Optional[str] = None,
) -> Generator[ChatResponse, None, None]:
query_configuration = QueryConfiguration(
top_k=session.response_chunks,
Expand All @@ -86,7 +88,23 @@ def stream_chat(
use_streaming=not session.query_configuration.disable_streaming,
)

response_id = str(uuid.uuid4())
response_id = response_id or str(uuid.uuid4())
new_chat_message = RagStudioChatMessage(
id=response_id,
session_id=session.id,
source_nodes=[],
inference_model=session.inference_model,
evaluations=[],
rag_message=RagMessage(
user=query,
assistant="",
),
timestamp=time.time(),
condensed_question=None,
status="pending",
)
chat_history_manager.append_to_history(session.id, [new_chat_message])

total_data_sources_size: int = sum(
map(
lambda ds_id: VectorStoreFactory.for_chunks(ds_id).size() or 0,
Expand Down Expand Up @@ -216,5 +234,59 @@ def _stream_direct_llm_chat(
),
timestamp=time.time(),
condensed_question=None,
status="success",
)
chat_history_manager.append_to_history(session.id, [new_chat_message])
chat_history_manager.update_message(session.id, response_id, new_chat_message)


def finalize_response(
chat_response: AgentChatResponse,
condensed_question: str | None,
query: str,
query_configuration: QueryConfiguration,
response_id: str,
session: Session,
user_name: Optional[str],
) -> RagStudioChatMessage:
if condensed_question and (condensed_question.strip() == query.strip()):
condensed_question = None

orig_source_nodes = chat_response.source_nodes
source_nodes = get_nodes_from_output(chat_response.response, session)

# if node with id present in orig_source_nodes, then don't add it again
node_ids_present = set([node.node_id for node in orig_source_nodes])
for node in source_nodes:
if node.node_id not in node_ids_present:
orig_source_nodes.append(node)

chat_response.source_nodes = orig_source_nodes

evaluations = []
if len(chat_response.source_nodes) != 0:
relevance, faithfulness = evaluators.evaluate_response(
query, chat_response, session.inference_model
)
evaluations.append(Evaluation(name="relevance", value=relevance))
evaluations.append(Evaluation(name="faithfulness", value=faithfulness))
response_source_nodes = format_source_nodes(chat_response)
new_chat_message = RagStudioChatMessage(
id=response_id,
session_id=session.id,
source_nodes=response_source_nodes,
inference_model=session.inference_model,
rag_message=RagMessage(
user=query,
assistant=chat_response.response,
),
evaluations=evaluations,
timestamp=time.time(),
condensed_question=condensed_question,
status="success",
)
record_rag_mlflow_run(
new_chat_message, query_configuration, response_id, session, user_name
)
chat_history_manager.update_message(session.id, response_id, new_chat_message)

return new_chat_message
18 changes: 18 additions & 0 deletions llm-service/app/services/chat_history/chat_history_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class RagStudioChatMessage(BaseModel):
evaluations: list[Evaluation]
timestamp: float
condensed_question: Optional[str]
status: Literal["pending", "error", "cancelled", "success"] = "success"
error_message: Optional[str] = None


class ChatHistoryManager(metaclass=ABCMeta):
Expand All @@ -84,6 +86,22 @@ def append_to_history(
) -> None:
pass

@abstractmethod
def update_message(
self, session_id: int, message_id: str, message: RagStudioChatMessage
) -> None:
"""Update an existing message by ID for the given session.

Implementations should overwrite both the user and assistant entries
corresponding to this message ID.
"""
pass

@abstractmethod
def delete_message(self, session_id: int, message_id: str) -> None:
"""Delete an existing message by ID for the given session."""
pass


def _create_chat_history_manager() -> ChatHistoryManager:
from app.services.chat_history.simple_chat_history_manager import (
Expand Down
45 changes: 45 additions & 0 deletions llm-service/app/services/chat_history/s3_chat_history_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,48 @@ def append_to_history(
f"Error appending to chat history for session {session_id}: {e}"
)
raise

def update_message(
self, session_id: int, message_id: str, message: RagStudioChatMessage
) -> None:
"""Update an existing message's content and metadata by ID in S3."""
s3_key = self._get_s3_key(session_id)
try:
chat_history_data = self.retrieve_chat_history(session_id=session_id)
updated = False
for idx, existing in enumerate(chat_history_data):
if existing.id == message_id:
chat_history_data[idx] = message
updated = True
break
if not updated:
return
chat_history_json = json.dumps(
[m.model_dump() for m in chat_history_data]
)
self.s3_client.put_object(
Bucket=self.bucket_name, Key=s3_key, Body=chat_history_json
)
except Exception as e:
logger.error(
f"Error updating chat message {message.id} for session {session_id}: {e}"
)
raise

def delete_message(self, session_id: int, message_id: str) -> None:
"""Delete a specific message by ID in S3-backed store."""
s3_key = self._get_s3_key(session_id)
try:
chat_history_data = self.retrieve_chat_history(session_id=session_id)
chat_history_data = [m for m in chat_history_data if m.id != message_id]
chat_history_json = json.dumps(
[m.model_dump() for m in chat_history_data]
)
self.s3_client.put_object(
Bucket=self.bucket_name, Key=s3_key, Body=chat_history_json
)
except Exception as e:
logger.error(
f"Error deleting chat message {message_id} for session {session_id}: {e}"
)
raise
Loading
Loading