Skip to content

Commit

Permalink
Merge branch 'main' into file-fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dawoodkhan82 authored Jan 31, 2024
2 parents 8f25762 + 68a54a7 commit 21f5cd0
Show file tree
Hide file tree
Showing 66 changed files with 1,152 additions and 409 deletions.
9 changes: 9 additions & 0 deletions .changeset/free-moose-guess.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
"@gradio/client": minor
"gradio": minor
"gradio_client": minor
---

feat:Improve chatbot streaming performance with diffs

Note that this PR changes the API format for generator functions, which would be a breaking change for any clients reading the EventStream directly
7 changes: 7 additions & 0 deletions .changeset/ninety-bobcats-fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"gradio": minor
"gradio_client": minor
"website": minor
---

feat:Document the `gr.ParamViewer` component, and fix component preprocessing/postprocessing docstrings
5 changes: 5 additions & 0 deletions .changeset/spicy-rabbits-check.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": patch
---

feat:Fix hyphen-bug in gradio cc publish
2 changes: 1 addition & 1 deletion .github/workflows/deploy-chromatic.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
pr_number: ${{ steps.get-pr.outputs.number }}
pr_labels: ${{ steps.get-pr.outputs.pr_labels }}
steps:
- uses: 8BitJonny/gh-get-current-pr@2.2.0
- uses: 8BitJonny/gh-get-current-pr@3.0.0
id: get-pr
with:
filterOutDraft: true
Expand Down
32 changes: 30 additions & 2 deletions client/js/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import {
set_space_hardware,
set_space_timeout,
hardware_types,
resolve_root
resolve_root,
apply_diff
} from "./utils.js";

import type {
Expand Down Expand Up @@ -288,6 +289,7 @@ export function api_factory(
const last_status: Record<string, Status["stage"]> = {};
let stream_open = false;
let pending_stream_messages: Record<string, any[]> = {}; // Event messages may be received by the SSE stream before the initial data POST request is complete. To resolve this race condition, we store the messages in a dictionary and process them when the POST request is complete.
let pending_diff_streams: Record<string, any[][]> = {};
let event_stream: EventSource | null = null;
const event_callbacks: Record<string, () => Promise<void>> = {};
const unclosed_events: Set<string> = new Set();
Expand Down Expand Up @@ -774,7 +776,8 @@ export function api_factory(
}
}
};
} else if (protocol == "sse_v1") {
} else if (protocol == "sse_v1" || protocol == "sse_v2") {
// latest API format. v2 introduces sending diffs for intermediate outputs in generative functions, which makes payloads lighter.
fire_event({
type: "status",
stage: "pending",
Expand Down Expand Up @@ -867,6 +870,9 @@ export function api_factory(
endpoint: _endpoint,
fn_index
});
if (data && protocol === "sse_v2") {
apply_diff_stream(event_id!, data);
}
}
if (data) {
fire_event({
Expand Down Expand Up @@ -904,6 +910,9 @@ export function api_factory(
if (event_callbacks[event_id]) {
delete event_callbacks[event_id];
}
if (event_id in pending_diff_streams) {
delete pending_diff_streams[event_id];
}
}
} catch (e) {
console.error("Unexpected client exception", e);
Expand Down Expand Up @@ -936,6 +945,25 @@ export function api_factory(
}
);

function apply_diff_stream(event_id: string, data: any): void {
let is_first_generation = !pending_diff_streams[event_id];
if (is_first_generation) {
pending_diff_streams[event_id] = [];
data.data.forEach((value: any, i: number) => {
pending_diff_streams[event_id][i] = value;
});
} else {
data.data.forEach((value: any, i: number) => {
let new_data = apply_diff(
pending_diff_streams[event_id][i],
value
);
pending_diff_streams[event_id][i] = new_data;
data.data[i] = new_data;
});
}
}

function fire_event<K extends EventType>(event: Event<K>): void {
const narrowed_listener_map: ListenerMap<K> = listener_map;
const listeners = narrowed_listener_map[event.type] || [];
Expand Down
2 changes: 1 addition & 1 deletion client/js/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export interface Config {
show_api: boolean;
stylesheets: string[];
path: string;
protocol?: "sse_v1" | "sse" | "ws";
protocol?: "sse_v2" | "sse_v1" | "sse" | "ws";
}

export interface Payload {
Expand Down
59 changes: 59 additions & 0 deletions client/js/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,62 @@ export const hardware_types = [
"a10g-large",
"a100-large"
] as const;

function apply_edit(
target: any,
path: (number | string)[],
action: string,
value: any
): any {
if (path.length === 0) {
if (action === "replace") {
return value;
} else if (action === "append") {
return target + value;
}
throw new Error(`Unsupported action: ${action}`);
}

let current = target;
for (let i = 0; i < path.length - 1; i++) {
current = current[path[i]];
}

const last_path = path[path.length - 1];
switch (action) {
case "replace":
current[last_path] = value;
break;
case "append":
current[last_path] += value;
break;
case "add":
if (Array.isArray(current)) {
current.splice(Number(last_path), 0, value);
} else {
current[last_path] = value;
}
break;
case "delete":
if (Array.isArray(current)) {
current.splice(Number(last_path), 1);
} else {
delete current[last_path];
}
break;
default:
throw new Error(`Unknown action: ${action}`);
}
return target;
}

export function apply_diff(
obj: any,
diff: [string, (number | string)[], any][]
): any {
diff.forEach(([action, path, value]) => {
obj = apply_edit(obj, path, action, value);
});

return obj;
}
20 changes: 15 additions & 5 deletions client/python/gradio_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,12 @@ def submit(
inferred_fn_index = self._infer_fn_index(api_name, fn_index)

helper = None
if self.endpoints[inferred_fn_index].protocol in ("ws", "sse", "sse_v1"):
if self.endpoints[inferred_fn_index].protocol in (
"ws",
"sse",
"sse_v1",
"sse_v2",
):
helper = self.new_helper(inferred_fn_index)
end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper)
future = self.executor.submit(end_to_end_fn, *args)
Expand Down Expand Up @@ -998,13 +1003,15 @@ def _predict(*data) -> tuple:
result = utils.synchronize_async(
self._sse_fn_v0, data, hash_data, helper
)
elif self.protocol == "sse_v1":
elif self.protocol == "sse_v1" or self.protocol == "sse_v2":
event_id = utils.synchronize_async(
self.client.send_data, data, hash_data
)
self.client.pending_event_ids.add(event_id)
self.client.pending_messages_per_event[event_id] = []
result = utils.synchronize_async(self._sse_fn_v1, helper, event_id)
result = utils.synchronize_async(
self._sse_fn_v1_v2, helper, event_id, self.protocol
)
else:
raise ValueError(f"Unsupported protocol: {self.protocol}")

Expand Down Expand Up @@ -1197,13 +1204,16 @@ async def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator):
self.client.cookies,
)

async def _sse_fn_v1(self, helper: Communicator, event_id: str):
return await utils.get_pred_from_sse_v1(
async def _sse_fn_v1_v2(
self, helper: Communicator, event_id: str, protocol: Literal["sse_v1", "sse_v2"]
):
return await utils.get_pred_from_sse_v1_v2(
helper,
self.client.headers,
self.client.cookies,
self.client.pending_messages_per_event,
event_id,
protocol,
)


Expand Down
30 changes: 29 additions & 1 deletion client/python/gradio_client/documentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def document_fn(fn: Callable, cls) -> tuple[str, list[dict], dict, str | None]:
for param_name, param in signature.parameters.items():
if param_name.startswith("_"):
continue
if param_name == "self":
continue
if param_name in ["kwargs", "args"] and param_name not in parameters:
continue
parameter_doc = {
Expand All @@ -147,7 +149,7 @@ def document_fn(fn: Callable, cls) -> tuple[str, list[dict], dict, str | None]:
parameter_docs.append(parameter_doc)
assert (
len(parameters) == 0
), f"Documentation format for {fn.__name__} documents nonexistent parameters: {''.join(parameters.keys())}"
), f"Documentation format for {fn.__name__} documents nonexistent parameters: {', '.join(parameters.keys())}. Valid parameters: {', '.join(signature.parameters.keys())}"
if len(returns) == 0:
return_docs = {}
elif len(returns) == 1:
Expand Down Expand Up @@ -203,6 +205,24 @@ def generate_documentation():
for cls, fns in class_list:
fn_to_document = cls if inspect.isfunction(cls) else cls.__init__
_, parameter_doc, return_doc, _ = document_fn(fn_to_document, cls)
if (
hasattr(cls, "preprocess")
and callable(cls.preprocess) # type: ignore
and hasattr(cls, "postprocess")
and callable(cls.postprocess) # type: ignore
):
preprocess_doc = document_fn(cls.preprocess, cls) # type: ignore
postprocess_doc = document_fn(cls.postprocess, cls) # type: ignore
preprocess_doc, postprocess_doc = (
{
"parameter_doc": preprocess_doc[1],
"return_doc": preprocess_doc[2],
},
{
"parameter_doc": postprocess_doc[1],
"return_doc": postprocess_doc[2],
},
)
cls_description, cls_tags, cls_example = document_cls(cls)
cls_documentation = {
"class": cls,
Expand All @@ -214,6 +234,14 @@ def generate_documentation():
"example": cls_example,
"fns": [],
}
if (
hasattr(cls, "preprocess")
and callable(cls.preprocess) # type: ignore
and hasattr(cls, "postprocess")
and callable(cls.postprocess) # type: ignore
):
cls_documentation["preprocess"] = preprocess_doc # type: ignore
cls_documentation["postprocess"] = postprocess_doc # type: ignore
for fn_name in fns:
instance_attribute_fn = fn_name.startswith("*")
if instance_attribute_fn:
Expand Down
Loading

0 comments on commit 21f5cd0

Please sign in to comment.