|
3 | 3 | import mimetypes
|
4 | 4 | import socket
|
5 | 5 | from contextlib import asynccontextmanager
|
6 |
| -from inspect import signature |
7 | 6 | from pathlib import Path
|
8 |
| -from typing import Any |
9 | 7 |
|
10 | 8 | import torch
|
11 | 9 | import uvicorn
|
12 | 10 | from fastapi import FastAPI
|
13 | 11 | from fastapi.middleware.cors import CORSMiddleware
|
14 | 12 | from fastapi.middleware.gzip import GZipMiddleware
|
15 | 13 | from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
16 |
| -from fastapi.openapi.utils import get_openapi |
17 | 14 | from fastapi.responses import HTMLResponse
|
18 | 15 | from fastapi_events.handlers.local import local_handler
|
19 | 16 | from fastapi_events.middleware import EventHandlerASGIMiddleware
|
20 |
| -from pydantic.json_schema import models_json_schema |
21 | 17 | from torch.backends.mps import is_available as is_mps_available
|
22 | 18 |
|
23 | 19 | # for PyCharm:
|
24 | 20 | # noinspection PyUnresolvedReferences
|
25 | 21 | import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
26 | 22 | import invokeai.frontend.web as web_dir
|
27 | 23 | from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
28 |
| -from invokeai.app.invocations.model import ModelIdentifierField |
29 | 24 | from invokeai.app.services.config.config_default import get_config
|
30 |
| -from invokeai.app.services.events.events_common import EventBase |
31 |
| -from invokeai.app.services.session_processor.session_processor_common import ProgressImage |
| 25 | +from invokeai.app.util.custom_openapi import get_openapi_func |
32 | 26 | from invokeai.backend.util.devices import TorchDevice
|
33 | 27 |
|
34 | 28 | from ..backend.util.logging import InvokeAILogger
|
|
45 | 39 | workflows,
|
46 | 40 | )
|
47 | 41 | from .api.sockets import SocketIO
|
48 |
| -from .invocations.baseinvocation import ( |
49 |
| - BaseInvocation, |
50 |
| - UIConfigBase, |
51 |
| -) |
52 |
| -from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra |
53 | 42 |
|
54 | 43 | app_config = get_config()
|
55 | 44 |
|
@@ -119,84 +108,7 @@ async def lifespan(app: FastAPI):
|
119 | 108 | app.include_router(session_queue.session_queue_router, prefix="/api")
|
120 | 109 | app.include_router(workflows.workflows_router, prefix="/api")
|
121 | 110 |
|
122 |
| - |
123 |
| -# Build a custom OpenAPI to include all outputs |
124 |
| -# TODO: can outputs be included on metadata of invocation schemas somehow? |
125 |
| -def custom_openapi() -> dict[str, Any]: |
126 |
| - if app.openapi_schema: |
127 |
| - return app.openapi_schema |
128 |
| - openapi_schema = get_openapi( |
129 |
| - title=app.title, |
130 |
| - description="An API for invoking AI image operations", |
131 |
| - version="1.0.0", |
132 |
| - routes=app.routes, |
133 |
| - separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/ |
134 |
| - ) |
135 |
| - |
136 |
| - # Add all outputs |
137 |
| - all_invocations = BaseInvocation.get_invocations() |
138 |
| - output_types = set() |
139 |
| - output_type_titles = {} |
140 |
| - for invoker in all_invocations: |
141 |
| - output_type = signature(invoker.invoke).return_annotation |
142 |
| - output_types.add(output_type) |
143 |
| - |
144 |
| - output_schemas = models_json_schema( |
145 |
| - models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}" |
146 |
| - ) |
147 |
| - for schema_key, output_schema in output_schemas[1]["$defs"].items(): |
148 |
| - # TODO: note that we assume the schema_key here is the TYPE.__name__ |
149 |
| - # This could break in some cases, figure out a better way to do it |
150 |
| - output_type_titles[schema_key] = output_schema["title"] |
151 |
| - openapi_schema["components"]["schemas"][schema_key] = output_schema |
152 |
| - openapi_schema["components"]["schemas"][schema_key]["class"] = "output" |
153 |
| - |
154 |
| - # Some models don't end up in the schemas as standalone definitions |
155 |
| - additional_schemas = models_json_schema( |
156 |
| - [ |
157 |
| - (UIConfigBase, "serialization"), |
158 |
| - (InputFieldJSONSchemaExtra, "serialization"), |
159 |
| - (OutputFieldJSONSchemaExtra, "serialization"), |
160 |
| - (ModelIdentifierField, "serialization"), |
161 |
| - (ProgressImage, "serialization"), |
162 |
| - ], |
163 |
| - ref_template="#/components/schemas/{model}", |
164 |
| - ) |
165 |
| - for schema_key, schema_json in additional_schemas[1]["$defs"].items(): |
166 |
| - openapi_schema["components"]["schemas"][schema_key] = schema_json |
167 |
| - |
168 |
| - openapi_schema["components"]["schemas"]["InvocationOutputMap"] = { |
169 |
| - "type": "object", |
170 |
| - "properties": {}, |
171 |
| - "required": [], |
172 |
| - } |
173 |
| - |
174 |
| - # Add a reference to the output type to additionalProperties of the invoker schema |
175 |
| - for invoker in all_invocations: |
176 |
| - invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute |
177 |
| - output_type = signature(obj=invoker.invoke).return_annotation |
178 |
| - output_type_title = output_type_titles[output_type.__name__] |
179 |
| - invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"] |
180 |
| - outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"} |
181 |
| - invoker_schema["output"] = outputs_ref |
182 |
| - openapi_schema["components"]["schemas"]["InvocationOutputMap"]["properties"][invoker.get_type()] = outputs_ref |
183 |
| - openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type()) |
184 |
| - invoker_schema["class"] = "invocation" |
185 |
| - |
186 |
| - # Add all event schemas |
187 |
| - for event in sorted(EventBase.get_events(), key=lambda e: e.__name__): |
188 |
| - json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}") |
189 |
| - if "$defs" in json_schema: |
190 |
| - for schema_key, schema in json_schema["$defs"].items(): |
191 |
| - openapi_schema["components"]["schemas"][schema_key] = schema |
192 |
| - del json_schema["$defs"] |
193 |
| - openapi_schema["components"]["schemas"][event.__name__] = json_schema |
194 |
| - |
195 |
| - app.openapi_schema = openapi_schema |
196 |
| - return app.openapi_schema |
197 |
| - |
198 |
| - |
199 |
| -app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment |
| 111 | +app.openapi = get_openapi_func(app) |
200 | 112 |
|
201 | 113 |
|
202 | 114 | @app.get("/docs", include_in_schema=False)
|
|
0 commit comments