Skip to content

Commit 2b52fe2

Browse files
committed
fix: create monkey-patch to force swagger to use multi-part forms
This is a terrible, horrible, no good, very bad patch that I wish I didn't have to do. However, its the only way to get the Swagger behavior we need. It applies the patch in this PR python-restx/flask-restx#542 on the flask-restx repo, however it seems unlikely it will ever be accepted. Checking for the post-patch hash is necessary when running pytest, as the unit tests run in the same process, meaning the imports remain in memory, but many of the fixtures are re-run for each test. This includes the fixture that sets up the main Flask application, so the monkey patch function was being re-executed more than once, and the hash check was failing. The fix now checks if the code matches the patched hash first, and if it does assumes the patch has already been applied and exits the monkey patching function early. Add logging messages documenting the success or failure of applying the patch.
1 parent 9374ea7 commit 2b52fe2

File tree

2 files changed

+128
-0
lines changed

2 files changed

+128
-0
lines changed

src/dioptra/restapi/app.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
from .__version__ import __version__ as DIOPTRA_VERSION
4141
from .db import db
42+
from .patches import monkey_patch_flask_restx
4243

4344
LOGGER: BoundLogger = structlog.stdlib.get_logger()
4445

@@ -66,6 +67,8 @@ def create_app(env: Optional[str] = None, injector: Optional[Injector] = None) -
6667
from .routes import register_routes
6768
from .v1.users.service import load_user as v1_load_user
6869

70+
monkey_patch_flask_restx()
71+
6972
if env is None:
7073
env = os.getenv("DIOPTRA_RESTAPI_ENV", "test")
7174

src/dioptra/restapi/patches.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# This Software (Dioptra) is being made available as a public service by the
2+
# National Institute of Standards and Technology (NIST), an Agency of the United
3+
# States Department of Commerce. This software was developed in part by employees of
4+
# NIST and in part by NIST contractors. Copyright in portions of this software that
5+
# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant
6+
# to Title 17 United States Code Section 105, works of NIST employees are not
7+
# subject to copyright protection in the United States. However, NIST may hold
8+
# international copyright in software created by its employees and domestic
9+
# copyright (or licensing rights) in portions of software that were assigned or
10+
# licensed to NIST. To the extent that NIST holds copyright in this software, it is
11+
# being made available under the Creative Commons Attribution 4.0 International
12+
# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts
13+
# of the software developed or licensed by NIST.
14+
#
15+
# ACCESS THE FULL CC BY 4.0 LICENSE HERE:
16+
# https://creativecommons.org/licenses/by/4.0/legalcode
17+
import hashlib
18+
import inspect
19+
from typing import Any
20+
21+
import structlog
22+
from structlog.stdlib import BoundLogger
23+
24+
EXPECTED_SERIALIZE_OPERATION_SHA256_HASH = "57241f0a33ed5e1771e5032d1e6f6994685185ed526b9ca2c70f4f27684d1f92" # noqa: B950; fmt: skip
25+
PATCHED_SERIALIZE_OPERATION_SHA256_HASH = "8a51bc04c8dcb81820548d9de53a9606faf0681ffc3684102744c69fbd076437" # noqa: B950; fmt: skip
26+
27+
LOGGER: BoundLogger = structlog.stdlib.get_logger()
28+
29+
30+
def monkey_patch_flask_restx() -> None:
31+
"""
32+
Monkey patch flask_restx.Swagger.serialize_operation to force Swagger docs to use
33+
the multipart/form-data content type for multi-file uploads instead of the
34+
application/x-www-form-urlencoded content type.
35+
36+
This monkey-patch applies the proposed change in this PR
37+
https://github.com/python-restx/flask-restx/pull/542.
38+
"""
39+
import flask_restx
40+
from flask_restx.utils import not_none
41+
42+
serialize_operation_sha256_hash = get_source_code_hash(
43+
flask_restx.Swagger.serialize_operation
44+
)
45+
46+
if serialize_operation_sha256_hash == PATCHED_SERIALIZE_OPERATION_SHA256_HASH:
47+
LOGGER.debug(
48+
"flask_restx.Swagger.serialize_operation already patched",
49+
sha256_hash=serialize_operation_sha256_hash,
50+
)
51+
return None
52+
53+
if serialize_operation_sha256_hash != EXPECTED_SERIALIZE_OPERATION_SHA256_HASH:
54+
LOGGER.error(
55+
"Source code hash changed",
56+
reason="hash of flask_restx.Swagger.serialize_operation did not match",
57+
expected_hash=EXPECTED_SERIALIZE_OPERATION_SHA256_HASH,
58+
sha256_hash=serialize_operation_sha256_hash,
59+
)
60+
raise RuntimeError(
61+
"Source code hash changed (reason: hash of "
62+
"flask_restx.Swagger.serialize_operation did not match "
63+
f"{EXPECTED_SERIALIZE_OPERATION_SHA256_HASH}): "
64+
f"{serialize_operation_sha256_hash}"
65+
)
66+
67+
def serialize_operation_patched(self, doc, method):
68+
operation = {
69+
"responses": self.responses_for(doc, method) or None,
70+
"summary": doc[method]["docstring"]["summary"],
71+
"description": self.description_for(doc, method) or None,
72+
"operationId": self.operation_id_for(doc, method),
73+
"parameters": self.parameters_for(doc[method]) or None,
74+
"security": self.security_for(doc, method),
75+
}
76+
# Handle 'produces' mimetypes documentation
77+
if "produces" in doc[method]:
78+
operation["produces"] = doc[method]["produces"]
79+
# Handle deprecated annotation
80+
if doc.get("deprecated") or doc[method].get("deprecated"):
81+
operation["deprecated"] = True
82+
# Handle form exceptions:
83+
doc_params = list(doc.get("params", {}).values())
84+
all_params = doc_params + (operation["parameters"] or [])
85+
if all_params and any(p["in"] == "formData" for p in all_params):
86+
if any(p["type"] == "file" for p in all_params):
87+
operation["consumes"] = ["multipart/form-data"]
88+
elif any(
89+
p["type"] == "array" and p["collectionFormat"] == "multi"
90+
for p in all_params
91+
if "collectionFormat" in p
92+
):
93+
operation["consumes"] = ["multipart/form-data"]
94+
else:
95+
operation["consumes"] = [
96+
"application/x-www-form-urlencoded",
97+
"multipart/form-data",
98+
]
99+
operation.update(self.vendor_fields(doc, method))
100+
return not_none(operation)
101+
102+
flask_restx.Swagger.serialize_operation = serialize_operation_patched
103+
LOGGER.info(
104+
"flask_restx.Swagger.serialize_operation patched successfully"
105+
)
106+
107+
108+
def get_source_code_hash(obj: Any) -> str:
109+
"""Generate a hash of the underlying source code of a Python object.
110+
111+
Args:
112+
obj: The Python object for which to generate a source code hash.
113+
114+
Returns:
115+
The hash of the source code of the Python object.
116+
"""
117+
118+
hash_sha256 = hashlib.sha256()
119+
source_lines, _ = inspect.getsourcelines(obj)
120+
source_lines = [line.rstrip() for line in source_lines]
121+
122+
for line in source_lines:
123+
hash_sha256.update(line.encode("utf-8"))
124+
125+
return hash_sha256.hexdigest()

0 commit comments

Comments
 (0)