Skip to content

Commit

Permalink
fix: vastly improve chat UI responsiveness by reordering Gradio events (
Browse files Browse the repository at this point in the history
#360) bump:patch
  • Loading branch information
taprosoft authored Oct 4, 2024
1 parent b01fc21 commit dfd00fe
Showing 1 changed file with 57 additions and 167 deletions.
224 changes: 57 additions & 167 deletions libs/ktem/ktem/pages/chat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import asyncio
import csv
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Optional

import gradio as gr
from filelock import FileLock
from ktem.app import BasePage
from ktem.components import reasonings
from ktem.db.models import Conversation, engine
Expand Down Expand Up @@ -38,6 +34,7 @@
for (var i = 0; i < links.length; i++) {
links[i].onclick = openModal;
}
return [links.length]
}
"""

Expand All @@ -48,19 +45,18 @@ def __init__(self, app):
self._indices_input = []

self.on_building_ui()

self._preview_links = gr.State(value=None)
self._reasoning_type = gr.State(value=None)
self._llm_type = gr.State(value=None)
self._conversation_renamed = gr.State(value=False)
self.info_panel_expanded = gr.State(value=True)
self._info_panel_expanded = gr.State(value=True)

def on_building_ui(self):
with gr.Row():
self.state_chat = gr.State(STATE)
self.state_retrieval_history = gr.State([])
self.state_chat_history = gr.State([])
self.state_plot_history = gr.State([])
self.state_settings = gr.State({})
self.state_info_panel = gr.State("")
self.state_plot_panel = gr.State(None)

with gr.Column(scale=1, elem_id="conv-settings-panel") as self.conv_column:
Expand Down Expand Up @@ -203,37 +199,11 @@ def on_register_events(self):
],
concurrency_limit=20,
show_progress="minimal",
).success(
fn=self.backup_original_info,
inputs=[
self.chat_panel.chatbot,
self._app.settings_state,
self.info_panel,
self.state_chat_history,
],
outputs=[
self.state_chat_history,
self.state_settings,
self.state_info_panel,
],
).then(
fn=self.persist_data_source,
inputs=[
self.chat_control.conversation_id,
self._app.user_id,
self.info_panel,
self.state_plot_panel,
self.state_retrieval_history,
self.state_plot_history,
self.chat_panel.chatbot,
self.state_chat,
]
+ self._indices_input,
outputs=[
self.state_retrieval_history,
self.state_plot_history,
],
concurrency_limit=20,
fn=lambda: True,
inputs=None,
outputs=[self._preview_links],
js=pdfview_js,
).success(
fn=self.check_and_suggest_name_conv,
inputs=self.chat_panel.chatbot,
Expand All @@ -256,7 +226,23 @@ def on_register_events(self):
],
show_progress="hidden",
).then(
fn=None, inputs=None, outputs=None, js=pdfview_js
fn=self.persist_data_source,
inputs=[
self.chat_control.conversation_id,
self._app.user_id,
self.info_panel,
self.state_plot_panel,
self.state_retrieval_history,
self.state_plot_history,
self.chat_panel.chatbot,
self.state_chat,
]
+ self._indices_input,
outputs=[
self.state_retrieval_history,
self.state_plot_history,
],
concurrency_limit=20,
)

self.chat_panel.regen_btn.click(
Expand All @@ -281,23 +267,10 @@ def on_register_events(self):
concurrency_limit=20,
show_progress="minimal",
).then(
fn=self.persist_data_source,
inputs=[
self.chat_control.conversation_id,
self._app.user_id,
self.info_panel,
self.state_plot_panel,
self.state_retrieval_history,
self.state_plot_history,
self.chat_panel.chatbot,
self.state_chat,
]
+ self._indices_input,
outputs=[
self.state_retrieval_history,
self.state_plot_history,
],
concurrency_limit=20,
fn=lambda: True,
inputs=None,
outputs=[self._preview_links],
js=pdfview_js,
).success(
fn=self.check_and_suggest_name_conv,
inputs=self.chat_panel.chatbot,
Expand All @@ -320,37 +293,39 @@ def on_register_events(self):
],
show_progress="hidden",
).then(
fn=None, inputs=None, outputs=None, js=pdfview_js
fn=self.persist_data_source,
inputs=[
self.chat_control.conversation_id,
self._app.user_id,
self.info_panel,
self.state_plot_panel,
self.state_retrieval_history,
self.state_plot_history,
self.chat_panel.chatbot,
self.state_chat,
]
+ self._indices_input,
outputs=[
self.state_retrieval_history,
self.state_plot_history,
],
concurrency_limit=20,
)

self.chat_control.btn_info_expand.click(
fn=lambda is_expanded: (
gr.update(scale=INFO_PANEL_SCALES[is_expanded]),
not is_expanded,
),
inputs=self.info_panel_expanded,
outputs=[self.info_column, self.info_panel_expanded],
inputs=self._info_panel_expanded,
outputs=[self.info_column, self._info_panel_expanded],
)

self.chat_panel.chatbot.like(
fn=self.is_liked,
inputs=[self.chat_control.conversation_id],
outputs=None,
).success(
self.save_log,
inputs=[
self.chat_control.conversation_id,
self.chat_panel.chatbot,
self._app.settings_state,
self.info_panel,
self.state_chat_history,
self.state_settings,
self.state_info_panel,
gr.State(getattr(flowsettings, "KH_APP_DATA_DIR", "logs")),
],
outputs=None,
)

self.chat_control.btn_new.click(
self.chat_control.new_conv,
inputs=self._app.user_id,
Expand Down Expand Up @@ -701,7 +676,15 @@ def is_liked(self, convo_id, liked: gr.LikeData):

def message_selected(self, retrieval_history, plot_history, msg: gr.SelectData):
index = msg.index[0]
return retrieval_history[index], plot_history[index]
try:
retrieval_content, plot_content = (
retrieval_history[index],
plot_history[index],
)
except IndexError:
retrieval_content, plot_content = gr.update(), None

return retrieval_content, plot_content

def create_pipeline(
self,
Expand Down Expand Up @@ -889,96 +872,3 @@ def check_and_suggest_name_conv(self, chat_history):
renamed = True

return new_name, renamed

def backup_original_info(
self, chat_history, settings, info_pannel, original_chat_history
):
original_chat_history.append(chat_history[-1])
return original_chat_history, settings, info_pannel

def save_log(
self,
conversation_id,
chat_history,
settings,
info_panel,
original_chat_history,
original_settings,
original_info_panel,
log_dir,
):
if not Path(log_dir).exists():
Path(log_dir).mkdir(parents=True)

lock = FileLock(Path(log_dir) / ".lock")
# get current date
today = datetime.now()
formatted_date = today.strftime("%d%m%Y_%H")

with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == conversation_id)
result = session.exec(statement).one()

data_source = deepcopy(result.data_source)
likes = data_source.get("likes", [])
if not likes:
return

feedback = likes[-1][-1]
message_index = likes[-1][0]

current_message = chat_history[message_index[0]]
original_message = original_chat_history[message_index[0]]
is_original = all(
[
current_item == original_item
for current_item, original_item in zip(
current_message, original_message
)
]
)

dataframe = [
[
conversation_id,
message_index,
current_message[0],
current_message[1],
chat_history,
settings,
info_panel,
feedback,
is_original,
original_message[1],
original_chat_history,
original_settings,
original_info_panel,
]
]

with lock:
log_file = Path(log_dir) / f"{formatted_date}_log.csv"
is_log_file_exist = log_file.is_file()
with open(log_file, "a") as f:
writer = csv.writer(f)
# write headers
if not is_log_file_exist:
writer.writerow(
[
"Conversation ID",
"Message ID",
"Question",
"Answer",
"Chat History",
"Settings",
"Evidences",
"Feedback",
"Original/ Rewritten",
"Original Answer",
"Original Chat History",
"Original Settings",
"Original Evidences",
]
)

writer.writerows(dataframe)

0 comments on commit dfd00fe

Please sign in to comment.