Skip to content

Commit

Permalink
Improve chatbot streaming performance with diffs (#7102)
Browse files Browse the repository at this point in the history
* changes

* add changeset

* changes

* add changeset

* changes

* channges

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* canges

* changes

* changes

* changes

* Update free-moose-guess.md

* changes

---------

Co-authored-by: Ali Abid <[email protected]>
Co-authored-by: gradio-pr-bot <[email protected]>
Co-authored-by: Ali Abid <[email protected]>
  • Loading branch information
4 people authored Jan 31, 2024
1 parent 6a7e98b commit 68a54a7
Show file tree
Hide file tree
Showing 11 changed files with 312 additions and 25 deletions.
9 changes: 9 additions & 0 deletions .changeset/free-moose-guess.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
"@gradio/client": minor
"gradio": minor
"gradio_client": minor
---

feat:Improve chatbot streaming performance with diffs

Note that this PR changes the API format for generator functions, which would be a breaking change for any clients reading the EventStream directly
32 changes: 30 additions & 2 deletions client/js/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import {
set_space_hardware,
set_space_timeout,
hardware_types,
resolve_root
resolve_root,
apply_diff
} from "./utils.js";

import type {
Expand Down Expand Up @@ -288,6 +289,7 @@ export function api_factory(
const last_status: Record<string, Status["stage"]> = {};
let stream_open = false;
let pending_stream_messages: Record<string, any[]> = {}; // Event messages may be received by the SSE stream before the initial data POST request is complete. To resolve this race condition, we store the messages in a dictionary and process them when the POST request is complete.
let pending_diff_streams: Record<string, any[][]> = {};
let event_stream: EventSource | null = null;
const event_callbacks: Record<string, () => Promise<void>> = {};
const unclosed_events: Set<string> = new Set();
Expand Down Expand Up @@ -774,7 +776,8 @@ export function api_factory(
}
}
};
} else if (protocol == "sse_v1") {
} else if (protocol == "sse_v1" || protocol == "sse_v2") {
// latest API format. v2 introduces sending diffs for intermediate outputs in generative functions, which makes payloads lighter.
fire_event({
type: "status",
stage: "pending",
Expand Down Expand Up @@ -867,6 +870,9 @@ export function api_factory(
endpoint: _endpoint,
fn_index
});
if (data && protocol === "sse_v2") {
apply_diff_stream(event_id!, data);
}
}
if (data) {
fire_event({
Expand Down Expand Up @@ -904,6 +910,9 @@ export function api_factory(
if (event_callbacks[event_id]) {
delete event_callbacks[event_id];
}
if (event_id in pending_diff_streams) {
delete pending_diff_streams[event_id];
}
}
} catch (e) {
console.error("Unexpected client exception", e);
Expand Down Expand Up @@ -936,6 +945,25 @@ export function api_factory(
}
);

function apply_diff_stream(event_id: string, data: any): void {
let is_first_generation = !pending_diff_streams[event_id];
if (is_first_generation) {
pending_diff_streams[event_id] = [];
data.data.forEach((value: any, i: number) => {
pending_diff_streams[event_id][i] = value;
});
} else {
data.data.forEach((value: any, i: number) => {
let new_data = apply_diff(
pending_diff_streams[event_id][i],
value
);
pending_diff_streams[event_id][i] = new_data;
data.data[i] = new_data;
});
}
}

function fire_event<K extends EventType>(event: Event<K>): void {
const narrowed_listener_map: ListenerMap<K> = listener_map;
const listeners = narrowed_listener_map[event.type] || [];
Expand Down
2 changes: 1 addition & 1 deletion client/js/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export interface Config {
show_api: boolean;
stylesheets: string[];
path: string;
protocol?: "sse_v1" | "sse" | "ws";
protocol?: "sse_v2" | "sse_v1" | "sse" | "ws";
}

export interface Payload {
Expand Down
59 changes: 59 additions & 0 deletions client/js/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,62 @@ export const hardware_types = [
"a10g-large",
"a100-large"
] as const;

function apply_edit(
target: any,
path: (number | string)[],
action: string,
value: any
): any {
if (path.length === 0) {
if (action === "replace") {
return value;
} else if (action === "append") {
return target + value;
}
throw new Error(`Unsupported action: ${action}`);
}

let current = target;
for (let i = 0; i < path.length - 1; i++) {
current = current[path[i]];
}

const last_path = path[path.length - 1];
switch (action) {
case "replace":
current[last_path] = value;
break;
case "append":
current[last_path] += value;
break;
case "add":
if (Array.isArray(current)) {
current.splice(Number(last_path), 0, value);
} else {
current[last_path] = value;
}
break;
case "delete":
if (Array.isArray(current)) {
current.splice(Number(last_path), 1);
} else {
delete current[last_path];
}
break;
default:
throw new Error(`Unknown action: ${action}`);
}
return target;
}

export function apply_diff(
obj: any,
diff: [string, (number | string)[], any][]
): any {
diff.forEach(([action, path, value]) => {
obj = apply_edit(obj, path, action, value);
});

return obj;
}
20 changes: 15 additions & 5 deletions client/python/gradio_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,12 @@ def submit(
inferred_fn_index = self._infer_fn_index(api_name, fn_index)

helper = None
if self.endpoints[inferred_fn_index].protocol in ("ws", "sse", "sse_v1"):
if self.endpoints[inferred_fn_index].protocol in (
"ws",
"sse",
"sse_v1",
"sse_v2",
):
helper = self.new_helper(inferred_fn_index)
end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper)
future = self.executor.submit(end_to_end_fn, *args)
Expand Down Expand Up @@ -998,13 +1003,15 @@ def _predict(*data) -> tuple:
result = utils.synchronize_async(
self._sse_fn_v0, data, hash_data, helper
)
elif self.protocol == "sse_v1":
elif self.protocol == "sse_v1" or self.protocol == "sse_v2":
event_id = utils.synchronize_async(
self.client.send_data, data, hash_data
)
self.client.pending_event_ids.add(event_id)
self.client.pending_messages_per_event[event_id] = []
result = utils.synchronize_async(self._sse_fn_v1, helper, event_id)
result = utils.synchronize_async(
self._sse_fn_v1_v2, helper, event_id, self.protocol
)
else:
raise ValueError(f"Unsupported protocol: {self.protocol}")

Expand Down Expand Up @@ -1197,13 +1204,16 @@ async def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator):
self.client.cookies,
)

async def _sse_fn_v1(self, helper: Communicator, event_id: str):
return await utils.get_pred_from_sse_v1(
async def _sse_fn_v1_v2(
self, helper: Communicator, event_id: str, protocol: Literal["sse_v1", "sse_v2"]
):
return await utils.get_pred_from_sse_v1_v2(
helper,
self.client.headers,
self.client.cookies,
self.client.pending_messages_per_event,
event_id,
protocol,
)


Expand Down
74 changes: 66 additions & 8 deletions client/python/gradio_client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import base64
import copy
import hashlib
import json
import mimetypes
Expand All @@ -17,7 +18,7 @@
from enum import Enum
from pathlib import Path
from threading import Lock
from typing import Any, Callable, Optional, TypedDict
from typing import Any, Callable, Literal, Optional, TypedDict

import fsspec.asyn
import httpx
Expand Down Expand Up @@ -381,22 +382,19 @@ async def get_pred_from_sse_v0(
return task.result()


async def get_pred_from_sse_v1(
async def get_pred_from_sse_v1_v2(
helper: Communicator,
headers: dict[str, str],
cookies: dict[str, str] | None,
pending_messages_per_event: dict[str, list[Message | None]],
event_id: str,
protocol: Literal["sse_v1", "sse_v2"],
) -> dict[str, Any] | None:
done, pending = await asyncio.wait(
[
asyncio.create_task(check_for_cancel(helper, headers, cookies)),
asyncio.create_task(
stream_sse_v1(
helper,
pending_messages_per_event,
event_id,
)
stream_sse_v1_v2(helper, pending_messages_per_event, event_id, protocol)
),
],
return_when=asyncio.FIRST_COMPLETED,
Expand All @@ -411,6 +409,9 @@ async def get_pred_from_sse_v1(

assert len(done) == 1
for task in done:
exception = task.exception()
if exception:
raise exception
return task.result()


Expand Down Expand Up @@ -502,13 +503,15 @@ async def stream_sse_v0(
raise


async def stream_sse_v1(
async def stream_sse_v1_v2(
helper: Communicator,
pending_messages_per_event: dict[str, list[Message | None]],
event_id: str,
protocol: Literal["sse_v1", "sse_v2"],
) -> dict[str, Any]:
try:
pending_messages = pending_messages_per_event[event_id]
pending_responses_for_diffs = None

while True:
if len(pending_messages) > 0:
Expand Down Expand Up @@ -540,6 +543,19 @@ async def stream_sse_v1(
log=log_message,
)
output = msg.get("output", {}).get("data", [])
if (
msg["msg"] == ServerMessage.process_generating
and protocol == "sse_v2"
):
if pending_responses_for_diffs is None:
pending_responses_for_diffs = list(output)
else:
for i, value in enumerate(output):
prev_output = pending_responses_for_diffs[i]
new_output = apply_diff(prev_output, value)
pending_responses_for_diffs[i] = new_output
output[i] = new_output

if output and status_update.code != Status.FINISHED:
try:
result = helper.prediction_processor(*output)
Expand All @@ -557,6 +573,48 @@ async def stream_sse_v1(
raise


def apply_diff(obj, diff):
obj = copy.deepcopy(obj)

def apply_edit(target, path, action, value):
if len(path) == 0:
if action == "replace":
return value
elif action == "append":
return target + value
else:
raise ValueError(f"Unsupported action: {action}")

current = target
for i in range(len(path) - 1):
current = current[path[i]]

last_path = path[-1]
if action == "replace":
current[last_path] = value
elif action == "append":
current[last_path] += value
elif action == "add":
if isinstance(current, list):
current.insert(int(last_path), value)
else:
current[last_path] = value
elif action == "delete":
if isinstance(current, list):
del current[int(last_path)]
else:
del current[last_path]
else:
raise ValueError(f"Unknown action: {action}")

return target

for action, path, value in diff:
obj = apply_edit(obj, path, action, value)

return obj


########################
# Data processing utils
########################
Expand Down
Loading

0 comments on commit 68a54a7

Please sign in to comment.