From 3c8c4ac2db284e1cb503c397205a79a6dcc27e23 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Wed, 31 Jan 2024 09:36:21 -0800 Subject: [PATCH 1/4] Document the `gr.ParamViewer` component, and fix component preprocessing/postprocessing docstrings (#7116) * add paramviewer to docs * add changeset * type fixing * round to * format * revert to make test pass * remove msg * add changeset * annotated image * annotated image push * audio * changes * changes * audio * audio * changes * remove borders from code * working * fixes * add changeset * formatting * build fix * changes * ls * more components * format * some more * more * component * plots * buttons * simple backend * number * more * more * simple templates * format * fixes * backend * fix tests * fix test * fixes * formatting --------- Co-authored-by: gradio-pr-bot Co-authored-by: aliabd --- .changeset/ninety-bobcats-fix.md | 7 ++ client/python/gradio_client/documentation.py | 30 +++++- demo/paramviewer_component/run.ipynb | 1 + demo/paramviewer_component/run.py | 21 ++++ gradio/_simple_templates/simpledropdown.py | 32 +++---- gradio/_simple_templates/simpleimage.py | 7 +- gradio/_simple_templates/simpletextbox.py | 23 ++--- gradio/blocks.py | 2 +- gradio/component_meta.py | 2 +- gradio/components/annotated_image.py | 53 +++++----- gradio/components/audio.py | 23 +++-- gradio/components/bar_plot.py | 40 ++++---- gradio/components/base.py | 8 ++ gradio/components/button.py | 24 +++-- gradio/components/chatbot.py | 26 +++-- gradio/components/checkbox.py | 20 +++- gradio/components/checkboxgroup.py | 19 +++- gradio/components/clear_button.py | 18 +++- gradio/components/code.py | 22 +++-- gradio/components/color_picker.py | 19 +++- gradio/components/dataframe.py | 23 +++-- gradio/components/dataset.py | 20 +++- gradio/components/dropdown.py | 22 +++-- gradio/components/duplicate_button.py | 4 +- gradio/components/fallback.py | 14 +++ gradio/components/file.py | 28 ++++-- gradio/components/file_explorer.py | 23 +++-- gradio/components/gallery.py | 49 +++++----- gradio/components/highlighted_text.py | 26 +++-- gradio/components/html.py | 18 +++- gradio/components/image.py | 30 ++++-- gradio/components/image_editor.py | 21 +++- gradio/components/json_component.py | 24 +++-- gradio/components/label.py | 43 +++++---- gradio/components/line_plot.py | 26 +++-- gradio/components/login_button.py | 2 +- gradio/components/logout_button.py | 2 +- gradio/components/markdown.py | 25 +++-- gradio/components/model3d.py | 20 +++- gradio/components/number.py | 21 ++-- gradio/components/paramviewer.py | 40 +++++--- gradio/components/plot.py | 26 +++-- gradio/components/radio.py | 15 +-- gradio/components/scatter_plot.py | 26 +++-- gradio/components/slider.py | 17 +++- gradio/components/state.py | 16 +++- gradio/components/textbox.py | 17 +++- gradio/components/upload_button.py | 24 +++-- gradio/components/video.py | 44 +++++---- gradio/events.py | 4 +- js/_website/src/lib/assets/style.css | 5 +- .../[[version]]/docs/[doc]/+page.server.ts | 43 ++++++--- .../[[version]]/docs/[doc]/+page.svelte | 96 +++++++++++++------ test/test_gradio_component_cli.py | 4 +- 54 files changed, 833 insertions(+), 382 deletions(-) create mode 100644 .changeset/ninety-bobcats-fix.md create mode 100644 demo/paramviewer_component/run.ipynb create mode 100644 demo/paramviewer_component/run.py diff --git a/.changeset/ninety-bobcats-fix.md b/.changeset/ninety-bobcats-fix.md new file mode 100644 index 0000000000000..eb5eee5ae7447 --- /dev/null +++ b/.changeset/ninety-bobcats-fix.md @@ -0,0 +1,7 @@ +--- +"gradio": minor +"gradio_client": minor +"website": minor +--- + +feat:Document the `gr.ParamViewer` component, and fix component preprocessing/postprocessing docstrings diff --git a/client/python/gradio_client/documentation.py b/client/python/gradio_client/documentation.py index af1337646761d..5d030c263fec6 100644 --- a/client/python/gradio_client/documentation.py +++ b/client/python/gradio_client/documentation.py @@ -123,6 +123,8 @@ def document_fn(fn: Callable, cls) -> tuple[str, list[dict], dict, str | None]: for param_name, param in signature.parameters.items(): if param_name.startswith("_"): continue + if param_name == "self": + continue if param_name in ["kwargs", "args"] and param_name not in parameters: continue parameter_doc = { @@ -147,7 +149,7 @@ def document_fn(fn: Callable, cls) -> tuple[str, list[dict], dict, str | None]: parameter_docs.append(parameter_doc) assert ( len(parameters) == 0 - ), f"Documentation format for {fn.__name__} documents nonexistent parameters: {''.join(parameters.keys())}" + ), f"Documentation format for {fn.__name__} documents nonexistent parameters: {', '.join(parameters.keys())}. Valid parameters: {', '.join(signature.parameters.keys())}" if len(returns) == 0: return_docs = {} elif len(returns) == 1: @@ -203,6 +205,24 @@ def generate_documentation(): for cls, fns in class_list: fn_to_document = cls if inspect.isfunction(cls) else cls.__init__ _, parameter_doc, return_doc, _ = document_fn(fn_to_document, cls) + if ( + hasattr(cls, "preprocess") + and callable(cls.preprocess) # type: ignore + and hasattr(cls, "postprocess") + and callable(cls.postprocess) # type: ignore + ): + preprocess_doc = document_fn(cls.preprocess, cls) # type: ignore + postprocess_doc = document_fn(cls.postprocess, cls) # type: ignore + preprocess_doc, postprocess_doc = ( + { + "parameter_doc": preprocess_doc[1], + "return_doc": preprocess_doc[2], + }, + { + "parameter_doc": postprocess_doc[1], + "return_doc": postprocess_doc[2], + }, + ) cls_description, cls_tags, cls_example = document_cls(cls) cls_documentation = { "class": cls, @@ -214,6 +234,14 @@ def generate_documentation(): "example": cls_example, "fns": [], } + if ( + hasattr(cls, "preprocess") + and callable(cls.preprocess) # type: ignore + and hasattr(cls, "postprocess") + and callable(cls.postprocess) # type: ignore + ): + cls_documentation["preprocess"] = preprocess_doc # type: ignore + cls_documentation["postprocess"] = postprocess_doc # type: ignore for fn_name in fns: instance_attribute_fn = fn_name.startswith("*") if instance_attribute_fn: diff --git a/demo/paramviewer_component/run.ipynb b/demo/paramviewer_component/run.ipynb new file mode 100644 index 0000000000000..9d02fce2f8b2c --- /dev/null +++ b/demo/paramviewer_component/run.ipynb @@ -0,0 +1 @@ +{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: paramviewer_component"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"The `round()` function in Python takes two parameters\")\n", " gr.ParamViewer(\n", " {\n", " \"number\": { \n", " \"type\": \"int | float\", \n", " \"description\": \"The number to round\", \n", " \"default\": None\n", " },\n", " \"ndigits\": { \n", " \"type\": \"int\", \n", " \"description\": \"The number of digits to round to\", \n", " \"default\": \"0\"\n", " }\n", " }\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/paramviewer_component/run.py b/demo/paramviewer_component/run.py new file mode 100644 index 0000000000000..8fc7ef986e89c --- /dev/null +++ b/demo/paramviewer_component/run.py @@ -0,0 +1,21 @@ +import gradio as gr + +with gr.Blocks() as demo: + gr.Markdown("The `round()` function in Python takes two parameters") + gr.ParamViewer( + { + "number": { + "type": "int | float", + "description": "The number to round", + "default": None + }, + "ndigits": { + "type": "int", + "description": "The number of digits to round to", + "default": "0" + } + } + ) + +if __name__ == "__main__": + demo.launch() \ No newline at end of file diff --git a/gradio/_simple_templates/simpledropdown.py b/gradio/_simple_templates/simpledropdown.py index 15d7631db6d77..f295958e14d99 100644 --- a/gradio/_simple_templates/simpledropdown.py +++ b/gradio/_simple_templates/simpledropdown.py @@ -10,10 +10,6 @@ class SimpleDropdown(FormComponent): """ Creates a very simple dropdown listing choices from which entries can be selected. - Preprocessing: Preprocessing: passes the value of the selected dropdown entry as a {str}. - Postprocessing: expects a {str} corresponding to the value of the dropdown entry to be selected. - Examples-format: a {str} representing the drop down value to select. - Demos: sentence_builder, titanic_survival """ EVENTS = [Events.change, Events.input, Events.select] @@ -41,7 +37,7 @@ def __init__( value: default value selected in dropdown. If None, no value is selected by default. If callable, the function will be called whenever the app loads to set the initial value of the component. label: component name in interface. info: additional component description. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. show_label: if True, will display label. scale: relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer. min_width: minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first. @@ -49,11 +45,9 @@ def __init__( visible: If False, component will be hidden. elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles. elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles. - render: bool = True, + render: If False, component will not render be rendered in the Blocks context. Should be used if the intention is to assign event listeners now but render the component later. """ self.choices = ( - # Although we expect choices to be a list of lists, it can be a list of tuples if the Gradio app - # is loaded with gr.load() since Python tuples are converted to lists in JSON. [tuple(c) if isinstance(c, (tuple, list)) else (str(c), c) for c in choices] if choices else [] @@ -82,14 +76,14 @@ def api_info(self) -> dict[str, Any]: def example_inputs(self) -> Any: return self.choices[0][1] if self.choices else None - def preprocess(self, x: str | int | float | None) -> str | int | float | None: + def preprocess(self, payload: str | int | float | None) -> str | int | float | None: """ Parameters: - x: selected choice + payload: the value of the selected dropdown choice Returns: - selected choice + Passes the value of the selected dropdown choice as a `str | int | float`. """ - return x + return payload def _warn_if_invalid_choice(self, y): if y not in [value for _, value in self.choices]: @@ -97,11 +91,17 @@ def _warn_if_invalid_choice(self, y): f"The value passed into gr.Dropdown() is not in the list of choices. Please update the list of choices to include: {y}." ) - def postprocess(self, y): - if y is None: + def postprocess(self, value): + """ + Parameters: + value: Expects a `str | int | float` corresponding to the value of the dropdown entry to be selected. + Returns: + Returns the value of the selected dropdown entry. + """ + if value is None: return None - self._warn_if_invalid_choice(y) - return y + self._warn_if_invalid_choice(value) + return value def process_example(self, input_data): return next((c[0] for c in self.choices if c[1] == input_data), None) diff --git a/gradio/_simple_templates/simpleimage.py b/gradio/_simple_templates/simpleimage.py index 22de90f672e13..3adf351187cc0 100644 --- a/gradio/_simple_templates/simpleimage.py +++ b/gradio/_simple_templates/simpleimage.py @@ -18,9 +18,6 @@ class SimpleImage(Component): """ Creates an image component that can be used to upload images (as an input) or display images (as an output). - Preprocessing: passes the uploaded image as a {str} filepath. - Postprocessing: expects a {str} or {pathlib.Path} filepath to an image and displays the image. - Examples-format: a {str} local filepath or URL to an image. """ EVENTS = [ @@ -85,7 +82,7 @@ def preprocess(self, payload: FileData | None) -> str | None: Parameters: payload: A FileData object containing the image data. Returns: - A string containing the path to the image. + A `str` containing the path to the image. """ if payload is None: return None @@ -94,7 +91,7 @@ def preprocess(self, payload: FileData | None) -> str | None: def postprocess(self, value: str | Path | None) -> FileData | None: """ Parameters: - value: A string or pathlib.Path object containing the path to the image. + value: Expects a `str` or `pathlib.Path` object containing the path to the image. Returns: A FileData object containing the image data. """ diff --git a/gradio/_simple_templates/simpletextbox.py b/gradio/_simple_templates/simpletextbox.py index ec8fb60e7e6f0..8021a6e2d0f3e 100644 --- a/gradio/_simple_templates/simpletextbox.py +++ b/gradio/_simple_templates/simpletextbox.py @@ -9,9 +9,6 @@ class SimpleTextbox(FormComponent): """ Creates a very simple textbox for user to enter string input or display string output. - Preprocessing: passes textbox value as a {str} into the function. - Postprocessing: expects a {str} returned from function and sets textbox value to it. - Examples-format: a {str} representing the textbox input. """ EVENTS = [ @@ -42,7 +39,7 @@ def __init__( value: default text to provide in textbox. If callable, the function will be called whenever the app loads to set the initial value of the component. placeholder: placeholder hint to provide behind textbox. label: component name in interface. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. show_label: if True, will display label. scale: relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer. min_width: minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first. @@ -69,25 +66,23 @@ def __init__( render=render, ) - def preprocess(self, x: str | None) -> str | None: + def preprocess(self, payload: str | None) -> str | None: """ - Preprocesses input (converts it to a string) before passing it to the function. Parameters: - x: text + payload: the text entered in the textarea. Returns: - text + Passes text value as a {str} into the function. """ - return None if x is None else str(x) + return None if payload is None else str(payload) - def postprocess(self, y: str | None) -> str | None: + def postprocess(self, value: str | None) -> str | None: """ - Postproccess the function output y by converting it to a str before passing it to the frontend. Parameters: - y: function output to postprocess. + value: Expects a {str} returned from function and sets textarea value to it. Returns: - text + The value to display in the textarea. """ - return None if y is None else str(y) + return None if value is None else str(value) def api_info(self) -> dict[str, Any]: return {"type": "string"} diff --git a/gradio/blocks.py b/gradio/blocks.py index bae484de25f1d..b431e2306cded 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -841,7 +841,7 @@ def set_event_trigger( batch: whether this function takes in a batch of inputs max_batch_size: the maximum batch size to send to the function cancels: a list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method. - every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled. + every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. collects_event_data: whether to collect event data for this event trigger_after: if set, this event will be triggered after 'trigger_after' function index trigger_only_on_success: if True, this event will only be triggered if the previous event was successful (only applies if `trigger_after` is set) diff --git a/gradio/component_meta.py b/gradio/component_meta.py index d76fa999eb388..ab721353ff230 100644 --- a/gradio/component_meta.py +++ b/gradio/component_meta.py @@ -49,7 +49,7 @@ def {{ event }}(self, preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component). postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser. cancels: A list of other events to cancel when this listener is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method. Functions that have not yet run (or generators that are iterating) will be cancelled, but functions that are currently running will be allowed to finish. - every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled. + every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. trigger_mode: If "once" (default for all events except `.change()`) would not allow any submissions while an event is pending. If set to "multiple", unlimited submissions are allowed while pending, and "always_last" (default for `.change()` event) would allow a second submission after the pending event is complete. js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components. concurrency_limit: If set, this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `Blocks.queue()`, which itself is 1 by default). diff --git a/gradio/components/annotated_image.py b/gradio/components/annotated_image.py index 4565ab37b8bac..ea01878ab6101 100644 --- a/gradio/components/annotated_image.py +++ b/gradio/components/annotated_image.py @@ -5,8 +5,8 @@ from typing import Any, List import numpy as np +import PIL.Image from gradio_client.documentation import document, set_documentation_group -from PIL import Image as _Image # using _ to minimize namespace pollution from gradio import processing_utils, utils from gradio.components.base import Component @@ -15,7 +15,7 @@ set_documentation_group("component") -_Image.init() # fixes https://github.com/gradio-app/gradio/issues/2843 +PIL.Image.init() # fixes https://github.com/gradio-app/gradio/issues/2843 class Annotation(GradioModel): @@ -31,9 +31,8 @@ class AnnotatedImageData(GradioModel): @document() class AnnotatedImage(Component): """ - Displays a base image and colored subsections on top of that image. Subsections can take the from of rectangles (e.g. object detection) or masks (e.g. image segmentation). - Preprocessing: this component does *not* accept input. - Postprocessing: expects a {Tuple[numpy.ndarray | PIL.Image | str, List[Tuple[numpy.ndarray | Tuple[int, int, int, int], str]]]} consisting of a base image and a list of subsections, that are either (x1, y1, x2, y2) tuples identifying object boundaries, or 0-1 confidence masks of the same shape as the image. A label is provided for each subsection. + Creates a component to displays a base image and colored annotations on top of that image. Annotations can take the from of rectangles (e.g. object detection) or masks (e.g. image segmentation). + As this component does not accept user input, it is rarely used as an input component. Demos: image_segmentation """ @@ -45,7 +44,7 @@ class AnnotatedImage(Component): def __init__( self, value: tuple[ - np.ndarray | _Image.Image | str, + np.ndarray | PIL.Image.Image | str, list[tuple[np.ndarray | tuple[int, int, int, int], str]], ] | None = None, @@ -67,17 +66,17 @@ def __init__( ): """ Parameters: - value: Tuple of base image and list of (subsection, label) pairs. - show_legend: If True, will show a legend of the subsections. + value: Tuple of base image and list of (annotation, label) pairs. + show_legend: If True, will show a legend of the annotations. height: The height of the image, specified in pixels if a number is passed, or in CSS units if a string is passed. width: The width of the image, specified in pixels if a number is passed, or in CSS units if a string is passed. color_map: A dictionary mapping labels to colors. The colors must be specified as hex codes. label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. show_label: if True, will display label. container: If True, will place the component in a container - providing some extra padding around the border. - scale: relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer. - min_width: minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first. + scale: Relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer. + min_width: Minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first. visible: If False, component will be hidden. elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles. elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles. @@ -101,32 +100,47 @@ def __init__( value=value, ) + def preprocess( + self, payload: AnnotatedImageData | None + ) -> tuple[str, list[tuple[str, str]]] | None: + """ + Parameters: + payload: Tuple of base image and list of annotations. + Returns: + Passes its value as a `tuple` consisting of a `str` filepath to a base image and `list` of annotations. Each annotation itself is `tuple` of a mask (as a `str` filepath to image) and a `str` label. + """ + if payload is None: + return None + base_img = payload.image.path + annotations = [(a.image.path, a.label) for a in payload.annotations] + return (base_img, annotations) + def postprocess( self, value: tuple[ - np.ndarray | _Image.Image | str, + np.ndarray | PIL.Image.Image | str, list[tuple[np.ndarray | tuple[int, int, int, int], str]], ] | None, ) -> AnnotatedImageData | None: """ Parameters: - value: Tuple of base image and list of subsections, with each subsection a two-part tuple where the first element is a 4 element bounding box or a 0-1 confidence mask, and the second element is the label. + value: Expects a a tuple of a base image and list of annotations: a `tuple[Image, list[Annotation]]`. The `Image` itself can be `str` filepath, `numpy.ndarray`, or `PIL.Image`. Each `Annotation` is a `tuple[Mask, str]`. The `Mask` can be either a `tuple` of 4 `int`'s representing the bounding box coordinates (x1, y1, x2, y2), or 0-1 confidence mask in the form of a `numpy.ndarray` of the same shape as the image, while the second element of the `Annotation` tuple is a `str` label. Returns: - Tuple of base image file and list of subsections, with each subsection a two-part tuple where the first element image path of the mask, and the second element is the label. + Tuple of base image file and list of annotations, with each annotation a two-part tuple where the first element image path of the mask, and the second element is the label. """ if value is None: return None base_img = value[0] if isinstance(base_img, str): base_img_path = base_img - base_img = np.array(_Image.open(base_img)) + base_img = np.array(PIL.Image.open(base_img)) elif isinstance(base_img, np.ndarray): base_file = processing_utils.save_img_array_to_cache( base_img, cache_dir=self.GRADIO_CACHE ) base_img_path = str(utils.abspath(base_file)) - elif isinstance(base_img, _Image.Image): + elif isinstance(base_img, PIL.Image.Image): base_file = processing_utils.save_pil_to_cache( base_img, cache_dir=self.GRADIO_CACHE ) @@ -171,7 +185,7 @@ def hex_to_rgb(value): colored_mask[:, :, 2] = rgb_color[2] * solid_mask colored_mask[:, :, 3] = mask_array * 255 - colored_mask_img = _Image.fromarray((colored_mask).astype(np.uint8)) + colored_mask_img = PIL.Image.fromarray((colored_mask).astype(np.uint8)) mask_file = processing_utils.save_pil_to_cache( colored_mask_img, cache_dir=self.GRADIO_CACHE @@ -188,8 +202,3 @@ def hex_to_rgb(value): def example_inputs(self) -> Any: return {} - - def preprocess( - self, payload: AnnotatedImageData | None - ) -> AnnotatedImageData | None: - return payload diff --git a/gradio/components/audio.py b/gradio/components/audio.py index 57edcc84f8496..af8d90ab383fb 100644 --- a/gradio/components/audio.py +++ b/gradio/components/audio.py @@ -49,9 +49,6 @@ class Audio( ): """ Creates an audio component that can be used to upload/record audio (as an input) or display audio (as an output). - Preprocessing: depending on `type`, passes the uploaded audio as {str} filepath or a {Tuple(int, numpy.array)} corresponding to (sample rate in Hz, audio data). If the latter, the audio data is a 16-bit int array whose values range from -32768 to 32767 and shape of the audio data array is (samples,) for mono audio or (samples, channels) for multi-channel audio. - Postprocessing: expects a {Tuple(int, numpy.array)} corresponding to (sample rate in Hz, audio data as a float or int numpy array) or as a {str} or {pathlib.Path} filepath or URL to an audio file, or bytes for binary content (recommended for streaming). Note: When converting audio data from float format to WAV, the audio is normalized by its peak value to avoid distortion or clipping in the resulting audio. - Examples-format: a {str} filepath to a local file that contains audio. Demos: main_note, generate_tone, reverse_audio Guides: real-time-speech-recognition """ @@ -105,11 +102,11 @@ def __init__( sources: A list of sources permitted for audio. "upload" creates a box where user can drop an audio file, "microphone" creates a microphone input. The first element in the list will be used as the default source. If None, defaults to ["upload", "microphone"], or ["microphone"] if `streaming` is True. type: The format the audio file is converted to before being passed into the prediction function. "numpy" converts the audio to a tuple consisting of: (int sample rate, numpy.array for the data), "filepath" passes a str path to a temporary file containing the audio. label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. show_label: if True, will display label. container: If True, will place the component in a container - providing some extra padding around the border. - scale: relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer. - min_width: minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first. + scale: Relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer. + min_width: Minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first. interactive: If True, will allow users to upload and edit an audio file. If False, can only be used to play audio. If not provided, this is inferred based on whether the component is used as an input or output. visible: If False, component will be hidden. streaming: If set to True when used in a `live` interface as an input, will automatically stream webcam feed. When used set as an output, takes audio chunks yield from the backend and combines them into one streaming audio output. @@ -189,7 +186,13 @@ def example_inputs(self) -> Any: def preprocess( self, payload: FileData | None - ) -> tuple[int, np.ndarray] | str | None: + ) -> str | tuple[int, np.ndarray] | None: + """ + Parameters: + payload: audio data as a FileData object, or None. + Returns: + passes audio as one of these formats (depending on `type`): a `str` filepath, or `tuple` of (sample rate in Hz, audio data as numpy array). If the latter, the audio data is a 16-bit `int` array whose values range from -32768 to 32767 and shape of the audio data array is (samples,) for mono audio or (samples, channels) for multi-channel audio. + """ if payload is None: return payload @@ -229,13 +232,13 @@ def preprocess( ) def postprocess( - self, value: tuple[int, np.ndarray] | str | Path | bytes | None + self, value: str | Path | bytes | tuple[int, np.ndarray] | None ) -> FileData | bytes | None: """ Parameters: - value: audio data in either of the following formats: a tuple of (sample_rate, data), or a string filepath or URL to an audio file, or None. + value: expects audio data in any of these formats: a `str` or `pathlib.Path` filepath or URL to an audio file, or a `bytes` object (recommended for streaming), or a `tuple` of (sample rate in Hz, audio data as numpy array). Note: if audio is supplied as a numpy array, the audio will be normalized by its peak value to avoid distortion or clipping in the resulting audio. Returns: - base64 url data + FileData object, bytes, or None. """ orig_name = None if value is None: diff --git a/gradio/components/bar_plot.py b/gradio/components/bar_plot.py index 3e5e4b58ffea9..b761ecc2fcd54 100644 --- a/gradio/components/bar_plot.py +++ b/gradio/components/bar_plot.py @@ -16,10 +16,8 @@ @document() class BarPlot(Plot): """ - Create a bar plot. - - Preprocessing: this component does *not* accept input. - Postprocessing: expects a pandas dataframe with the data to plot. + Creates a bar plot component to display data from a pandas DataFrame (as output). As this component does + not accept user input, it is rarely used as an input component. Demos: bar_plot, chicago-bikeshare-dashboard """ @@ -75,7 +73,7 @@ def __init__( ): """ Parameters: - value: The pandas dataframe containing the data to display in a scatter plot. + value: The pandas dataframe containing the data to display in a scatter plot. If a callable is provided, the function will be called whenever the app loads to set the initial value of the plot. x: Column corresponding to the x axis. y: Column corresponding to the y axis. color: The column to determine the bar color. Must be categorical (discrete values). @@ -97,7 +95,7 @@ def __init__( interactive: Whether users should be able to interact with the plot by panning or zooming with their mouse or trackpad. label: The (optional) label to display on the top left corner of the plot. show_label: Whether the label should be displayed. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. visible: Whether the plot should be visible. elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles. elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles. @@ -172,8 +170,8 @@ def create_plot( "none", ] | None = None, - height: int | None = None, - width: int | None = None, + height: int | str | None = None, + width: int | str | None = None, y_lim: list[int] | None = None, interactive: bool | None = True, sort: Literal["x", "y", "-x", "-y"] | None = None, @@ -244,7 +242,7 @@ def create_plot( } if tooltip: - encodings["tooltip"] = tooltip + encodings["tooltip"] = tooltip # type: ignore chart = ( alt.Chart(value) # type: ignore @@ -257,11 +255,24 @@ def create_plot( return chart - def postprocess( - self, value: pd.DataFrame | dict | None - ) -> AltairPlotData | dict | None: + def preprocess(self, payload: AltairPlotData) -> AltairPlotData: + """ + Parameters: + payload: The data to display in a bar plot. + Returns: + (Rarely used) passes the data displayed in the bar plot as an AltairPlotData dataclass, which includes the plot information as a JSON string, as well as the type of plot (in this case, "bar"). + """ + return payload + + def postprocess(self, value: pd.DataFrame | None) -> AltairPlotData | None: + """ + Parameters: + value: Expects a pandas DataFrame containing the data to display in the bar plot. The DataFrame should contain at least two columns, one for the x-axis (corresponding to this component's `x` argument) and one for the y-axis (corresponding to `y`). + Returns: + The data to display in a bar plot, in the form of an AltairPlotData dataclass, which includes the plot information as a JSON string, as well as the type of plot (in this case, "bar"). + """ # if None or update - if value is None or isinstance(value, dict): + if value is None: return value if self.x is None or self.y is None: raise ValueError("No value provided for required parameters `x` and `y`.") @@ -292,6 +303,3 @@ def postprocess( def example_inputs(self) -> dict[str, Any]: return {} - - def preprocess(self, payload: AltairPlotData) -> AltairPlotData: - return payload diff --git a/gradio/components/base.py b/gradio/components/base.py index 5966f52589e58..223cb5e686aae 100644 --- a/gradio/components/base.py +++ b/gradio/components/base.py @@ -49,6 +49,10 @@ class ComponentBase(ABC, metaclass=ComponentMeta): def preprocess(self, payload: Any) -> Any: """ Any preprocessing needed to be performed on function input. + Parameters: + payload: The input data received by the component from the frontend. + Returns: + The preprocessed input data sent to the user's function in the backend. """ return payload @@ -56,6 +60,10 @@ def preprocess(self, payload: Any) -> Any: def postprocess(self, value): """ Any postprocessing needed to be performed on function output. + Parameters: + value: The output data received by the component from the user's function in the backend. + Returns: + The postprocessed output data sent to the frontend. """ return value diff --git a/gradio/components/button.py b/gradio/components/button.py index 20a4d1537d73f..a0001cdcef4ff 100644 --- a/gradio/components/button.py +++ b/gradio/components/button.py @@ -15,11 +15,7 @@ @document() class Button(Component): """ - Used to create a button, that can be assigned arbitrary click() events. The label (value) of the button can be used as an input or set via the output of a function. - - Preprocessing: passes the button value as a {str} into the function - Postprocessing: expects a {str} to be returned from a function, which is set as the label of the button - Demos: blocks_inputs, blocks_kinematics + Creates a button that can be assigned arbitrary .click() events. The value (label) of the button can be used as an input to the function (rarely used) or set via the output of a function. """ EVENTS = [Events.click] @@ -44,7 +40,7 @@ def __init__( """ Parameters: value: Default text for the button to display. If callable, the function will be called whenever the app loads to set the initial value of the component. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. variant: 'primary' for main call-to-action, 'secondary' for a more subdued style, 'stop' for a stop button. size: Size of the button. Can be "sm" or "lg". icon: URL or path to the icon file to display within the button. If None, no icon will be displayed. Must be within the working directory of the Gradio app or an external URL. @@ -77,10 +73,22 @@ def __init__( def skip_api(self): return True - def preprocess(self, payload: str) -> str: + def preprocess(self, payload: str | None) -> str | None: + """ + Parameters: + payload: string corresponding to the button label + Returns: + (Rarely used) the `str` corresponding to the button label when the button is clicked + """ return payload - def postprocess(self, value: str) -> str: + def postprocess(self, value: str | None) -> str | None: + """ + Parameters: + value: string corresponding to the button label + Returns: + Expects a `str` value that is set as the button label + """ return value def example_inputs(self) -> Any: diff --git a/gradio/components/chatbot.py b/gradio/components/chatbot.py index 69d3ac55dc435..657d4bbc41088 100644 --- a/gradio/components/chatbot.py +++ b/gradio/components/chatbot.py @@ -29,9 +29,9 @@ class ChatbotData(GradioRootModel): @document() class Chatbot(Component): """ - Displays a chatbot output showing both user submitted messages and responses. Supports a subset of Markdown including bold, italics, code, tables. Also supports audio/video/image files, which are displayed in the Chatbot, and other kinds of files which are displayed as links. - Preprocessing: passes the messages in the Chatbot as a {List[List[str | None | Tuple]]}, i.e. a list of lists. The inner list has 2 elements: the user message and the response message. See `Postprocessing` for the format of these messages. - Postprocessing: expects function to return a {List[List[str | None | Tuple]]}, i.e. a list of lists. The inner list should have 2 elements: the user message and the response message. The individual messages can be (1) strings in valid Markdown, (2) tuples if sending files: (a filepath or URL to a file, [optional string alt text]) -- if the file is image/video/audio, it is displayed in the Chatbot, or (3) None, in which case the message is not displayed. + Creates a chatbot that displays user-submitted messages and responses. Supports a subset of Markdown including bold, italics, code, tables. + Also supports audio/video/image files, which are displayed in the Chatbot, and other kinds of files which are displayed as links. This + component is usually used as an output component. Demos: chatbot_simple, chatbot_multimodal Guides: creating-a-chatbot @@ -73,7 +73,7 @@ def __init__( Parameters: value: Default value to show in chatbot. If callable, the function will be called whenever the app loads to set the initial value of the component. label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. show_label: if True, will display label. container: If True, will place the component in a container - providing some extra padding around the border. scale: relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer. @@ -151,8 +151,14 @@ def _preprocess_chat_messages( def preprocess( self, - payload: ChatbotData, - ) -> list[list[str | tuple[str] | tuple[str, str] | None]]: + payload: ChatbotData | None, + ) -> list[list[str | tuple[str] | tuple[str, str] | None]] | None: + """ + Parameters: + payload: data as a ChatbotData object + Returns: + Passes the messages in the chatbot as a `list[list[str | None | tuple]]`, i.e. a list of lists. The inner list has 2 elements: the user message and the response message. Each message can be (1) a string in valid Markdown, (2) a tuple if there are displayed files: (a filepath or URL to a file, [optional string alt text]), or (3) None, if there is no message displayed. + """ if payload is None: return payload processed_messages = [] @@ -194,8 +200,14 @@ def _postprocess_chat_messages( def postprocess( self, - value: list[list[str | tuple[str] | tuple[str, str] | None] | tuple], + value: list[list[str | tuple[str] | tuple[str, str] | None] | tuple] | None, ) -> ChatbotData: + """ + Parameters: + value: expects a `list[list[str | None | tuple]]`, i.e. a list of lists. The inner list should have 2 elements: the user message and the response message. The individual messages can be (1) strings in valid Markdown, (2) tuples if sending files: (a filepath or URL to a file, [optional string alt text]) -- if the file is image/video/audio, it is displayed in the Chatbot, or (3) None, in which case the message is not displayed. + Returns: + an object of type ChatbotData + """ if value is None: return ChatbotData(root=[]) processed_messages = [] diff --git a/gradio/components/checkbox.py b/gradio/components/checkbox.py index 44ef1f9604026..be74a4dc074db 100644 --- a/gradio/components/checkbox.py +++ b/gradio/components/checkbox.py @@ -15,11 +15,9 @@ @document() class Checkbox(FormComponent): """ - Creates a checkbox that can be set to `True` or `False`. + Creates a checkbox that can be set to `True` or `False`. Can be used as an input to pass a boolean value to a function or as an output + to display a boolean value. - Preprocessing: passes the status of the checkbox as a {bool} into the function. - Postprocessing: expects a {bool} returned from the function and, if it is True, checks the checkbox. - Examples-format: a {bool} representing whether the box is checked. Demos: sentence_builder, titanic_survival """ @@ -47,7 +45,7 @@ def __init__( value: if True, checked by default. If callable, the function will be called whenever the app loads to set the initial value of the component. label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. info: additional component description. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. show_label: if True, will display label. container: If True, will place the component in a container - providing some extra padding around the border. scale: relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer. @@ -81,7 +79,19 @@ def example_inputs(self) -> bool: return True def preprocess(self, payload: bool | None) -> bool | None: + """ + Parameters: + payload: the status of the checkbox + Returns: + Passes the status of the checkbox as a `bool`. + """ return payload def postprocess(self, value: bool | None) -> bool | None: + """ + Parameters: + value: Expects a `bool` value that is set as the status of the checkbox + Returns: + The same `bool` value that is set as the status of the checkbox + """ return value diff --git a/gradio/components/checkboxgroup.py b/gradio/components/checkboxgroup.py index f46e251607497..97cf6e530d670 100644 --- a/gradio/components/checkboxgroup.py +++ b/gradio/components/checkboxgroup.py @@ -15,10 +15,7 @@ @document() class CheckboxGroup(FormComponent): """ - Creates a set of checkboxes of which a subset can be checked. - Preprocessing: passes the list of checked checkboxes as a {List[str | int | float]} or their indices as a {List[int]} into the function, depending on `type`. - Postprocessing: expects a {List[str | int | float]}, each element of which becomes a checked checkbox. - Examples-format: a {List[str | int | float]} representing the values to be checked. + Creates a set of checkboxes. Can be used as an input to pass a set of values to a function or as an output to display values, a subset of which are selected. Demos: sentence_builder, titanic_survival """ @@ -50,7 +47,7 @@ def __init__( type: Type of value to be returned by component. "value" returns the list of strings of the choices selected, "index" returns the list of indices of the choices selected. label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. info: Additional component description. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise.sed (e.g. to cancel it) via this component's .load_event attribute. show_label: If True, will display label. container: If True, will place the component in a container - providing some extra padding around the border. scale: Relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer. @@ -103,6 +100,12 @@ def api_info(self) -> dict[str, Any]: def preprocess( self, payload: list[str | int | float] ) -> list[str | int | float] | list[int | None]: + """ + Parameters: + payload: the list of checked checkboxes' values + Returns: + Passes the list of checked checkboxes as a `list[str | int | float]` or their indices as a `list[int]` into the function, depending on `type`. + """ if self.type == "value": return payload elif self.type == "index": @@ -119,6 +122,12 @@ def preprocess( def postprocess( self, value: list[str | int | float] | str | int | float | None ) -> list[str | int | float]: + """ + Parameters: + value: Expects a `list[str | int | float]` of values or a single `str | int | float` value, the checkboxes with these values are checked. + Returns: + the list of checked checkboxes' values + """ if value is None: return [] if not isinstance(value, list): diff --git a/gradio/components/clear_button.py b/gradio/components/clear_button.py index 2ad807afc6f92..711f08346791f 100644 --- a/gradio/components/clear_button.py +++ b/gradio/components/clear_button.py @@ -111,11 +111,23 @@ def add(self, components: None | Component | list[Component]) -> ClearButton: ) return self - def postprocess(self, value: str | None) -> str | None: - return value - def preprocess(self, payload: str | None) -> str | None: + """ + Parameters: + payload: string corresponding to the button label + Returns: + (Rarely used) the `str` corresponding to the button label when the button is clicked + """ return payload + def postprocess(self, value: str | None) -> str | None: + """ + Parameters: + value: string corresponding to the button label + Returns: + Expects a `str` value that is set as the button label + """ + return value + def example_inputs(self) -> Any: return None diff --git a/gradio/components/code.py b/gradio/components/code.py index 19fae0add4480..a3bbbb2928d09 100644 --- a/gradio/components/code.py +++ b/gradio/components/code.py @@ -16,9 +16,7 @@ @document("languages") class Code(Component): """ - Creates a Code editor for entering, editing or viewing code. - Preprocessing: passes a {str} of code into the function. - Postprocessing: expects the function to return a {str} of code or a single-element {tuple}: {(string_filepath,)} + Creates a code editor for viewing code (as an ouptut component), or for entering and editing code (as an input component). """ languages = [ @@ -77,7 +75,7 @@ def __init__( """ Parameters: value: Default value to show in the code editor. If callable, the function will be called whenever the app loads to set the initial value of the component. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. language: The language to display the code as. Supported languages listed in `gr.Code.languages`. label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. interactive: Whether user should be able to enter code or only view it. @@ -110,10 +108,22 @@ def __init__( value=value, ) - def preprocess(self, payload: Any) -> Any: + def preprocess(self, payload: str | None) -> str | None: + """ + Parameters: + payload: string corresponding to the code + Returns: + Passes the code entered as a `str`. + """ return payload - def postprocess(self, value: tuple | str | None) -> None | str: + def postprocess(self, value: tuple[str] | str | None) -> None | str: + """ + Parameters: + value: Expects a `str` of code or a single-element `tuple`: (filepath,) with the `str` path to a file containing the code. + Returns: + Returns the code as a `str`. + """ if value is None: return None elif isinstance(value, tuple): diff --git a/gradio/components/color_picker.py b/gradio/components/color_picker.py index 6575c9d42c0b7..fa1993f4b22eb 100644 --- a/gradio/components/color_picker.py +++ b/gradio/components/color_picker.py @@ -15,10 +15,7 @@ @document() class ColorPicker(Component): """ - Creates a color picker for user to select a color as string input. - Preprocessing: passes selected color value as a {str} into the function. - Postprocessing: expects a {str} returned from function and sets color picker value to it. - Examples-format: a {str} with a hexadecimal representation of a color, e.g. "#ff0000" for red. + Creates a color picker for user to select a color as string input. Can be used as an input to pass a color value to a function or as an output to display a color value. Demos: color_picker, color_generator """ @@ -46,7 +43,7 @@ def __init__( value: default text to provide in color picker. If callable, the function will be called whenever the app loads to set the initial value of the component. label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. info: additional component description. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise.sed (e.g. to cancel it) via this component's .load_event attribute. show_label: if True, will display label. container: If True, will place the component in a container - providing some extra padding around the border. scale: relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer. @@ -80,12 +77,24 @@ def api_info(self) -> dict[str, Any]: return {"type": "string"} def preprocess(self, payload: str | None) -> str | None: + """ + Parameters: + payload: Color as hex string + Returns: + Passes selected color value as a hex `str` into the function. + """ if payload is None: return None else: return str(payload) def postprocess(self, value: str | None) -> str | None: + """ + Parameters: + value: Expects a hex `str` returned from function and sets color picker value to it. + Returns: + A `str` value that is set as the color picker value. + """ if value is None: return None else: diff --git a/gradio/components/dataframe.py b/gradio/components/dataframe.py index a69580b8005dd..13cdf45d93bd1 100644 --- a/gradio/components/dataframe.py +++ b/gradio/components/dataframe.py @@ -54,10 +54,7 @@ class DataframeData(GradioModel): @document() class Dataframe(Component): """ - Accepts or displays 2D input through a spreadsheet-like component for dataframes. - Preprocessing: passes the uploaded spreadsheet data as a {pandas.DataFrame}, {numpy.array}, {polars.DataFrame}, or {List[List]} depending on `type` - Postprocessing: expects a {pandas.DataFrame}, {pandas.Styler}, {numpy.array}, {polars.DataFrame}, {List[List]}, {List}, a {Dict} with keys `data` (and optionally `headers`), or {str} path to a csv, which is rendered in the spreadsheet. - Examples-format: a {str} filepath to a csv with data, a pandas dataframe, a polars dataframe, or a list of lists (excluding headers) where each sublist is a row of data. + This component displays a table of value spreadsheet-like component. Can be used to display data as an output component, or as an input to collect data from the user. Demos: filter_records, matrix_transpose, tax_calculator, sort_records """ @@ -111,7 +108,7 @@ def __init__( latex_delimiters: A list of dicts of the form {"left": open delimiter (str), "right": close delimiter (str), "display": whether to display in newline (bool)} that will be used to render LaTeX expressions. If not provided, `latex_delimiters` is set to `[{ "left": "$$", "right": "$$", "display": True }]`, so only expressions enclosed in $$ delimiters will be rendered as LaTeX, and in a new line. Pass in an empty list to disable LaTeX rendering. For more information, see the [KaTeX documentation](https://katex.org/docs/autorender.html). Only applies to columns whose datatype is "markdown". label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. show_label: if True, will display label. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. height: The maximum height of the dataframe, specified in pixels if a number is passed, or in CSS units if a string is passed. If more rows are created than can fit in the height, a scrollbar will appear. scale: relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer. min_width: minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first. @@ -193,7 +190,13 @@ def __init__( def preprocess( self, payload: DataframeData - ) -> pd.DataFrame | np.ndarray | pl.DataFrame | list: + ) -> pd.DataFrame | np.ndarray | pl.DataFrame | list[list]: + """ + Parameters: + payload: the uploaded spreadsheet data as an object with `headers` and `data` attributes + Returns: + Passes the uploaded spreadsheet data as a `pandas.DataFrame`, `numpy.array`, `polars.DataFrame`, or native 2D Python `list[list]` depending on `type` + """ if self.type == "pandas": if payload.headers is not None: return pd.DataFrame(payload.data, columns=payload.headers) @@ -208,7 +211,7 @@ def preprocess( if self.type == "numpy": return np.array(payload.data) elif self.type == "array": - return payload.data + return payload.data # type: ignore else: raise ValueError( "Unknown type: " @@ -228,6 +231,12 @@ def postprocess( | str | None, ) -> DataframeData: + """ + Parameters: + value: Expects data any of these formats: `pandas.DataFrame`, `pandas.Styler`, `numpy.array`, `polars.DataFrame`, `list[list]`, `list`, or a `dict` with keys 'data' (and optionally 'headers'), or `str` path to a csv, which is rendered as the spreadsheet. + Returns: + the uploaded spreadsheet data as an object with `headers` and `data` attributes + """ if value is None: return self.postprocess(self.empty_input) if isinstance(value, dict): diff --git a/gradio/components/dataset.py b/gradio/components/dataset.py index e778751c12aed..846b114b14b31 100644 --- a/gradio/components/dataset.py +++ b/gradio/components/dataset.py @@ -19,10 +19,7 @@ @document() class Dataset(Component): """ - Used to create an output widget for showing datasets. Used to render the examples - box. - Preprocessing: passes the selected sample either as a {list} of data (if type="value") or as an {int} index (if type="index") - Postprocessing: expects a {list} of {lists} corresponding to the dataset data. + Creates a gallery or table to display data samples. This component is designed for internal use to display examples. """ EVENTS = [Events.click, Events.select] @@ -47,6 +44,7 @@ def __init__( ): """ Parameters: + label: The label for this component, appears above the component. components: Which component types to show in this dataset widget, can be passed in as a list of string names or Components instances. The following components are supported in a Dataset: Audio, Checkbox, CheckboxGroup, ColorPicker, Dataframe, Dropdown, File, HTML, Image, Markdown, Model3D, Number, Radio, Slider, Textbox, TimeSeries, Video samples: a nested list of samples. Each sublist within the outer list represents a data sample, and each element within the sublist represents an value for each component headers: Column headers in the Dataset widget, should be the same len as components. If not provided, inferred from component labels @@ -127,13 +125,25 @@ def get_config(self): return config - def preprocess(self, payload: int) -> int | list[list] | None: + def preprocess(self, payload: int) -> int | list | None: + """ + Parameters: + payload: the index of the selected example in the dataset + Returns: + Passes the selected sample either as a `list` of data corresponding to each input component (if `type` is "value") or as an `int` index (if `type` is "index") + """ if self.type == "index": return payload elif self.type == "values": return self.samples[payload] def postprocess(self, samples: list[list]) -> dict: + """ + Parameters: + samples: Expects a `list[list]` corresponding to the dataset data, can be used to update the dataset. + Returns: + Returns the updated dataset data as a `dict` with the key "samples". + """ return { "samples": samples, "__type__": "update", diff --git a/gradio/components/dropdown.py b/gradio/components/dropdown.py index cba2e7185289f..efab899d97665 100644 --- a/gradio/components/dropdown.py +++ b/gradio/components/dropdown.py @@ -16,10 +16,8 @@ @document() class Dropdown(FormComponent): """ - Creates a dropdown of choices from which entries can be selected. - Preprocessing: passes the value of the selected dropdown entry as a {str} or its index as an {int} into the function, depending on `type`. - Postprocessing: expects a {str} corresponding to the value of the dropdown entry to be selected. - Examples-format: a {str} representing the drop down value to select. + Creates a dropdown of choices from which a single entry or multiple entries can be selected (as an input component) or displayed (as an output component). + Demos: sentence_builder, titanic_survival """ @@ -59,7 +57,7 @@ def __init__( filterable: If True, user will be able to type into the dropdown and filter the choices by typing. Can only be set to False if `allow_custom_value` is False. label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. info: additional component description. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. show_label: if True, will display label. container: If True, will place the component in a container - providing some extra padding around the border. scale: relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer. @@ -68,7 +66,7 @@ def __init__( visible: If False, component will be hidden. elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles. elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles. - render: If False, component will not render be rendered in the Blocks context. Should be used if the intention is to assign event listeners now but render the component later. + render: If False, component will not be rendered in the Blocks context. Should be used if the intention is to assign event listeners now but render the component later. """ self.choices = ( # Although we expect choices to be a list of tuples, it can be a list of tuples if the Gradio app @@ -136,6 +134,12 @@ def example_inputs(self) -> Any: def preprocess( self, payload: str | int | float | list[str | int | float] | None ) -> str | int | float | list[str | int | float] | list[int | None] | None: + """ + Parameters: + payload: the value of the selected dropdown choice(s) + Returns: + Passes the value of the selected dropdown choice as a `str | int | float` or its index as an `int` into the function, depending on `type`. Or, if `multiselect` is True, passes the values of the selected dropdown choices as a list of correspoding values/indices instead. + """ if self.type == "value": return payload elif self.type == "index": @@ -167,6 +171,12 @@ def _warn_if_invalid_choice(self, value): def postprocess( self, value: str | int | float | list[str | int | float] | None ) -> str | int | float | list[str | int | float] | None: + """ + Parameters: + value: Expects a `str | int | float` corresponding to the value of the dropdown entry to be selected. Or, if `multiselect` is True, expects a `list` of values corresponding to the selected dropdown entries. + Returns: + Returns the values of the selected dropdown entry or entries. + """ if value is None: return None if self.multiselect: diff --git a/gradio/components/duplicate_button.py b/gradio/components/duplicate_button.py index 4012712352bda..e6be4d68cc57e 100644 --- a/gradio/components/duplicate_button.py +++ b/gradio/components/duplicate_button.py @@ -17,8 +17,6 @@ class DuplicateButton(Button): """ Button that triggers a Spaces Duplication, when the demo is on Hugging Face Spaces. Does nothing locally. - Preprocessing: passes the button value as a {str} into the function - Postprocessing: expects a {str} to be returned from a function, which is set as the label of the button """ is_template = True @@ -44,7 +42,7 @@ def __init__( """ Parameters: value: Default text for the button to display. If callable, the function will be called whenever the app loads to set the initial value of the component. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. variant: 'primary' for main call-to-action, 'secondary' for a more subdued style, 'stop' for a stop button. size: Size of the button. Can be "sm" or "lg". icon: URL or path to the icon file to display within the button. If None, no icon will be displayed. diff --git a/gradio/components/fallback.py b/gradio/components/fallback.py index 3c9b98a6315fd..9e3359a7559e7 100644 --- a/gradio/components/fallback.py +++ b/gradio/components/fallback.py @@ -3,9 +3,23 @@ class Fallback(Component): def preprocess(self, payload): + """ + This docstring is used to generate the docs for this custom component. + Parameters: + payload: the data to be preprocessed, sent from the frontend + Returns: + the data after preprocessing, sent to the user's function in the backend + """ return payload def postprocess(self, value): + """ + This docstring is used to generate the docs for this custom component. + Parameters: + payload: the data to be postprocessed, sent from the user's function in the backend + Returns: + the data after postprocessing, sent to the frontend + """ return value def example_inputs(self): diff --git a/gradio/components/file.py b/gradio/components/file.py index eb0b35cd5a281..d198e5620da2b 100644 --- a/gradio/components/file.py +++ b/gradio/components/file.py @@ -20,11 +20,9 @@ @document() class File(Component): """ - Creates a file component that allows uploading generic file (when used as an input) and or displaying generic files (output). - Preprocessing: passes the uploaded file as a {tempfile._TemporaryFileWrapper} or {List[tempfile._TemporaryFileWrapper]} depending on `file_count` (or a {bytes}/{List[bytes]} depending on `type`) - Postprocessing: expects function to return a {str} path to a file, or {List[str]} consisting of paths to files. - Examples-format: a {str} path to a local file that populates the component. - Demos: zip_to_json, zip_files + Creates a file component that allows uploading one or more generic files (when used as an input) or displaying generic files (as output). + + Demo: zip_files, zip_to_json """ EVENTS = [Events.change, Events.select, Events.clear, Events.upload] @@ -56,7 +54,7 @@ def __init__( file_types: List of file extensions or types of files to be uploaded (e.g. ['image', '.json', '.mp4']). "file" allows any file to be uploaded, "image" allows only image files to be uploaded, "audio" allows only audio files to be uploaded, "video" allows only video files to be uploaded, "text" allows only text files to be uploaded. type: Type of value to be returned by component. "file" returns a temporary file object with the same base name as the uploaded file, whose full path can be retrieved by file_obj.name, "binary" returns an bytes object. label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise.sed (e.g. to cancel it) via this component's .load_event attribute. show_label: if True, will display label. container: If True, will place the component in a container - providing some extra padding around the border. scale: relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer. @@ -125,7 +123,13 @@ def _process_single_file(self, f: FileData) -> NamedString | bytes: def preprocess( self, payload: ListFiles | FileData | None - ) -> bytes | NamedString | list[bytes | NamedString] | None: + ) -> bytes | str | list[bytes] | list[str] | None: + """ + Parameters: + payload: File information as a FileData object, or a list of FileData objects. + Returns: + Passes the file as a `str` or `bytes` object, or a list of `str` or list of `bytes` objects, depending on `type` and `file_count`. + """ if payload is None: return None if self.file_count == "single": @@ -135,11 +139,17 @@ def preprocess( return self._process_single_file(payload) else: if isinstance(payload, ListFiles): - return [self._process_single_file(f) for f in payload] + return [self._process_single_file(f) for f in payload] # type: ignore else: - return [self._process_single_file(payload)] + return [self._process_single_file(payload)] # type: ignore def postprocess(self, value: str | list[str] | None) -> ListFiles | FileData | None: + """ + Parameters: + value: Expects a `str` filepath, or a `list[str]` of filepaths. + Returns: + File information as a FileData object, or a list of FileData objects. + """ if value is None: return None if isinstance(value, list): diff --git a/gradio/components/file_explorer.py b/gradio/components/file_explorer.py index 302ae7e1eeb93..38d317c38ffd2 100644 --- a/gradio/components/file_explorer.py +++ b/gradio/components/file_explorer.py @@ -24,11 +24,10 @@ class FileExplorerData(GradioRootModel): @document() class FileExplorer(Component): """ - Creates a file explorer component that allows users to browse and select files on the machine hosting the Gradio app. - Preprocessing: passes the selected file or directory as a {str} path (relative to root) or {list[str}} depending on `file_count` - Postprocessing: expects function to return a {str} path to a file, or {List[str]} consisting of paths to files. - Examples-format: a {str} path to a local file that populates the component. - Demos: zip_to_json, zip_files + Creates a file explorer component that allows users to browse files on the machine hosting the Gradio app. As an input component, + it also allows users to select files to be used as input to a function, while as an output component, it displays selected files. + + Demos: file_explorer """ EVENTS = ["change"] @@ -64,7 +63,7 @@ def __init__( root_dir: Path to root directory to select files from. If not provided, defaults to current working directory. ignore_glob: The glob-tyle pattern that will be used to exclude files from the list. For example, "*.py" will exclude all .py files from the list. See the Python glob documentation at https://docs.python.org/3/library/glob.html for more information. label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise.sed (e.g. to cancel it) via this component's .load_event attribute. show_label: if True, will display label. container: If True, will place the component in a container - providing some extra padding around the border. scale: relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer. @@ -112,6 +111,12 @@ def example_inputs(self) -> Any: return ["Users", "gradio", "app.py"] def preprocess(self, payload: FileExplorerData | None) -> list[str] | str | None: + """ + Parameters: + payload: List of selected files as a FileExplorerData object. + Returns: + Passes the selected file or directory as a `str` path (relative to `root`) or `list[str}` depending on `file_count` + """ if payload is None: return None @@ -136,6 +141,12 @@ def _strip_root(self, path): return path def postprocess(self, value: str | list[str] | None) -> FileExplorerData | None: + """ + Parameters: + value: Expects function to return a `str` path to a file, or `list[str]` consisting of paths to files. + Returns: + A FileExplorerData object containing the selected files as a list of strings. + """ if value is None: return None diff --git a/gradio/components/gallery.py b/gradio/components/gallery.py index e0a43f7551f93..b61f5b54b96d3 100644 --- a/gradio/components/gallery.py +++ b/gradio/components/gallery.py @@ -35,9 +35,8 @@ class GalleryData(GradioRootModel): @document() class Gallery(Component): """ - Used to display a list of images as a gallery that can be scrolled through. - Preprocessing: A list of (image, caption) tuples. Each image is a filepath, numpy array or PIL.image depending on the `type` parameter. {List[tuple[str | PIL.Image | numpy.array, str | None]]}. - Postprocessing: expects a list of images in any format, {List[numpy.array | PIL.Image | str | pathlib.Path]}, or a {List} of (image, {str} caption) tuples and displays them. + Creates a gallery component that allows displaying a grid of images, and optionally captions. If used as an input, the user can upload images to the gallery. + If used as an output, the user can click on individual images to view them at a higher resolution. Demos: fake_gan """ @@ -79,7 +78,7 @@ def __init__( Parameters: value: List of images to display in the gallery by default. If callable, the function will be called whenever the app loads to set the initial value of the component. label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. show_label: if True, will display label. container: If True, will place the component in a container - providing some extra padding around the border. scale: relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer. @@ -134,15 +133,37 @@ def __init__( interactive=interactive, ) + def preprocess( + self, payload: GalleryData | None + ) -> ( + List[tuple[str, str | None]] + | List[tuple[_Image.Image, str | None]] + | List[tuple[np.ndarray, str | None]] + | None + ): + """ + Parameters: + payload: a list of images, or list of (image, caption) tuples + Returns: + Passes the list of images as a list of (image, caption) tuples, or a list of (image, None) tuples if no captions are provided (which is usually the case). The image can be a `str` file path, a `numpy` array, or a `PIL.Image` object depending on `type`. + """ + if payload is None or not payload.root: + return None + data = [] + for gallery_element in payload.root: + image = self.convert_to_type(gallery_element.image.path, self.type) # type: ignore + data.append((image, gallery_element.caption)) + return data + def postprocess( self, value: list[GalleryImageType | CaptionedGalleryImageType] | None, ) -> GalleryData: """ Parameters: - value: list of images, or list of (image, caption) tuples + value: Expects the function to return a `list` of images, or `list` of (image, `str` caption) tuples. Each image can be a `str` file path, a `numpy` array, or a `PIL.Image` object. Returns: - list of string file paths to images in temp directory + a list of images, or list of (image, caption) tuples """ if value is None: return GalleryData(root=[]) @@ -193,22 +214,6 @@ def convert_to_type(img: str, type: Literal["filepath", "numpy", "pil"]): converted_image = np.array(converted_image) return converted_image - def preprocess( - self, payload: GalleryData | None - ) -> ( - List[tuple[str, str | None]] - | List[tuple[_Image.Image, str | None]] - | List[tuple[np.ndarray, str | None]] - | None - ): - if payload is None or not payload.root: - return None - data = [] - for gallery_element in payload.root: - image = self.convert_to_type(gallery_element.image.path, self.type) # type: ignore - data.append((image, gallery_element.caption)) - return data - def example_inputs(self) -> Any: return [ "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png" diff --git a/gradio/components/highlighted_text.py b/gradio/components/highlighted_text.py index ac0c7d96ac583..95e9c0b289cd4 100644 --- a/gradio/components/highlighted_text.py +++ b/gradio/components/highlighted_text.py @@ -26,8 +26,6 @@ class HighlightedTextData(GradioRootModel): class HighlightedText(Component): """ Displays text that contains spans that are highlighted by category or numerical value. - Preprocessing: passes a list of tuples as a {List[Tuple[str, float | str | None]]]} into the function. If no labels are provided, the text will be displayed as a single span. - Postprocessing: expects a {List[Tuple[str, float | str]]]} consisting of spans of text and their associated labels, or a {Dict} with two keys: (1) "text" whose value is the complete text, and (2) "entities", which is a list of dictionaries, each of which have the keys: "entity" (consisting of the entity label, can alternatively be called "entity_group"), "start" (the character index where the label starts), and "end" (the character index where the label ends). Entities should not overlap. Demos: diff_texts, text_analysis Guides: named-entity-recognition @@ -65,7 +63,7 @@ def __init__( combine_adjacent: If True, will merge the labels of adjacent tokens belonging to the same category. adjacent_separator: Specifies the separator to be used between tokens if combine_adjacent is True. label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. show_label: if True, will display label. container: If True, will place the component in a container - providing some extra padding around the border. scale: relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer. @@ -98,14 +96,27 @@ def __init__( def example_inputs(self) -> Any: return {"value": [{"token": "Hello", "class_or_confidence": "1"}]} + def preprocess( + self, payload: HighlightedTextData | None + ) -> list[tuple[str, str | float | None]] | None: + """ + Parameters: + payload: An instance of HighlightedTextData + Returns: + Passes the value as a list of tuples as a `list[tuple]` into the function. Each `tuple` consists of a `str` substring of the text (so the entire text is included) and `str | float | None` label, which is the category or confidence of that substring. + """ + if payload is None: + return None + return payload.model_dump() # type: ignore + def postprocess( self, value: list[tuple[str, str | float | None]] | dict | None ) -> HighlightedTextData | None: """ Parameters: - value: List of (word, category) tuples, or a dictionary of two keys: "text", and "entities", which itself is a list of dictionaries, each of which have the keys: "entity" (or "entity_group"), "start", and "end" + value: Expects a list of (word, category) tuples, or a dictionary of two keys: "text", and "entities", which itself is a list of dictionaries, each of which have the keys: "entity" (or "entity_group"), "start", and "end" Returns: - List of (word, category) tuples + An instance of HighlightedTextData """ if value is None: return None @@ -165,8 +176,3 @@ def postprocess( for o in value ] ) - - def preprocess(self, payload: HighlightedTextData | None) -> dict | None: - if payload is None: - return None - return payload.model_dump() diff --git a/gradio/components/html.py b/gradio/components/html.py index 8534446a2d12c..064d01366f486 100644 --- a/gradio/components/html.py +++ b/gradio/components/html.py @@ -15,9 +15,7 @@ @document() class HTML(Component): """ - Used to display arbitrary HTML output. - Preprocessing: this component does *not* accept input. - Postprocessing: expects a valid HTML {str}. + Creates a component to display arbitrary HTML output. As this component does not accept user input, it is rarely used as an input component. Demos: text_analysis Guides: key-features @@ -41,7 +39,7 @@ def __init__( Parameters: value: Default value. If callable, the function will be called whenever the app loads to set the initial value of the component. label: The label for this component. Is used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. show_label: This parameter has no effect. visible: If False, component will be hidden. elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles. @@ -63,9 +61,21 @@ def example_inputs(self) -> Any: return "

Hello

" def preprocess(self, payload: str | None) -> str | None: + """ + Parameters: + payload: string corresponding to the HTML + Returns: + (Rarely used) passes the HTML as a `str`. + """ return payload def postprocess(self, value: str | None) -> str | None: + """ + Parameters: + value: Expects a `str` consisting of valid HTML. + Returns: + Returns the HTML string. + """ return value def api_info(self) -> dict[str, Any]: diff --git a/gradio/components/image.py b/gradio/components/image.py index bd03991ac1970..b56de66f3df86 100644 --- a/gradio/components/image.py +++ b/gradio/components/image.py @@ -7,8 +7,8 @@ from typing import Any, Literal, cast import numpy as np +import PIL.Image from gradio_client.documentation import document, set_documentation_group -from PIL import Image as _Image # using _ to minimize namespace pollution from PIL import ImageOps import gradio.image_utils as image_utils @@ -18,16 +18,14 @@ from gradio.events import Events set_documentation_group("component") -_Image.init() # fixes https://github.com/gradio-app/gradio/issues/2843 +PIL.Image.init() # fixes https://github.com/gradio-app/gradio/issues/2843 @document() class Image(StreamingInput, Component): """ Creates an image component that can be used to upload images (as an input) or display images (as an output). - Preprocessing: passes the uploaded image as a {numpy.array}, {PIL.Image} or {str} filepath depending on `type`. For SVGs, the `type` parameter is ignored and the filepath of the SVG is returned. - Postprocessing: expects a {numpy.array}, {PIL.Image} or {str} or {pathlib.Path} filepath to an image and displays the image. - Examples-format: a {str} local filepath or URL to an image. + Demos: image_mod, image_mod_default_image Guides: image-classification-in-pytorch, image-classification-in-tensorflow, image-classification-with-vision-transformers, create-your-own-friends-with-a-gan """ @@ -44,7 +42,7 @@ class Image(StreamingInput, Component): def __init__( self, - value: str | _Image.Image | np.ndarray | None = None, + value: str | PIL.Image.Image | np.ndarray | None = None, *, height: int | str | None = None, width: int | str | None = None, @@ -78,7 +76,7 @@ def __init__( sources: List of sources for the image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "clipboard" allows users to paste an image from the clipboard. If None, defaults to ["upload", "webcam", "clipboard"] if streaming is False, otherwise defaults to ["webcam"]. type: The format the image is converted before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (height, width, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "filepath" passes a str path to a temporary file containing the image. If the image is SVG, the `type` is ignored and the filepath of the SVG is returned. label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. show_label: if True, will display label. show_download_button: If True, will display button to download image. container: If True, will place the component in a container - providing some extra padding around the border. @@ -145,7 +143,13 @@ def __init__( def preprocess( self, payload: FileData | None - ) -> np.ndarray | _Image.Image | str | None: + ) -> np.ndarray | PIL.Image.Image | str | None: + """ + Parameters: + payload: image data in the form of a FileData object + Returns: + Passes the uploaded image as a `numpy.array`, `PIL.Image` or `str` filepath depending on `type`. For SVGs, the `type` parameter is ignored and the filepath of the SVG is returned. + """ if payload is None: return payload file_path = Path(payload.path) @@ -162,7 +166,7 @@ def preprocess( if suffix.lower() == "svg": return str(file_path) - im = _Image.open(file_path) + im = PIL.Image.open(file_path) exif = im.getexif() # 274 is the code for image rotation and 1 means "correct orientation" if exif.get(274, 1) != 1 and hasattr(ImageOps, "exif_transpose"): @@ -184,8 +188,14 @@ def preprocess( ) def postprocess( - self, value: np.ndarray | _Image.Image | str | Path | None + self, value: np.ndarray | PIL.Image.Image | str | Path | None ) -> FileData | None: + """ + Parameters: + value: Expects a `numpy.array`, `PIL.Image`, or `str` or `pathlib.Path` filepath to an image which is displayed. + Returns: + Returns the image as a `FileData` object. + """ if value is None: return None if isinstance(value, str) and value.lower().endswith(".svg"): diff --git a/gradio/components/image_editor.py b/gradio/components/image_editor.py index a3bfe63965155..68c6f5c401648 100644 --- a/gradio/components/image_editor.py +++ b/gradio/components/image_editor.py @@ -90,10 +90,9 @@ def __post_init__(self): @document() class ImageEditor(Component): """ - Creates an image component that can be used to upload and edit images (as an input) or display images (as an output). - Preprocessing: passes the uploaded images as a dictionary with keys: `background`, `layers`, and `composite`. The values corresponding to `background` and `composite` are images, while `layers` is a list of images. The images are of type PIL.Image, np.array, or str filepath, depending on the `type` parameter. - Postprocessing: expects a dictionary with keys: `background`, `layers`, and `composite`. The values corresponding to `background` and `composite` should be images or None, while `layers` should be a list of images. Images can be of type PIL.Image, np.array, or str filepath/URL. Or, the value can be simply a single image, in which case it will be used as the background. - Examples-format: a dictionary with keys: `background`, `layers`, and `composite`. The values corresponding to `background` and `composite` should be strings or None, while `layers` should be a list of strings. The image corresponding to `composite`, if not None, is used as the example image. Otherwise, the image corresonding to `background` is used. The strings should be filepaths or URLs. Or, the value can be simply a single string filepath/URL to an image, which is used directly as the example image. + Creates an image component that, as an input, can be used to upload and edit images using simple editing tools such + as brushes, strokes, cropping, and layers. Or, as an output, this component can be used to display images. + Demos: image_editor """ @@ -149,7 +148,7 @@ def __init__( sources: List of sources that can be used to set the background image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "clipboard" allows users to paste an image from the clipboard. type: The format the images are converted to before being passed into the prediction function. "numpy" converts the images to numpy arrays with shape (height, width, 3) and values from 0 to 255, "pil" converts the images to PIL image objects, "filepath" passes images as str filepaths to temporary copies of the images. label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. show_label: if True, will display label. show_download_button: If True, will display button to download image. container: If True, will place the component in a container - providing some extra padding around the border. @@ -250,6 +249,12 @@ def convert_and_format_image( ) def preprocess(self, payload: EditorData | None) -> EditorValue | None: + """ + Parameters: + payload: An instance of `EditorData` consisting of the background image, layers, and composite image. + Returns: + Passes the uploaded images as an instance of EditorValue, which is just a `dict` with keys: 'background', 'layers', and 'composite'. The values corresponding to 'background' and 'composite' are images, while 'layers' is a `list` of images. The images are of type `PIL.Image`, `np.array`, or `str` filepath, depending on the `type` parameter. + """ if payload is None: return payload @@ -267,6 +272,12 @@ def preprocess(self, payload: EditorData | None) -> EditorValue | None: } def postprocess(self, value: EditorValue | ImageType | None) -> EditorData | None: + """ + Parameters: + value: Expects a EditorValue, which is just a dictionary with keys: 'background', 'layers', and 'composite'. The values corresponding to 'background' and 'composite' should be images or None, while `layers` should be a list of images. Images can be of type `PIL.Image`, `np.array`, or `str` filepath/URL. Or, the value can be simply a single image (`ImageType`), in which case it will be used as the background. + Returns: + An instance of `EditorData` consisting of the background image, layers, and composite image. + """ if value is None: return None elif isinstance(value, dict): diff --git a/gradio/components/json_component.py b/gradio/components/json_component.py index c444e828ebfa8..ae9d4b28e38ca 100644 --- a/gradio/components/json_component.py +++ b/gradio/components/json_component.py @@ -17,9 +17,7 @@ @document() class JSON(Component): """ - Used to display arbitrary JSON output prettily. - Preprocessing: this component does *not* accept input. - Postprocessing: expects a {str} filepath to a file containing valid JSON -- or a {list} or {dict} that is valid JSON + Used to display arbitrary JSON output prettily. As this component does not accept user input, it is rarely used as an input component. Demos: zip_to_json, blocks_xray """ @@ -45,7 +43,7 @@ def __init__( Parameters: value: Default value. If callable, the function will be called whenever the app loads to set the initial value of the component. label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. show_label: if True, will display label. container: If True, will place the component in a container - providing some extra padding around the border. scale: relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer. @@ -69,7 +67,22 @@ def __init__( value=value, ) + def preprocess(self, payload: dict | list | None) -> dict | list | None: + """ + Parameters: + payload: JSON value as a `dict` or `list` + Returns: + Passes the JSON value as a `dict` or `list` depending on the value. + """ + return payload + def postprocess(self, value: dict | list | str | None) -> dict | list | None: + """ + Parameters: + value: Expects a `str` filepath to a file containing valid JSON -- or a `list` or `dict` that is valid JSON + Returns: + Returns the JSON as a `list` or `dict`. + """ if value is None: return None if isinstance(value, str): @@ -77,9 +90,6 @@ def postprocess(self, value: dict | list | str | None) -> dict | list | None: else: return value - def preprocess(self, payload: dict | list | str | None) -> dict | list | str | None: - return payload - def example_inputs(self) -> Any: return {"foo": "bar"} diff --git a/gradio/components/label.py b/gradio/components/label.py index bb910dfd27a7f..1a273718db1bd 100644 --- a/gradio/components/label.py +++ b/gradio/components/label.py @@ -29,9 +29,8 @@ class LabelData(GradioModel): @document() class Label(Component): """ - Displays a classification label, along with confidence scores of top categories, if provided. - Preprocessing: this component does *not* accept input. - Postprocessing: expects a {Dict[str, float]} of classes and confidences, or {str} with just the class or an {int}/{float} for regression outputs, or a {str} path to a .json file containing a json dictionary in the structure produced by Label.postprocess(). + Displays a classification label, along with confidence scores of top categories, if provided. As this component does not + accept user input, it is rarely used as an input component. Demos: main_note, titanic_survival Guides: image-classification-in-pytorch, image-classification-in-tensorflow, image-classification-with-vision-transformers @@ -63,7 +62,7 @@ def __init__( value: Default value to show in the component. If a str or number is provided, simply displays the string or number. If a {Dict[str, float]} of classes and confidences is provided, displays the top class on top and the `num_top_classes` below, along with their confidence bars. If callable, the function will be called whenever the app loads to set the initial value of the component. num_top_classes: number of most confident classes to show. label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. show_label: if True, will display label. container: If True, will place the component in a container - providing some extra padding around the border. scale: relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer. @@ -90,9 +89,32 @@ def __init__( value=value, ) + def preprocess( + self, payload: LabelData | None + ) -> dict[str, float] | str | int | float | None: + """ + Parameters: + payload: An instance of `LabelData` containing the label and confidences. + Returns: + Depending on the value, passes the label as a `str | int | float`, or the labels and confidences as a `dict[str, float]`. + """ + if payload is None: + return None + if payload.confidences is None: + return payload.label + return { + d["label"]: d["confidence"] for d in payload.model_dump()["confidences"] + } + def postprocess( - self, value: dict[str, float] | str | float | None + self, value: dict[str, float] | str | int | float | None ) -> LabelData | dict | None: + """ + Parameters: + value: Expects a `dict[str, float]` of classes and confidences, or `str` with just the class or an `int | float` for regression outputs, or a `str` path to a .json file containing a json dictionary in one of the preceding formats. + Returns: + Returns a `LabelData` object with the label and confidences, or a `dict` of the same format, or a `str` or `int` or `float` if the input was a single label. + """ if value is None or value == {}: return {} if isinstance(value, str) and value.endswith(".json") and Path(value).exists(): @@ -121,17 +143,6 @@ def postprocess( f"Instead, got a {type(value)}" ) - def preprocess( - self, payload: LabelData | None - ) -> dict[str, float] | str | float | None: - if payload is None: - return None - if payload.confidences is None: - return payload.label - return { - d["label"]: d["confidence"] for d in payload.model_dump()["confidences"] - } - def example_inputs(self) -> Any: return { "label": "Cat", diff --git a/gradio/components/line_plot.py b/gradio/components/line_plot.py index b019e072f7aaf..66642fb538cea 100644 --- a/gradio/components/line_plot.py +++ b/gradio/components/line_plot.py @@ -16,10 +16,8 @@ @document() class LinePlot(Plot): """ - Create a line plot. - - Preprocessing: this component does *not* accept input. - Postprocessing: expects a pandas dataframe with the data to plot. + Creates a line plot component to display data from a pandas DataFrame (as output). As this component does + not accept user input, it is rarely used as an input component. Demos: line_plot, live_dashboard """ @@ -111,7 +109,7 @@ def __init__( interactive: Whether users should be able to interact with the plot by panning or zooming with their mouse or trackpad. label: The (optional) label to display on the top left corner of the plot. show_label: Whether the label should be displayed. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. visible: Whether the plot should be visible. elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles. elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles. @@ -286,9 +284,24 @@ def create_plot( return chart + def preprocess(self, payload: AltairPlotData | None) -> AltairPlotData | None: + """ + Parameters: + payload: The data to display in a line plot. + Returns: + (Rarely used) passes the data displayed in the line plot as an AltairPlotData dataclass, which includes the plot information as a JSON string, as well as the type of plot (in this case, "line"). + """ + return payload + def postprocess( self, value: pd.DataFrame | dict | None ) -> AltairPlotData | dict | None: + """ + Parameters: + value: Expects a pandas DataFrame containing the data to display in the line plot. The DataFrame should contain at least two columns, one for the x-axis (corresponding to this component's `x` argument) and one for the y-axis (corresponding to `y`). + Returns: + The data to display in a line plot, in the form of an AltairPlotData dataclass, which includes the plot information as a JSON string, as well as the type of plot (in this case, "line"). + """ # if None or update if value is None or isinstance(value, dict): return value @@ -322,6 +335,3 @@ def postprocess( def example_inputs(self) -> Any: return None - - def preprocess(self, value: AltairPlotData | None) -> AltairPlotData | None: - return value diff --git a/gradio/components/login_button.py b/gradio/components/login_button.py index 04b1a97b34c66..969a501b38500 100644 --- a/gradio/components/login_button.py +++ b/gradio/components/login_button.py @@ -17,7 +17,7 @@ @document() class LoginButton(Button): """ - Button that redirects the user to Sign with Hugging Face using OAuth. + Creates a button that redirects the user to Sign with Hugging Face using OAuth. """ is_template = True diff --git a/gradio/components/logout_button.py b/gradio/components/logout_button.py index c80822e0e0017..c0ae12ea366bc 100644 --- a/gradio/components/logout_button.py +++ b/gradio/components/logout_button.py @@ -14,7 +14,7 @@ @document() class LogoutButton(Button): """ - Button to log out a user from a Space. + Creates a Button to log out a user from a Space using OAuth. Note: `LogoutButton` component is deprecated. Please use `gr.LoginButton` instead which handles both the login and logout processes. diff --git a/gradio/components/markdown.py b/gradio/components/markdown.py index cdf1a45da4bf9..09f47e25cef9b 100644 --- a/gradio/components/markdown.py +++ b/gradio/components/markdown.py @@ -16,9 +16,8 @@ @document() class Markdown(Component): """ - Used to render arbitrary Markdown output. Can also render latex enclosed by dollar signs. - Preprocessing: this component does *not* accept input. - Postprocessing: expects a valid {str} that can be rendered as Markdown. + Used to render arbitrary Markdown output. Can also render latex enclosed by dollar signs. As this component does not accept user input, + it is rarely used as an input component. Demos: blocks_hello, blocks_kinematics Guides: key-features @@ -47,7 +46,7 @@ def __init__( Parameters: value: Value to show in Markdown component. If callable, the function will be called whenever the app loads to set the initial value of the component. label: The label for this component. Is used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. show_label: This parameter has no effect. rtl: If True, sets the direction of the rendered text to right-to-left. Default is False, which renders text left-to-right. latex_delimiters: A list of dicts of the form {"left": open delimiter (str), "right": close delimiter (str), "display": whether to display in newline (bool)} that will be used to render LaTeX expressions. If not provided, `latex_delimiters` is set to `[{ "left": "$$", "right": "$$", "display": True }]`, so only expressions enclosed in $$ delimiters will be rendered as LaTeX, and in a new line. Pass in an empty list to disable LaTeX rendering. For more information, see the [KaTeX documentation](https://katex.org/docs/autorender.html). @@ -78,15 +77,27 @@ def __init__( value=value, ) + def preprocess(self, payload: str | None) -> str | None: + """ + Parameters: + payload: the `str` of Markdown corresponding to the displayed value. + Returns: + Passes the `str` of Markdown corresponding to the displayed value. + """ + return payload + def postprocess(self, value: str | None) -> str | None: + """ + Parameters: + value: Expects a valid `str` that can be rendered as Markdown. + Returns: + The same `str` as the input, but with leading and trailing whitespace removed. + """ if value is None: return None unindented_y = inspect.cleandoc(value) return unindented_y - def preprocess(self, payload: str | None) -> str | None: - return payload - def example_inputs(self) -> Any: return "# Hello!" diff --git a/gradio/components/model3d.py b/gradio/components/model3d.py index 8d50113125217..2eee3d9ffd2a9 100644 --- a/gradio/components/model3d.py +++ b/gradio/components/model3d.py @@ -17,9 +17,7 @@ @document() class Model3D(Component): """ - Component allows users to upload or view 3D Model files (.obj, .glb, or .gltf). - Preprocessing: This component passes the uploaded file as a {str}filepath. - Postprocessing: expects function to return a {str} or {pathlib.Path} filepath of type (.obj, glb, or .gltf) + Creates a component allows users to upload or view 3D Model files (.obj, .glb, .stl, or .gltf). Demos: model3D Guides: how-to-use-3D-model-component @@ -58,7 +56,7 @@ def __init__( ): """ Parameters: - value: path to (.obj, glb, or .gltf) file to show in model3D viewer. If callable, the function will be called whenever the app loads to set the initial value of the component. + value: path to (.obj, .glb, .stl, or .gltf) file to show in model3D viewer. If callable, the function will be called whenever the app loads to set the initial value of the component. clear_color: background color of scene, should be a tuple of 4 floats between 0 and 1 representing RGBA values. camera_position: initial camera position of scene, provided as a tuple of `(alpha, beta, radius)`. Each value is optional. If provided, `alpha` and `beta` should be in degrees reflecting the angular position along the longitudinal and latitudinal axes, respectively. Radius corresponds to the distance from the center of the object to the camera. zoom_speed: the speed of zooming in and out of the scene when the cursor wheel is rotated or when screen is pinched on a mobile device. Should be a positive float, increase this value to make zooming faster, decrease to make it slower. Affects the wheelPrecision property of the camera. @@ -67,7 +65,7 @@ def __init__( interactive: if True, will allow users to upload a file; if False, can only be used to display files. If not provided, this is inferred based on whether the component is used as an input or output. label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. show_label: if True, will display label. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. container: If True, will place the component in a container - providing some extra padding around the border. scale: relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer. min_width: minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first. @@ -97,11 +95,23 @@ def __init__( ) def preprocess(self, payload: FileData | None) -> str | None: + """ + Parameters: + payload: the uploaded file as an instance of `FileData`. + Returns: + Passes the uploaded file as a {str} filepath to the function. + """ if payload is None: return payload return payload.path def postprocess(self, value: str | Path | None) -> FileData | None: + """ + Parameters: + value: Expects function to return a {str} or {pathlib.Path} filepath of type (.obj, .glb, .stl, or .gltf) + Returns: + The uploaded file as an instance of `FileData`. + """ if value is None: return value return FileData(path=str(value), orig_name=Path(value).name) diff --git a/gradio/components/number.py b/gradio/components/number.py index 3f44edb43d6ac..41dca968561b1 100644 --- a/gradio/components/number.py +++ b/gradio/components/number.py @@ -17,9 +17,6 @@ class Number(FormComponent): """ Creates a numeric field for user to enter numbers as input or display numeric output. - Preprocessing: passes field value as a {float} or {int} into the function, depending on `precision`. - Postprocessing: expects an {int} or {float} returned from the function and sets field value to it. - Examples-format: a {float} or {int} representing the number's value. Demos: tax_calculator, titanic_survival, blocks_simple_squares """ @@ -52,7 +49,7 @@ def __init__( value: default value. If callable, the function will be called whenever the app loads to set the initial value of the component. label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. info: additional component description. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. show_label: if True, will display label. container: If True, will place the component in a container - providing some extra padding around the border. scale: relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer. @@ -108,7 +105,13 @@ def _round_to_precision(num: float | int, precision: int | None) -> float | int: else: return round(num, precision) - def preprocess(self, payload: float | None) -> float | None: + def preprocess(self, payload: float | None) -> float | int | None: + """ + Parameters: + payload: the field value. + Returns: + Passes field value as a `float` or `int` into the function, depending on `precision`. + """ if payload is None: return None elif self.minimum is not None and payload < self.minimum: @@ -119,7 +122,13 @@ def preprocess(self, payload: float | None) -> float | None: ) return self._round_to_precision(payload, self.precision) - def postprocess(self, value: float | None) -> float | None: + def postprocess(self, value: float | int | None) -> float | int | None: + """ + Parameters: + value: Expects an `int` or `float` returned from the function and sets field value to it. + Returns: + The (optionally rounded) field value as a `float` or `int` depending on `precision`. + """ if value is None: return None return self._round_to_precision(value, self.precision) diff --git a/gradio/components/paramviewer.py b/gradio/components/paramviewer.py index 66220d54733dd..4e3ca5fdd9c02 100644 --- a/gradio/components/paramviewer.py +++ b/gradio/components/paramviewer.py @@ -2,6 +2,8 @@ from typing import Literal, TypedDict +from gradio_client.documentation import document, set_documentation_group + from gradio.components.base import Component from gradio.events import Events @@ -9,12 +11,19 @@ class Parameter(TypedDict): type: str description: str - default: str + default: str | None + + +set_documentation_group("component") +@document() class ParamViewer(Component): """ - Displays an interactive table of parameters and their descriptions and default values width syntax highlighting + Displays an interactive table of parameters and their descriptions and default values with syntax highlighting. For each parameter, + the user should provide a type (e.g. a `str`), a human-readable description, and a default value. As this component does not accept user input, + it is rarely used as an input component.Internally, this component is used to display the parameters of components in the Custom + Component Gallery (https://www.gradio.app/custom-components/gallery). """ EVENTS = [ @@ -24,7 +33,7 @@ class ParamViewer(Component): def __init__( self, - value: list[Parameter] | None = None, + value: dict[str, Parameter] | None = None, language: Literal["python", "typescript"] = "python", linkify: list[str] | None = None, every: float | None = None, @@ -34,12 +43,11 @@ def __init__( Parameters: value: A list of dictionaries with keys "type", "description", and "default" for each parameter. language: The language to display the code in. One of "python" or "typescript". - linkify: A list of strings to linkify. If a string is found in the description, it will be linked to the corresponding url. + linkify: A list of strings to linkify. If any of these strings is found in the description, it will be rendered as a link. every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. render: If False, component will not render be rendered in the Blocks context. Should be used if the intention is to assign event listeners now but render the component later. - """ - self.value = value + self.value = value or {} self.language = language self.linkify = linkify super().__init__( @@ -48,26 +56,32 @@ def __init__( render=render, ) - def preprocess(self, payload: list[Parameter]) -> list[Parameter]: + def preprocess(self, payload: dict[str, Parameter]) -> dict[str, Parameter]: """ Parameters: - payload: A list of dictionaries with keys "type", "description", and "default" for each parameter. + payload: A `dict[str, dict]`. The key in the outer dictionary is the parameter name, while the inner dictionary has keys "type", "description", and (optionally) "default" for each parameter. Returns: - A list of dictionaries with keys "type", "description", and "default" for each parameter. + (Rarely used) passes value as a `dict[str, dict]`. The key in the outer dictionary is the parameter name, while the inner dictionary has keys "type", "description", and (optionally) "default" for each parameter. """ return payload - def postprocess(self, value: list[Parameter]) -> list[Parameter]: + def postprocess(self, value: dict[str, Parameter]) -> dict[str, Parameter]: """ Parameters: - value: A list of dictionaries with keys "type", "description", and "default" for each parameter. + value: Expects value as a `dict[str, dict]`. The key in the outer dictionary is the parameter name, while the inner dictionary has keys "type", "description", and (optionally) "default" for each parameter. Returns: - A list of dictionaries with keys "type", "description", and "default" for each parameter. + The same value. """ return value def example_inputs(self): - return [{"type": "numpy", "description": "any valid json", "default": "None"}] + return { + "array": { + "type": "numpy", + "description": "any valid json", + "default": "None", + } + } def api_info(self): return {"type": {}, "description": "any valid json"} diff --git a/gradio/components/plot.py b/gradio/components/plot.py index 673b51dd59dee..4ce1cd445a855 100644 --- a/gradio/components/plot.py +++ b/gradio/components/plot.py @@ -4,10 +4,9 @@ import json from types import ModuleType -from typing import Any, Callable, Literal +from typing import Any, Literal import altair as alt -import pandas as pd from gradio_client.documentation import document, set_documentation_group from gradio import processing_utils @@ -31,9 +30,8 @@ class AltairPlotData(PlotData): @document() class Plot(Component): """ - Used to display various kinds of plots (matplotlib, plotly, or bokeh are supported). - Preprocessing: this component does *not* accept input. - Postprocessing: expects either a {matplotlib.figure.Figure}, a {plotly.graph_objects._figure.Figure}, or a {dict} corresponding to a bokeh plot (json_item format) + Creates a plot component to display various kinds of plots (matplotlib, plotly, altair, or bokeh plots are supported). As this component does + not accept user input, it is rarely used as an input component. Demos: altair_plot, outbreak_forecast, blocks_kinematics, stock_forecast, map_airbnb Guides: plot-component-for-maps @@ -44,7 +42,7 @@ class Plot(Component): def __init__( self, - value: Callable | None | pd.DataFrame = None, + value: Any | None = None, *, label: str | None = None, every: float | None = None, @@ -61,7 +59,7 @@ def __init__( Parameters: value: Optionally, supply a default plot object to display, must be a matplotlib, plotly, altair, or bokeh figure, or a callable. If callable, the function will be called whenever the app loads to set the initial value of the component. label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. show_label: if True, will display label. container: If True, will place the component in a container - providing some extra padding around the border. scale: relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer. @@ -98,12 +96,24 @@ def get_config(self): return config def preprocess(self, payload: PlotData | None) -> PlotData | None: + """ + Parameters: + payload: The data to display in the plot. + Returns: + (Rarely used) passes the data displayed in the plot as an PlotData dataclass, which includes the plot information as a JSON string, as well as the type of chart and the plotting library. + """ return payload def example_inputs(self) -> Any: return None - def postprocess(self, value) -> PlotData | None: + def postprocess(self, value: Any) -> PlotData | None: + """ + Parameters: + value: Expects plot data in one of these formats: a matplotlib.Figure, bokeh.Model, plotly.Figure, or altair.Chart object. + Returns: + PlotData: A dataclass containing the plot data as a JSON string, as well as the type of chart and the plotting library. + """ import matplotlib.figure if value is None: diff --git a/gradio/components/radio.py b/gradio/components/radio.py index 5ba9f45d5ff23..d30340a89fce3 100644 --- a/gradio/components/radio.py +++ b/gradio/components/radio.py @@ -16,9 +16,6 @@ class Radio(FormComponent): """ Creates a set of (string or numeric type) radio buttons of which only one can be selected. - Preprocessing: passes the value of the selected radio button as a {str} or {int} or {float} or its index as an {int} into the function, depending on `type`. - Postprocessing: expects a {str} or {int} or {float} corresponding to the value of the radio button to be selected. - Examples-format: a {str} representing the radio option to select. Demos: sentence_builder, titanic_survival, blocks_essay """ @@ -51,7 +48,7 @@ def __init__( type: Type of value to be returned by component. "value" returns the string of the choice selected, "index" returns the index of the choice selected. label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. info: Additional component description. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. show_label: if True, will display label. container: If True, will place the component in a container - providing some extra padding around the border. scale: Relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer. @@ -97,9 +94,9 @@ def example_inputs(self) -> Any: def preprocess(self, payload: str | int | float | None) -> str | int | float | None: """ Parameters: - payload: selected choice + payload: Selected choice in the radio group Returns: - value of the selected choice as string or index within choice list + Passes the value of the selected radio button as a `str | int | float`, or its index as an `int` into the function, depending on `type`. """ if self.type == "value": return payload @@ -117,6 +114,12 @@ def preprocess(self, payload: str | int | float | None) -> str | int | float | N ) def postprocess(self, value: str | int | float | None) -> str | int | float | None: + """ + Parameters: + value: Expects a `str | int | float` corresponding to the value of the radio button to be selected + Returns: + The same value + """ return value def api_info(self) -> dict[str, Any]: diff --git a/gradio/components/scatter_plot.py b/gradio/components/scatter_plot.py index 3e5dd4120751e..4bac3350bed79 100644 --- a/gradio/components/scatter_plot.py +++ b/gradio/components/scatter_plot.py @@ -17,10 +17,8 @@ @document() class ScatterPlot(Plot): """ - Create a scatter plot. - - Preprocessing: this component does *not* accept input. - Postprocessing: expects a pandas dataframe with the data to plot. + Creates a scatter plot component to display data from a pandas DataFrame (as output). As this component does + not accept user input, it is rarely used as an input component. Demos: scatter_plot Guides: creating-a-dashboard-from-bigquery-data @@ -127,7 +125,7 @@ def __init__( caption: The (optional) caption to display below the plot. interactive: Whether users should be able to interact with the plot by panning or zooming with their mouse or trackpad. label: The (optional) label to display on the top left corner of the plot. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. show_label: Whether the label should be displayed. visible: Whether the plot should be visible. elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles. @@ -309,9 +307,24 @@ def create_plot( return chart + def preprocess(self, payload: AltairPlotData | None) -> AltairPlotData | None: + """ + Parameters: + payload: The data to display in a scatter plot. + Returns: + (Rarely used) passes the data displayed in the scatter plot as an AltairPlotData dataclass, which includes the plot information as a JSON string, as well as the type of plot (in this case, "scatter"). + """ + return payload + def postprocess( self, value: pd.DataFrame | dict | None ) -> AltairPlotData | dict | None: + """ + Parameters: + value: Expects a pandas DataFrame containing the data to display in the scatter plot. The DataFrame should contain at least two columns, one for the x-axis (corresponding to this component's `x` argument) and one for the y-axis (corresponding to `y`). + Returns: + The data to display in a scatter plot, in the form of an AltairPlotData dataclass, which includes the plot information as a JSON string, as well as the type of plot (in this case, "scatter"). + """ # if None or update if value is None or isinstance(value, dict): return value @@ -347,6 +360,3 @@ def postprocess( def example_inputs(self) -> Any: return None - - def preprocess(self, payload: AltairPlotData | None) -> AltairPlotData | None: - return payload diff --git a/gradio/components/slider.py b/gradio/components/slider.py index 62c031dc7eb7f..3737e4267ec29 100644 --- a/gradio/components/slider.py +++ b/gradio/components/slider.py @@ -18,9 +18,6 @@ class Slider(FormComponent): """ Creates a slider that ranges from {minimum} to {maximum} with a step size of {step}. - Preprocessing: passes slider value as a {float} into the function. - Postprocessing: expects an {int} or {float} returned from function and sets slider value to it as long as it is within range. - Examples-format: A {float} or {int} representing the slider's value. Demos: sentence_builder, slider_release, interface_random_slider, blocks_random_slider Guides: create-your-own-friends-with-a-gan @@ -57,7 +54,7 @@ def __init__( step: increment between slider values. label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. info: additional component description. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. show_label: if True, will display label. container: If True, will place the component in a container - providing some extra padding around the border. scale: relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer. @@ -115,7 +112,19 @@ def get_random_value(self): return value def postprocess(self, value: float | None) -> float: + """ + Parameters: + value: Expects an {int} or {float} returned from function and sets slider value to it as long as it is within range (otherwise, sets to minimum value). + Returns: + The value of the slider within the range. + """ return self.minimum if value is None else value def preprocess(self, payload: float) -> float: + """ + Parameters: + payload: slider value + Returns: + Passes slider value as a {float} into the function. + """ return payload diff --git a/gradio/components/state.py b/gradio/components/state.py index 4841c108ef304..d96f5bc17b816 100644 --- a/gradio/components/state.py +++ b/gradio/components/state.py @@ -19,8 +19,6 @@ class State(Component): Special hidden component that stores session state across runs of the demo by the same user. The value of the State variable is cleared when the user refreshes the page. - Preprocessing: No preprocessing is performed - Postprocessing: No postprocessing is performed Demos: interface_state, blocks_simple_squares Guides: real-time-speech-recognition """ @@ -44,12 +42,24 @@ def __init__( raise TypeError( f"The initial value of `gr.State` must be able to be deepcopied. The initial value of type {type(value)} cannot be deepcopied." ) from err - super().__init__(value=self.value) + super().__init__(value=self.value, render=render) def preprocess(self, payload: Any) -> Any: + """ + Parameters: + payload: Value + Returns: + Passes a value of arbitrary type through. + """ return payload def postprocess(self, value: Any) -> Any: + """ + Parameters: + value: Expects a value of arbitrary type, as long as it can be deepcopied. + Returns: + Passes a value of arbitrary type through. + """ return value def api_info(self) -> dict[str, Any]: diff --git a/gradio/components/textbox.py b/gradio/components/textbox.py index c16cb651b379c..c1be9029c0a04 100644 --- a/gradio/components/textbox.py +++ b/gradio/components/textbox.py @@ -18,9 +18,6 @@ class Textbox(FormComponent): """ Creates a textarea for user to enter string input or display string output. - Preprocessing: passes textarea value as a {str} into the function. - Postprocessing: expects a {str} returned from function and sets textarea value to it. - Examples-format: a {str} representing the textbox input. Demos: hello_world, diff_texts, sentence_builder Guides: creating-a-chatbot, real-time-speech-recognition @@ -69,7 +66,7 @@ def __init__( placeholder: placeholder hint to provide behind textarea. label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. info: additional component description. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. show_label: if True, will display label. container: If True, will place the component in a container - providing some extra padding around the border. scale: relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer. @@ -118,9 +115,21 @@ def __init__( self.text_align = text_align def preprocess(self, payload: str | None) -> str | None: + """ + Parameters: + payload: the text entered in the textarea. + Returns: + Passes text value as a {str} into the function. + """ return None if payload is None else str(payload) def postprocess(self, value: str | None) -> str | None: + """ + Parameters: + value: Expects a {str} returned from function and sets textarea value to it. + Returns: + The value to display in the textarea. + """ return None if value is None else str(value) def api_info(self) -> dict[str, Any]: diff --git a/gradio/components/upload_button.py b/gradio/components/upload_button.py index 62da9d44dec8d..30ac1229f01ad 100644 --- a/gradio/components/upload_button.py +++ b/gradio/components/upload_button.py @@ -21,9 +21,7 @@ class UploadButton(Component): """ Used to create an upload button, when clicked allows a user to upload files that satisfy the specified file type or generic files (if file_type not set). - Preprocessing: passes the uploaded file as a {file-object} or {List[file-object]} depending on `file_count` (or a {bytes}/{List[bytes]} depending on `type`) - Postprocessing: expects function to return a {str} path to a file, or {List[str]} consisting of paths to files. - Examples-format: a {str} path to a local file that populates the component. + Demos: upload_button """ @@ -53,7 +51,7 @@ def __init__( Parameters: label: Text to display on the button. Defaults to "Upload a File". value: File or list of files to upload by default. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. variant: 'primary' for main call-to-action, 'secondary' for a more subdued style, 'stop' for a stop button. visible: If False, component will be hidden. size: Size of the button. Can be "sm" or "lg". @@ -139,7 +137,13 @@ def _process_single_file(self, f: FileData) -> bytes | NamedString: def preprocess( self, payload: ListFiles | FileData | None - ) -> bytes | NamedString | list[bytes | NamedString] | None: + ) -> bytes | str | list[bytes] | list[str] | None: + """ + Parameters: + payload: File information as a FileData object, or a list of FileData objects. + Returns: + Passes the file as a `str` or `bytes` object, or a list of `str` or list of `bytes` objects, depending on `type` and `file_count`. + """ if payload is None: return None @@ -150,11 +154,17 @@ def preprocess( return self._process_single_file(payload) else: if isinstance(payload, ListFiles): - return [self._process_single_file(f) for f in payload] + return [self._process_single_file(f) for f in payload] # type: ignore else: - return [self._process_single_file(payload)] + return [self._process_single_file(payload)] # type: ignore def postprocess(self, value: str | list[str] | None) -> ListFiles | FileData | None: + """ + Parameters: + value: Expects a `str` filepath, or a `list[str]` of filepaths. + Returns: + File information as a FileData object, or a list of FileData objects. + """ if value is None: return None if isinstance(value, list): diff --git a/gradio/components/video.py b/gradio/components/video.py index ef02697de3206..d5478521539ce 100644 --- a/gradio/components/video.py +++ b/gradio/components/video.py @@ -36,9 +36,7 @@ class Video(Component): combinations are .mp4 with h264 codec, .ogg with theora codec, and .webm with vp9 codec. If the component detects that the output video would not be playable in the browser it will attempt to convert it to a playable mp4 video. If the conversion fails, the original video is returned. - Preprocessing: passes the uploaded video as a {str} filepath or URL whose extension can be modified by `format`. - Postprocessing: expects a {str} or {pathlib.Path} filepath to a video which is displayed, or a {Tuple[str | pathlib.Path, str | pathlib.Path | None]} where the first element is a filepath to a video and the second element is an optional filepath to a subtitle file. - Examples-format: a {str} filepath to a local file that contains the video, or a {Tuple[str, str]} where the first element is a filepath to a video file and the second element is a filepath to a subtitle file. + Demos: video_identity, video_subtitle """ @@ -95,7 +93,7 @@ def __init__( height: The height of the displayed video, specified in pixels if a number is passed, or in CSS units if a string is passed. width: The width of the displayed video, specified in pixels if a number is passed, or in CSS units if a string is passed. label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. - every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. show_label: if True, will display label. container: If True, will place the component in a container - providing some extra padding around the border. scale: relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer. @@ -161,6 +159,12 @@ def __init__( ) def preprocess(self, payload: VideoData | None) -> str | None: + """ + Parameters: + payload: An instance of VideoData containing the video and subtitle files. + Returns: + Passes the uploaded video as a `str` filepath or URL whose extension can be modified by `format`. + """ if payload is None: return None assert payload.video.path @@ -221,32 +225,40 @@ def preprocess(self, payload: VideoData | None) -> str | None: return str(file_name) def postprocess( - self, y: str | Path | tuple[str | Path, str | Path | None] | None + self, value: str | Path | tuple[str | Path, str | Path | None] | None ) -> VideoData | None: - if y is None or y == [None, None] or y == (None, None): + """ + Parameters: + value: Expects a {str} or {pathlib.Path} filepath to a video which is displayed, or a {Tuple[str | pathlib.Path, str | pathlib.Path | None]} where the first element is a filepath to a video and the second element is an optional filepath to a subtitle file. + Returns: + VideoData object containing the video and subtitle files. + """ + if value is None or value == [None, None] or value == (None, None): return None - if isinstance(y, (str, Path)): - processed_files = (self._format_video(y), None) + if isinstance(value, (str, Path)): + processed_files = (self._format_video(value), None) - elif isinstance(y, (tuple, list)): - if len(y) != 2: + elif isinstance(value, (tuple, list)): + if len(value) != 2: raise ValueError( - f"Expected lists of length 2 or tuples of length 2. Received: {y}" + f"Expected lists of length 2 or tuples of length 2. Received: {value}" ) - if not (isinstance(y[0], (str, Path)) and isinstance(y[1], (str, Path))): + if not ( + isinstance(value[0], (str, Path)) and isinstance(value[1], (str, Path)) + ): raise TypeError( - f"If a tuple is provided, both elements must be strings or Path objects. Received: {y}" + f"If a tuple is provided, both elements must be strings or Path objects. Received: {value}" ) - video = y[0] - subtitle = y[1] + video = value[0] + subtitle = value[1] processed_files = ( self._format_video(video), self._format_subtitle(subtitle), ) else: - raise Exception(f"Cannot process type as video: {type(y)}") + raise Exception(f"Cannot process type as video: {type(value)}") assert processed_files[0] return VideoData(video=processed_files[0], subtitles=processed_files[1]) diff --git a/gradio/events.py b/gradio/events.py index 751f82b9c6db0..5922fa4597c24 100644 --- a/gradio/events.py +++ b/gradio/events.py @@ -224,7 +224,7 @@ def event_trigger( preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component). postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser. cancels: A list of other events to cancel when this listener is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method. Functions that have not yet run (or generators that are iterating) will be cancelled, but functions that are currently running will be allowed to finish. - every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled. + every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. trigger_mode: If "once" (default for all events except `.change()`) would not allow any submissions while an event is pending. If set to "multiple", unlimited submissions are allowed while pending, and "always_last" (default for `.change()` event) would allow a second submission after the pending event is complete. js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components. concurrency_limit: If set, this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `Blocks.queue()`, which itself is 1 by default). @@ -348,7 +348,7 @@ def on( preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component). postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser. cancels: A list of other events to cancel when this listener is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method. Functions that have not yet run (or generators that are iterating) will be cancelled, but functions that are currently running will be allowed to finish. - every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled. + every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs', return should be a list of values for output components. concurrency_limit: If set, this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `Blocks.queue()`, which itself is 1 by default). concurrency_id: If set, this is the id of the concurrency group. Events with the same concurrency_id will be limited by the lowest set concurrency_limit. diff --git a/js/_website/src/lib/assets/style.css b/js/_website/src/lib/assets/style.css index 945d02417d2a9..9977dbfd6943b 100644 --- a/js/_website/src/lib/assets/style.css +++ b/js/_website/src/lib/assets/style.css @@ -229,13 +229,10 @@ li > ul { padding-bottom: 0 !important; } -code { +pre > code { font-weight: 500 !important; - border: 1px solid #e2e8f0; border-radius: 0.25rem; - padding: 0.175rem 0.25rem; font-size: 0.8em !important; - background: #f7fafc; } h1 > code, diff --git a/js/_website/src/routes/[[version]]/docs/[doc]/+page.server.ts b/js/_website/src/routes/[[version]]/docs/[doc]/+page.server.ts index f7f9775c6d6a8..bdd213968326f 100644 --- a/js/_website/src/routes/[[version]]/docs/[doc]/+page.server.ts +++ b/js/_website/src/routes/[[version]]/docs/[doc]/+page.server.ts @@ -81,6 +81,34 @@ export async function load({ params, parent }) { }); } + if ("preprocess" in obj && "postprocess" in obj) { + obj.preprocess.return_doc.doc = style_formatted_text( + obj.preprocess.return_doc.doc + ); + obj.postprocess.parameter_doc[0].doc = style_formatted_text( + obj.postprocess.parameter_doc[0].doc + ); + + let preprocess_code_snippet = Prism.highlight( + `def predict( + value: ${obj.preprocess.return_doc.annotation} +) + ...`, + Prism.languages[language], + "python" + ); + + let postprocess_code_snippet = Prism.highlight( + `def predict(ยทยทยท) -> ${obj.postprocess.parameter_doc[0].annotation} + ... + return value`, + Prism.languages[language], + "python" + ); + obj.preprocess_code_snippet = preprocess_code_snippet; + obj.postprocess_code_snippet = postprocess_code_snippet; + } + if (obj.example) { obj.highlighted_example = Prism.highlight( obj.example, @@ -136,21 +164,6 @@ export async function load({ params, parent }) { } } if ("tags" in obj) { - if ("preprocessing" in obj.tags) { - obj.tags.preprocessing = style_formatted_text( - obj.tags.preprocessing - ); - } - if ("postprocessing" in obj.tags) { - obj.tags.postprocessing = style_formatted_text( - obj.tags.postprocessing - ); - } - if ("examples_format" in obj.tags) { - obj.tags.examples_format = style_formatted_text( - obj.tags.examples_format - ); - } if ("events" in obj.tags) { obj.tags.events = style_formatted_text(obj.tags.events); } diff --git a/js/_website/src/routes/[[version]]/docs/[doc]/+page.svelte b/js/_website/src/routes/[[version]]/docs/[doc]/+page.svelte index 8ead7556b3405..51ccbb6c7ec36 100644 --- a/js/_website/src/routes/[[version]]/docs/[doc]/+page.svelte +++ b/js/_website/src/routes/[[version]]/docs/[doc]/+page.svelte @@ -214,40 +214,82 @@ > -

{@html obj.description}

+

{@html obj.description}

{#if mode === "components"} -
-

- Behavior - -

-

- As input: - {@html obj.tags.preprocessing} -

-

- As output: - {@html obj.tags.postprocessing} -

- {#if obj.tags.examples_format} -

- Format expected for examples: +

+ Behavior + - {@html obj.tags.examples_format} +

+

+ As input component: + {@html obj.preprocess.return_doc.doc}

- {/if} - {#if obj.tags.events && obj.tags.events.length > 0} +

+ Your function should accept one of these types: +

+
+
{@html obj.preprocess_code_snippet}
+
+ +

+ As output component: + {@html obj.postprocess.parameter_doc[0].doc} +

+

+ Your function should return one of these types: +

+
+
{@html obj.postprocess_code_snippet}
+
+ {#if obj.tags.events && obj.tags.events.length > 0} +

+ Supported events: + {@html obj.tags.events} +

+ {/if} +
+ {:else} +
+

+ Behavior + +

- Supported events: - {@html obj.tags.events} + As input: + {@html obj.tags.preprocessing}

- {/if} -
+

+ As output: + {@html obj.tags.postprocessing} +

+ {#if obj.tags.examples_format} +

+ Format expected for examples: + {@html obj.tags.examples_format} +

+ {/if} + {#if obj.tags.events && obj.tags.events.length > 0} +

+ Supported events: + {@html obj.tags.events} +

+ {/if} + + {/if} {/if} {#if obj.example} diff --git a/test/test_gradio_component_cli.py b/test/test_gradio_component_cli.py index e66de55d5ef85..eacc15c2a2cdd 100644 --- a/test/test_gradio_component_cli.py +++ b/test/test_gradio_component_cli.py @@ -85,9 +85,9 @@ def test_do_not_replace_class_name_in_import_statement(tmp_path): configure_metadata=False, ) code = (tmp_path / "backend" / "gradio_myimage" / "myimage.py").read_text() - assert "from PIL import Image as _Image" in code + assert "import PIL.Image" in code assert "class MyImage" in code - assert "_Image.Image" in code + assert "PIL.Image.Image" in code def test_raises_if_directory_exists(tmp_path): From 59b9715f820f848fda5a3cf80a24e3e5aee57d05 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Wed, 31 Jan 2024 09:54:08 -0800 Subject: [PATCH 2/4] chore(deps): update 8bitjonny/gh-get-current-pr action to v3 (#7232) Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> --- .github/workflows/deploy-chromatic.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/deploy-chromatic.yml b/.github/workflows/deploy-chromatic.yml index 1842b64b92965..153317f8fbb7d 100644 --- a/.github/workflows/deploy-chromatic.yml +++ b/.github/workflows/deploy-chromatic.yml @@ -17,7 +17,7 @@ jobs: pr_number: ${{ steps.get-pr.outputs.number }} pr_labels: ${{ steps.get-pr.outputs.pr_labels }} steps: - - uses: 8BitJonny/gh-get-current-pr@2.2.0 + - uses: 8BitJonny/gh-get-current-pr@3.0.0 id: get-pr with: filterOutDraft: true From 6a7e98bfefdc9530f0390f0d780edc5a35266d56 Mon Sep 17 00:00:00 2001 From: Freddy Boulton Date: Wed, 31 Jan 2024 10:33:13 -0800 Subject: [PATCH 3/4] Fix hyphen-bug in gradio cc publish (#7229) * Fix hyphen-bug * add changeset * add changeset --------- Co-authored-by: gradio-pr-bot --- .changeset/spicy-rabbits-check.md | 5 +++++ gradio/cli/commands/components/publish.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) create mode 100644 .changeset/spicy-rabbits-check.md diff --git a/.changeset/spicy-rabbits-check.md b/.changeset/spicy-rabbits-check.md new file mode 100644 index 0000000000000..e38665f88295b --- /dev/null +++ b/.changeset/spicy-rabbits-check.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +feat:Fix hyphen-bug in gradio cc publish diff --git a/gradio/cli/commands/components/publish.py b/gradio/cli/commands/components/publish.py index c38ec6da9f86b..5e09f872b6579 100644 --- a/gradio/cli/commands/components/publish.py +++ b/gradio/cli/commands/components/publish.py @@ -141,7 +141,7 @@ def _publish( ] wheel_file = max( (p for p in distribution_files if p.suffix == ".whl"), - key=lambda s: semantic_version.Version(str(s).split("-")[1]), + key=lambda s: semantic_version.Version(str(s.name).split("-")[1]), ) if not wheel_file: raise ValueError( From 68a54a7a310d8d7072fdae930bf1cfdf12c45a7f Mon Sep 17 00:00:00 2001 From: aliabid94 Date: Wed, 31 Jan 2024 10:39:46 -0800 Subject: [PATCH 4/4] Improve chatbot streaming performance with diffs (#7102) * changes * add changeset * changes * add changeset * changes * channges * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * canges * changes * changes * changes * Update free-moose-guess.md * changes --------- Co-authored-by: Ali Abid Co-authored-by: gradio-pr-bot Co-authored-by: Ali Abid --- .changeset/free-moose-guess.md | 9 ++++ client/js/src/client.ts | 32 +++++++++++- client/js/src/types.ts | 2 +- client/js/src/utils.ts | 59 +++++++++++++++++++++ client/python/gradio_client/client.py | 20 ++++++-- client/python/gradio_client/utils.py | 74 ++++++++++++++++++++++++--- gradio/blocks.py | 45 +++++++++++++++- gradio/components/textbox.py | 4 +- gradio/queueing.py | 5 +- gradio/utils.py | 44 ++++++++++++++++ test/test_blocks.py | 43 ++++++++++++++++ 11 files changed, 312 insertions(+), 25 deletions(-) create mode 100644 .changeset/free-moose-guess.md diff --git a/.changeset/free-moose-guess.md b/.changeset/free-moose-guess.md new file mode 100644 index 0000000000000..78c667340246c --- /dev/null +++ b/.changeset/free-moose-guess.md @@ -0,0 +1,9 @@ +--- +"@gradio/client": minor +"gradio": minor +"gradio_client": minor +--- + +feat:Improve chatbot streaming performance with diffs + +Note that this PR changes the API format for generator functions, which would be a breaking change for any clients reading the EventStream directly diff --git a/client/js/src/client.ts b/client/js/src/client.ts index 6355d5a04cade..00ae24048e54d 100644 --- a/client/js/src/client.ts +++ b/client/js/src/client.ts @@ -11,7 +11,8 @@ import { set_space_hardware, set_space_timeout, hardware_types, - resolve_root + resolve_root, + apply_diff } from "./utils.js"; import type { @@ -288,6 +289,7 @@ export function api_factory( const last_status: Record = {}; let stream_open = false; let pending_stream_messages: Record = {}; // Event messages may be received by the SSE stream before the initial data POST request is complete. To resolve this race condition, we store the messages in a dictionary and process them when the POST request is complete. + let pending_diff_streams: Record = {}; let event_stream: EventSource | null = null; const event_callbacks: Record Promise> = {}; const unclosed_events: Set = new Set(); @@ -774,7 +776,8 @@ export function api_factory( } } }; - } else if (protocol == "sse_v1") { + } else if (protocol == "sse_v1" || protocol == "sse_v2") { + // latest API format. v2 introduces sending diffs for intermediate outputs in generative functions, which makes payloads lighter. fire_event({ type: "status", stage: "pending", @@ -867,6 +870,9 @@ export function api_factory( endpoint: _endpoint, fn_index }); + if (data && protocol === "sse_v2") { + apply_diff_stream(event_id!, data); + } } if (data) { fire_event({ @@ -904,6 +910,9 @@ export function api_factory( if (event_callbacks[event_id]) { delete event_callbacks[event_id]; } + if (event_id in pending_diff_streams) { + delete pending_diff_streams[event_id]; + } } } catch (e) { console.error("Unexpected client exception", e); @@ -936,6 +945,25 @@ export function api_factory( } ); + function apply_diff_stream(event_id: string, data: any): void { + let is_first_generation = !pending_diff_streams[event_id]; + if (is_first_generation) { + pending_diff_streams[event_id] = []; + data.data.forEach((value: any, i: number) => { + pending_diff_streams[event_id][i] = value; + }); + } else { + data.data.forEach((value: any, i: number) => { + let new_data = apply_diff( + pending_diff_streams[event_id][i], + value + ); + pending_diff_streams[event_id][i] = new_data; + data.data[i] = new_data; + }); + } + } + function fire_event(event: Event): void { const narrowed_listener_map: ListenerMap = listener_map; const listeners = narrowed_listener_map[event.type] || []; diff --git a/client/js/src/types.ts b/client/js/src/types.ts index 2b1869855ef0a..4e93a762b53f1 100644 --- a/client/js/src/types.ts +++ b/client/js/src/types.ts @@ -20,7 +20,7 @@ export interface Config { show_api: boolean; stylesheets: string[]; path: string; - protocol?: "sse_v1" | "sse" | "ws"; + protocol?: "sse_v2" | "sse_v1" | "sse" | "ws"; } export interface Payload { diff --git a/client/js/src/utils.ts b/client/js/src/utils.ts index 5883cfe0b16c9..36839113561d6 100644 --- a/client/js/src/utils.ts +++ b/client/js/src/utils.ts @@ -239,3 +239,62 @@ export const hardware_types = [ "a10g-large", "a100-large" ] as const; + +function apply_edit( + target: any, + path: (number | string)[], + action: string, + value: any +): any { + if (path.length === 0) { + if (action === "replace") { + return value; + } else if (action === "append") { + return target + value; + } + throw new Error(`Unsupported action: ${action}`); + } + + let current = target; + for (let i = 0; i < path.length - 1; i++) { + current = current[path[i]]; + } + + const last_path = path[path.length - 1]; + switch (action) { + case "replace": + current[last_path] = value; + break; + case "append": + current[last_path] += value; + break; + case "add": + if (Array.isArray(current)) { + current.splice(Number(last_path), 0, value); + } else { + current[last_path] = value; + } + break; + case "delete": + if (Array.isArray(current)) { + current.splice(Number(last_path), 1); + } else { + delete current[last_path]; + } + break; + default: + throw new Error(`Unknown action: ${action}`); + } + return target; +} + +export function apply_diff( + obj: any, + diff: [string, (number | string)[], any][] +): any { + diff.forEach(([action, path, value]) => { + obj = apply_edit(obj, path, action, value); + }); + + return obj; +} diff --git a/client/python/gradio_client/client.py b/client/python/gradio_client/client.py index 19a9cdb7418c7..0748e7cc45826 100644 --- a/client/python/gradio_client/client.py +++ b/client/python/gradio_client/client.py @@ -428,7 +428,12 @@ def submit( inferred_fn_index = self._infer_fn_index(api_name, fn_index) helper = None - if self.endpoints[inferred_fn_index].protocol in ("ws", "sse", "sse_v1"): + if self.endpoints[inferred_fn_index].protocol in ( + "ws", + "sse", + "sse_v1", + "sse_v2", + ): helper = self.new_helper(inferred_fn_index) end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper) future = self.executor.submit(end_to_end_fn, *args) @@ -998,13 +1003,15 @@ def _predict(*data) -> tuple: result = utils.synchronize_async( self._sse_fn_v0, data, hash_data, helper ) - elif self.protocol == "sse_v1": + elif self.protocol == "sse_v1" or self.protocol == "sse_v2": event_id = utils.synchronize_async( self.client.send_data, data, hash_data ) self.client.pending_event_ids.add(event_id) self.client.pending_messages_per_event[event_id] = [] - result = utils.synchronize_async(self._sse_fn_v1, helper, event_id) + result = utils.synchronize_async( + self._sse_fn_v1_v2, helper, event_id, self.protocol + ) else: raise ValueError(f"Unsupported protocol: {self.protocol}") @@ -1197,13 +1204,16 @@ async def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator): self.client.cookies, ) - async def _sse_fn_v1(self, helper: Communicator, event_id: str): - return await utils.get_pred_from_sse_v1( + async def _sse_fn_v1_v2( + self, helper: Communicator, event_id: str, protocol: Literal["sse_v1", "sse_v2"] + ): + return await utils.get_pred_from_sse_v1_v2( helper, self.client.headers, self.client.cookies, self.client.pending_messages_per_event, event_id, + protocol, ) diff --git a/client/python/gradio_client/utils.py b/client/python/gradio_client/utils.py index 630f16c2aa736..d520f556e7b2c 100644 --- a/client/python/gradio_client/utils.py +++ b/client/python/gradio_client/utils.py @@ -2,6 +2,7 @@ import asyncio import base64 +import copy import hashlib import json import mimetypes @@ -17,7 +18,7 @@ from enum import Enum from pathlib import Path from threading import Lock -from typing import Any, Callable, Optional, TypedDict +from typing import Any, Callable, Literal, Optional, TypedDict import fsspec.asyn import httpx @@ -381,22 +382,19 @@ async def get_pred_from_sse_v0( return task.result() -async def get_pred_from_sse_v1( +async def get_pred_from_sse_v1_v2( helper: Communicator, headers: dict[str, str], cookies: dict[str, str] | None, pending_messages_per_event: dict[str, list[Message | None]], event_id: str, + protocol: Literal["sse_v1", "sse_v2"], ) -> dict[str, Any] | None: done, pending = await asyncio.wait( [ asyncio.create_task(check_for_cancel(helper, headers, cookies)), asyncio.create_task( - stream_sse_v1( - helper, - pending_messages_per_event, - event_id, - ) + stream_sse_v1_v2(helper, pending_messages_per_event, event_id, protocol) ), ], return_when=asyncio.FIRST_COMPLETED, @@ -411,6 +409,9 @@ async def get_pred_from_sse_v1( assert len(done) == 1 for task in done: + exception = task.exception() + if exception: + raise exception return task.result() @@ -502,13 +503,15 @@ async def stream_sse_v0( raise -async def stream_sse_v1( +async def stream_sse_v1_v2( helper: Communicator, pending_messages_per_event: dict[str, list[Message | None]], event_id: str, + protocol: Literal["sse_v1", "sse_v2"], ) -> dict[str, Any]: try: pending_messages = pending_messages_per_event[event_id] + pending_responses_for_diffs = None while True: if len(pending_messages) > 0: @@ -540,6 +543,19 @@ async def stream_sse_v1( log=log_message, ) output = msg.get("output", {}).get("data", []) + if ( + msg["msg"] == ServerMessage.process_generating + and protocol == "sse_v2" + ): + if pending_responses_for_diffs is None: + pending_responses_for_diffs = list(output) + else: + for i, value in enumerate(output): + prev_output = pending_responses_for_diffs[i] + new_output = apply_diff(prev_output, value) + pending_responses_for_diffs[i] = new_output + output[i] = new_output + if output and status_update.code != Status.FINISHED: try: result = helper.prediction_processor(*output) @@ -557,6 +573,48 @@ async def stream_sse_v1( raise +def apply_diff(obj, diff): + obj = copy.deepcopy(obj) + + def apply_edit(target, path, action, value): + if len(path) == 0: + if action == "replace": + return value + elif action == "append": + return target + value + else: + raise ValueError(f"Unsupported action: {action}") + + current = target + for i in range(len(path) - 1): + current = current[path[i]] + + last_path = path[-1] + if action == "replace": + current[last_path] = value + elif action == "append": + current[last_path] += value + elif action == "add": + if isinstance(current, list): + current.insert(int(last_path), value) + else: + current[last_path] = value + elif action == "delete": + if isinstance(current, list): + del current[int(last_path)] + else: + del current[last_path] + else: + raise ValueError(f"Unknown action: {action}") + + return target + + for action, path, value in diff: + obj = apply_edit(obj, path, action, value) + + return obj + + ######################## # Data processing utils ######################## diff --git a/gradio/blocks.py b/gradio/blocks.py index b431e2306cded..f9661086d59fb 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -539,6 +539,7 @@ def __init__( self.enable_queue = True self.max_threads = 40 self.pending_streams = defaultdict(dict) + self.pending_diff_streams = defaultdict(dict) self.show_error = True self.head = head if css is not None and os.path.exists(css): @@ -1483,6 +1484,38 @@ def handle_streaming_outputs( data[i] = output_data return data + def handle_streaming_diffs( + self, + fn_index: int, + data: list, + session_hash: str | None, + run: int | None, + final: bool, + ) -> list: + if session_hash is None or run is None: + return data + first_run = run not in self.pending_diff_streams[session_hash] + if first_run: + self.pending_diff_streams[session_hash][run] = [None] * len(data) + last_diffs = self.pending_diff_streams[session_hash][run] + + for i in range(len(self.dependencies[fn_index]["outputs"])): + if final: + data[i] = last_diffs[i] + continue + + if first_run: + last_diffs[i] = data[i] + else: + prev_chunk = last_diffs[i] + last_diffs[i] = data[i] + data[i] = utils.diff(prev_chunk, data[i]) + + if final: + del self.pending_diff_streams[session_hash][run] + + return data + async def process_api( self, fn_index: int, @@ -1565,11 +1598,19 @@ async def process_api( data = self.postprocess_data(fn_index, result["prediction"], state) is_generating, iterator = result["is_generating"], result["iterator"] if is_generating or was_generating: + run = id(old_iterator) if was_generating else id(iterator) data = self.handle_streaming_outputs( fn_index, data, session_hash=session_hash, - run=id(old_iterator) if was_generating else id(iterator), + run=run, + ) + data = self.handle_streaming_diffs( + fn_index, + data, + session_hash=session_hash, + run=run, + final=not is_generating, ) block_fn.total_runtime += result["duration"] @@ -1611,7 +1652,7 @@ def get_config_file(self): "is_colab": utils.colab_check(), "stylesheets": self.stylesheets, "theme": self.theme.name, - "protocol": "sse_v1", + "protocol": "sse_v2", "body_css": { "body_background_fill": self.theme._get_computed_value( "body_background_fill" diff --git a/gradio/components/textbox.py b/gradio/components/textbox.py index c1be9029c0a04..6dd9b8c73112e 100644 --- a/gradio/components/textbox.py +++ b/gradio/components/textbox.py @@ -6,9 +6,7 @@ from gradio_client.documentation import document, set_documentation_group -from gradio.components.base import ( - FormComponent, -) +from gradio.components.base import FormComponent from gradio.events import Events set_documentation_group("component") diff --git a/gradio/queueing.py b/gradio/queueing.py index c871d7989ec8f..8ed2f5cf766a1 100644 --- a/gradio/queueing.py +++ b/gradio/queueing.py @@ -584,10 +584,7 @@ async def process_events( response = None err = e for event in awake_events: - if response is None: - relevant_response = err - else: - relevant_response = old_response or old_err + relevant_response = response or err or old_err self.send_message( event, ServerMessage.process_completed, diff --git a/gradio/utils.py b/gradio/utils.py index a819953856dd9..ce6021375eef1 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -1043,3 +1043,47 @@ def __setitem__(self, key: K, value: V) -> None: def get_cache_folder() -> Path: return Path(os.environ.get("GRADIO_EXAMPLES_CACHE", "gradio_cached_examples")) + + +def diff(old, new): + def compare_objects(obj1, obj2, path=None): + if path is None: + path = [] + edits = [] + + if obj1 == obj2: + return edits + + if type(obj1) != type(obj2): + edits.append(("replace", path, obj2)) + return edits + + if isinstance(obj1, str) and obj2.startswith(obj1): + edits.append(("append", path, obj2[len(obj1) :])) + return edits + + if isinstance(obj1, list): + common_length = min(len(obj1), len(obj2)) + for i in range(common_length): + edits.extend(compare_objects(obj1[i], obj2[i], path + [i])) + for i in range(common_length, len(obj1)): + edits.append(("delete", path + [i], None)) + for i in range(common_length, len(obj2)): + edits.append(("add", path + [i], obj2[i])) + return edits + + if isinstance(obj1, dict): + for key in obj1: + if key in obj2: + edits.extend(compare_objects(obj1[key], obj2[key], path + [key])) + else: + edits.append(("delete", path + [key], None)) + for key in obj2: + if key not in obj1: + edits.append(("add", path + [key], obj2[key])) + return edits + + edits.append(("replace", path, obj2)) + return edits + + return compare_objects(old, new) diff --git a/test/test_blocks.py b/test/test_blocks.py index cc4627c4ba5de..e01c4dd730bce 100644 --- a/test/test_blocks.py +++ b/test/test_blocks.py @@ -265,6 +265,49 @@ def generator(string): assert outputs == ["a", "b", "c"] demo.queue().launch(prevent_thread_lock=True) + def test_varying_output_forms_with_generators(self, connect): + generations = [ + {"a": 1}, + {"a": 1, "b": [1, 3]}, + {"b": [1, 3, 2]}, + 1, + 2, + 3, + [1, 2, {"x": 4, "y": 6}], + {"data": [1, 2, {"x": 4, "y": 6}]}, + None, + 1.2, + ] + + def generator(): + yield from generations + + def generator_random(): + indices = list(range(len(generations))) + random.shuffle(indices) + for i in indices: + time.sleep(random.random() / 5) + yield generations[i] + + with gr.Blocks() as demo: + btn1 = gr.Button() + btn2 = gr.Button() + output_json = gr.JSON() + btn1.click(generator, None, output_json, api_name="generator") + btn2.click(generator_random, None, output_json, api_name="generator_random") + + with connect(demo) as client: + outputs = [] + for output in client.submit(api_name="/generator"): + outputs.append(output) + assert outputs == generations + + outputs = [] + for output in client.submit(api_name="/generator_random"): + outputs.append(output) + for generation in generations: + assert generation in outputs + def test_socket_reuse(self): try: io = gr.Interface(lambda x: x, gr.Textbox(), gr.Textbox())