55from typing import Optional , Union
66
77from loguru import logger
8+ from result import Err , Ok , Result
89from sqlalchemy .ext .declarative import DeclarativeMeta
910from typing_extensions import TypeGuard
1011
@@ -18,13 +19,13 @@ def __init__(self, schema_factory: SchemaFactory, /):
1819 @abstractmethod
1920 def transform (
2021 self , rawtargets : Iterable [Union [ModuleType , DeclarativeMeta ]], depth : Optional [int ], /
21- ) -> Schema : ...
22+ ) -> Result [ Schema , str ] : ...
2223
2324
2425class JSONSchemaTransformer (AbstractTransformer ):
2526 def transform (
2627 self , rawtargets : Iterable [Union [ModuleType , DeclarativeMeta ]], depth : Optional [int ], /
27- ) -> Schema :
28+ ) -> Result [ Schema , str ] :
2829 definitions = {}
2930
3031 for item in rawtargets :
@@ -33,33 +34,46 @@ def transform(
3334 elif inspect .ismodule (item ):
3435 partial_definitions = self .transform_by_module (item , depth )
3536 else :
36- TypeError (f"Expected a class or module, got { item } " )
37+ return Err (f"Expected a class or module, got { item } " )
3738
38- definitions .update (partial_definitions )
39+ if partial_definitions .is_err ():
40+ return partial_definitions
3941
40- return definitions
42+ definitions . update ( partial_definitions . unwrap ())
4143
42- def transform_by_model (self , model : DeclarativeMeta , depth : Optional [int ], / ) -> Schema :
44+ return Ok (definitions )
45+
46+ def transform_by_model (
47+ self , model : DeclarativeMeta , depth : Optional [int ], /
48+ ) -> Result [Schema , str ]:
4349 return self .schema_factory (model , depth = depth )
4450
45- def transform_by_module (self , module : ModuleType , depth : Optional [int ], / ) -> Schema :
51+ def transform_by_module (
52+ self , module : ModuleType , depth : Optional [int ], /
53+ ) -> Result [Schema , str ]:
4654 subdefinitions = {}
4755 definitions = {}
4856 for basemodel in collect_models (module ):
49- schema = self .schema_factory (basemodel , depth = depth )
57+ schema_result = self .schema_factory (basemodel , depth = depth )
58+
59+ if schema_result .is_err ():
60+ return schema_result
61+
62+ schema = schema_result .unwrap ()
63+
5064 if "definitions" in schema :
5165 subdefinitions .update (schema .pop ("definitions" ))
5266 definitions [schema ["title" ]] = schema
5367 d = {}
5468 d .update (subdefinitions )
5569 d .update (definitions )
56- return {"definitions" : definitions }
70+ return Ok ( {"definitions" : definitions })
5771
5872
5973class OpenAPI2Transformer (AbstractTransformer ):
6074 def transform (
6175 self , rawtargets : Iterable [Union [ModuleType , DeclarativeMeta ]], depth : Optional [int ], /
62- ) -> Schema :
76+ ) -> Result [ Schema , str ] :
6377 definitions = {}
6478
6579 for target in rawtargets :
@@ -68,29 +82,46 @@ def transform(
6882 elif inspect .ismodule (target ):
6983 partial_definitions = self .transform_by_module (target , depth )
7084 else :
71- raise TypeError (f"Expected a class or module, got { target } " )
85+ return Err (f"Expected a class or module, got { target } " )
86+
87+ if partial_definitions .is_err ():
88+ return partial_definitions
7289
73- definitions .update (partial_definitions )
90+ definitions .update (partial_definitions . unwrap () )
7491
75- return {"definitions" : definitions }
92+ return Ok ( {"definitions" : definitions })
7693
77- def transform_by_model (self , model : DeclarativeMeta , depth : Optional [int ], / ) -> Schema :
94+ def transform_by_model (
95+ self , model : DeclarativeMeta , depth : Optional [int ], /
96+ ) -> Result [Schema , str ]:
7897 definitions = {}
79- schema = self .schema_factory (model , depth = depth )
98+ schema_result = self .schema_factory (model , depth = depth )
99+
100+ if schema_result .is_err ():
101+ return schema_result
102+
103+ schema = schema_result .unwrap ()
80104
81105 if "definitions" in schema :
82106 definitions .update (schema .pop ("definitions" ))
83107
84108 definitions [schema ["title" ]] = schema
85109
86- return definitions
110+ return Ok ( definitions )
87111
88- def transform_by_module (self , module : ModuleType , depth : Optional [int ], / ) -> Schema :
112+ def transform_by_module (
113+ self , module : ModuleType , depth : Optional [int ], /
114+ ) -> Result [Schema , str ]:
89115 subdefinitions = {}
90116 definitions = {}
91117
92118 for basemodel in collect_models (module ):
93- schema = self .schema_factory (basemodel , depth = depth )
119+ schema_result = self .schema_factory (basemodel , depth = depth )
120+
121+ if schema_result .is_err ():
122+ return schema_result
123+
124+ schema = schema_result .unwrap ()
94125
95126 if "definitions" in schema :
96127 subdefinitions .update (schema .pop ("definitions" ))
@@ -101,7 +132,7 @@ def transform_by_module(self, module: ModuleType, depth: Optional[int], /) -> Sc
101132 d .update (subdefinitions )
102133 d .update (definitions )
103134
104- return definitions
135+ return Ok ( definitions )
105136
106137
107138class OpenAPI3Transformer (OpenAPI2Transformer ):
@@ -118,8 +149,13 @@ def replace_ref(self, d: Union[dict, list], old_prefix: str, new_prefix: str, /)
118149
119150 def transform (
120151 self , rawtargets : Iterable [Union [ModuleType , DeclarativeMeta ]], depth : Optional [int ], /
121- ) -> Schema :
122- definitions = super ().transform (rawtargets , depth )
152+ ) -> Result [Schema , str ]:
153+ definitions_result = super ().transform (rawtargets , depth )
154+
155+ if definitions_result .is_err ():
156+ return Err (definitions_result .unwrap_err ())
157+
158+ definitions = definitions_result .unwrap ()
123159
124160 self .replace_ref (definitions , "#/definitions/" , "#/components/schemas/" )
125161
@@ -128,7 +164,8 @@ def transform(
128164 if "schemas" not in definitions ["components" ]:
129165 definitions ["components" ]["schemas" ] = {}
130166 definitions ["components" ]["schemas" ] = definitions .pop ("definitions" , {})
131- return definitions
167+
168+ return Ok (definitions )
132169
133170
134171def collect_models (module : ModuleType , / ) -> Iterator [DeclarativeMeta ]:
0 commit comments