Skip to content

Commit 2f9ebde

Browse files
fix(app): openapi schema generation
Some tech debt related to dynamic pydantic schemas for invocations became problematic. Including the invocations and results in the event schemas was breaking pydantic's handling of ref schemas. I don't really understand why - I think it's a pydantic bug in a remote edge case that we are hitting. After many failed attempts I landed on this implementation, which is actually much tidier than what was in there before. - Create pydantic-enabled types for `AnyInvocation` and `AnyInvocationOutput` and use these in place of the janky dynamic unions. Actually, they are kinda the same, but better encapsulated. Use these in `Graph`, `GraphExecutionState`, `InvocationEventBase` and `InvocationCompleteEvent`. - Revise the custom openapi function to work with the new models. - Split out the custom openapi function to a separate file. Add a `post_transform` callback so consumers can customize the output schema. - Update makefile scripts.
1 parent e257a72 commit 2f9ebde

File tree

7 files changed

+177
-226
lines changed

7 files changed

+177
-226
lines changed

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ help:
1818
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
1919
@echo "installer-zip Build the installer .zip file for the current version"
2020
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
21+
@echo "openapi Generate the OpenAPI schema for the app, outputting to stdout"
2122

2223
# Runs ruff, fixing any safely-fixable errors and formatting
2324
ruff:
@@ -70,3 +71,6 @@ installer-zip:
7071
tag-release:
7172
cd installer && ./tag_release.sh
7273

74+
# Generate the OpenAPI Schema for the app
75+
openapi:
76+
python scripts/generate_openapi_schema.py

invokeai/app/api_app.py

Lines changed: 2 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -3,32 +3,26 @@
33
import mimetypes
44
import socket
55
from contextlib import asynccontextmanager
6-
from inspect import signature
76
from pathlib import Path
8-
from typing import Any
97

108
import torch
119
import uvicorn
1210
from fastapi import FastAPI
1311
from fastapi.middleware.cors import CORSMiddleware
1412
from fastapi.middleware.gzip import GZipMiddleware
1513
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
16-
from fastapi.openapi.utils import get_openapi
1714
from fastapi.responses import HTMLResponse
1815
from fastapi_events.handlers.local import local_handler
1916
from fastapi_events.middleware import EventHandlerASGIMiddleware
20-
from pydantic.json_schema import models_json_schema
2117
from torch.backends.mps import is_available as is_mps_available
2218

2319
# for PyCharm:
2420
# noinspection PyUnresolvedReferences
2521
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
2622
import invokeai.frontend.web as web_dir
2723
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
28-
from invokeai.app.invocations.model import ModelIdentifierField
2924
from invokeai.app.services.config.config_default import get_config
30-
from invokeai.app.services.events.events_common import EventBase
31-
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
25+
from invokeai.app.util.custom_openapi import get_openapi_func
3226
from invokeai.backend.util.devices import TorchDevice
3327

3428
from ..backend.util.logging import InvokeAILogger
@@ -45,11 +39,6 @@
4539
workflows,
4640
)
4741
from .api.sockets import SocketIO
48-
from .invocations.baseinvocation import (
49-
BaseInvocation,
50-
UIConfigBase,
51-
)
52-
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
5342

5443
app_config = get_config()
5544

@@ -119,84 +108,7 @@ async def lifespan(app: FastAPI):
119108
app.include_router(session_queue.session_queue_router, prefix="/api")
120109
app.include_router(workflows.workflows_router, prefix="/api")
121110

122-
123-
# Build a custom OpenAPI to include all outputs
124-
# TODO: can outputs be included on metadata of invocation schemas somehow?
125-
def custom_openapi() -> dict[str, Any]:
126-
if app.openapi_schema:
127-
return app.openapi_schema
128-
openapi_schema = get_openapi(
129-
title=app.title,
130-
description="An API for invoking AI image operations",
131-
version="1.0.0",
132-
routes=app.routes,
133-
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
134-
)
135-
136-
# Add all outputs
137-
all_invocations = BaseInvocation.get_invocations()
138-
output_types = set()
139-
output_type_titles = {}
140-
for invoker in all_invocations:
141-
output_type = signature(invoker.invoke).return_annotation
142-
output_types.add(output_type)
143-
144-
output_schemas = models_json_schema(
145-
models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}"
146-
)
147-
for schema_key, output_schema in output_schemas[1]["$defs"].items():
148-
# TODO: note that we assume the schema_key here is the TYPE.__name__
149-
# This could break in some cases, figure out a better way to do it
150-
output_type_titles[schema_key] = output_schema["title"]
151-
openapi_schema["components"]["schemas"][schema_key] = output_schema
152-
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
153-
154-
# Some models don't end up in the schemas as standalone definitions
155-
additional_schemas = models_json_schema(
156-
[
157-
(UIConfigBase, "serialization"),
158-
(InputFieldJSONSchemaExtra, "serialization"),
159-
(OutputFieldJSONSchemaExtra, "serialization"),
160-
(ModelIdentifierField, "serialization"),
161-
(ProgressImage, "serialization"),
162-
],
163-
ref_template="#/components/schemas/{model}",
164-
)
165-
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
166-
openapi_schema["components"]["schemas"][schema_key] = schema_json
167-
168-
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
169-
"type": "object",
170-
"properties": {},
171-
"required": [],
172-
}
173-
174-
# Add a reference to the output type to additionalProperties of the invoker schema
175-
for invoker in all_invocations:
176-
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
177-
output_type = signature(obj=invoker.invoke).return_annotation
178-
output_type_title = output_type_titles[output_type.__name__]
179-
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
180-
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
181-
invoker_schema["output"] = outputs_ref
182-
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["properties"][invoker.get_type()] = outputs_ref
183-
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type())
184-
invoker_schema["class"] = "invocation"
185-
186-
# Add all event schemas
187-
for event in sorted(EventBase.get_events(), key=lambda e: e.__name__):
188-
json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
189-
if "$defs" in json_schema:
190-
for schema_key, schema in json_schema["$defs"].items():
191-
openapi_schema["components"]["schemas"][schema_key] = schema
192-
del json_schema["$defs"]
193-
openapi_schema["components"]["schemas"][event.__name__] = json_schema
194-
195-
app.openapi_schema = openapi_schema
196-
return app.openapi_schema
197-
198-
199-
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
111+
app.openapi = get_openapi_func(app)
200112

201113

202114
@app.get("/docs", include_in_schema=False)

invokeai/app/invocations/baseinvocation.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,10 @@ def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
113113
def get_typeadapter(cls) -> TypeAdapter[Any]:
114114
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
115115
if not cls._typeadapter:
116-
InvocationOutputsUnion = TypeAliasType(
117-
"InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
116+
AnyInvocationOutput = TypeAliasType(
117+
"AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
118118
)
119-
cls._typeadapter = TypeAdapter(InvocationOutputsUnion)
119+
cls._typeadapter = TypeAdapter(AnyInvocationOutput)
120120
return cls._typeadapter
121121

122122
@classmethod
@@ -125,12 +125,13 @@ def get_output_types(cls) -> Iterable[str]:
125125
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
126126

127127
@staticmethod
128-
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
128+
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
129129
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
130130
# Because we use a pydantic Literal field with default value for the invocation type,
131131
# it will be typed as optional in the OpenAPI schema. Make it required manually.
132132
if "required" not in schema or not isinstance(schema["required"], list):
133133
schema["required"] = []
134+
schema["class"] = "output"
134135
schema["required"].extend(["type"])
135136

136137
@classmethod
@@ -182,10 +183,10 @@ def register_invocation(cls, invocation: BaseInvocation) -> None:
182183
def get_typeadapter(cls) -> TypeAdapter[Any]:
183184
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
184185
if not cls._typeadapter:
185-
InvocationsUnion = TypeAliasType(
186-
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
186+
AnyInvocation = TypeAliasType(
187+
"AnyInvocation", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
187188
)
188-
cls._typeadapter = TypeAdapter(InvocationsUnion)
189+
cls._typeadapter = TypeAdapter(AnyInvocation)
189190
return cls._typeadapter
190191

191192
@classmethod
@@ -221,7 +222,7 @@ def get_output_annotation(cls) -> BaseInvocationOutput:
221222
return signature(cls.invoke).return_annotation
222223

223224
@staticmethod
224-
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None:
225+
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
225226
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
226227
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
227228
if uiconfig is not None:
@@ -237,6 +238,7 @@ def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *arg
237238
schema["version"] = uiconfig.version
238239
if "required" not in schema or not isinstance(schema["required"], list):
239240
schema["required"] = []
241+
schema["class"] = "invocation"
240242
schema["required"].extend(["type", "id"])
241243

242244
@abstractmethod
@@ -310,7 +312,7 @@ def invoke_internal(self, context: InvocationContext, services: "InvocationServi
310312
protected_namespaces=(),
311313
validate_assignment=True,
312314
json_schema_extra=json_schema_extra,
313-
json_schema_serialization_defaults_required=True,
315+
json_schema_serialization_defaults_required=False,
314316
coerce_numbers_to_str=True,
315317
)
316318

invokeai/app/services/events/events_common.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33

44
from fastapi_events.handlers.local import local_handler
55
from fastapi_events.registry.payload_schema import registry as payload_schema
6-
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, field_validator
6+
from pydantic import BaseModel, ConfigDict, Field
77

8-
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
98
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
109
from invokeai.app.services.session_queue.session_queue_common import (
1110
QUEUE_ITEM_STATUS,
@@ -14,6 +13,7 @@
1413
SessionQueueItem,
1514
SessionQueueStatus,
1615
)
16+
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
1717
from invokeai.app.util.misc import get_timestamp
1818
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
1919
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
@@ -98,17 +98,9 @@ class InvocationEventBase(QueueItemEventBase):
9898
item_id: int = Field(description="The ID of the queue item")
9999
batch_id: str = Field(description="The ID of the queue batch")
100100
session_id: str = Field(description="The ID of the session (aka graph execution state)")
101-
invocation: SerializeAsAny[BaseInvocation] = Field(description="The ID of the invocation")
101+
invocation: AnyInvocation = Field(description="The ID of the invocation")
102102
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
103103

104-
@field_validator("invocation", mode="plain")
105-
@classmethod
106-
def validate_invocation(cls, v: Any):
107-
"""Validates the invocation using the dynamic type adapter."""
108-
109-
invocation = BaseInvocation.get_typeadapter().validate_python(v)
110-
return invocation
111-
112104

113105
@payload_schema.register
114106
class InvocationStartedEvent(InvocationEventBase):
@@ -117,7 +109,7 @@ class InvocationStartedEvent(InvocationEventBase):
117109
__event_name__ = "invocation_started"
118110

119111
@classmethod
120-
def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent":
112+
def build(cls, queue_item: SessionQueueItem, invocation: AnyInvocation) -> "InvocationStartedEvent":
121113
return cls(
122114
queue_id=queue_item.queue_id,
123115
item_id=queue_item.item_id,
@@ -144,7 +136,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
144136
def build(
145137
cls,
146138
queue_item: SessionQueueItem,
147-
invocation: BaseInvocation,
139+
invocation: AnyInvocation,
148140
intermediate_state: PipelineIntermediateState,
149141
progress_image: ProgressImage,
150142
) -> "InvocationDenoiseProgressEvent":
@@ -182,19 +174,11 @@ class InvocationCompleteEvent(InvocationEventBase):
182174

183175
__event_name__ = "invocation_complete"
184176

185-
result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation")
186-
187-
@field_validator("result", mode="plain")
188-
@classmethod
189-
def validate_results(cls, v: Any):
190-
"""Validates the invocation result using the dynamic type adapter."""
191-
192-
result = BaseInvocationOutput.get_typeadapter().validate_python(v)
193-
return result
177+
result: AnyInvocationOutput = Field(description="The result of the invocation")
194178

195179
@classmethod
196180
def build(
197-
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput
181+
cls, queue_item: SessionQueueItem, invocation: AnyInvocation, result: AnyInvocationOutput
198182
) -> "InvocationCompleteEvent":
199183
return cls(
200184
queue_id=queue_item.queue_id,
@@ -223,7 +207,7 @@ class InvocationErrorEvent(InvocationEventBase):
223207
def build(
224208
cls,
225209
queue_item: SessionQueueItem,
226-
invocation: BaseInvocation,
210+
invocation: AnyInvocation,
227211
error_type: str,
228212
error_message: str,
229213
error_traceback: str,

0 commit comments

Comments
 (0)