Skip to content

Commit d87cdda

Browse files
committed
add union type and subtypes check in schema model signature
1 parent eecb05f commit d87cdda

File tree

1 file changed

+45
-16
lines changed

1 file changed

+45
-16
lines changed

ninja/signature/details.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -198,13 +198,20 @@ def _args_flatten_map(self, args: List[FuncParam]) -> Dict[str, Tuple[str, ...]]
198198

199199
def _model_flatten_map(self, model: TModel, prefix: str) -> Generator:
200200
field: FieldInfo
201-
for attr, field in model.model_fields.items():
202-
field_name = field.alias or attr
203-
name = f"{prefix}{self.FLATTEN_PATH_SEP}{field_name}"
204-
if is_pydantic_model(field.annotation):
205-
yield from self._model_flatten_map(field.annotation, name) # type: ignore
206-
else:
207-
yield field_name, name
201+
if get_origin(model) in UNION_TYPES:
202+
# If the model is a union type, process each type in the union
203+
for arg in get_args(model):
204+
if type(arg) is None:
205+
continue # Skip NoneType
206+
yield from self._model_flatten_map(arg, prefix)
207+
else:
208+
for attr, field in model.model_fields.items():
209+
field_name = field.alias or attr
210+
name = f"{prefix}{self.FLATTEN_PATH_SEP}{field_name}"
211+
if is_pydantic_model(field.annotation):
212+
yield from self._model_flatten_map(field.annotation, name) # type: ignore
213+
else:
214+
yield field_name, name
208215

209216
def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam:
210217
# _EMPTY = self.signature.empty
@@ -278,7 +285,11 @@ def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam:
278285
def is_pydantic_model(cls: Any) -> bool:
279286
try:
280287
if get_origin(cls) in UNION_TYPES:
281-
return any(issubclass(arg, pydantic.BaseModel) for arg in get_args(cls))
288+
return any(
289+
issubclass(arg, pydantic.BaseModel)
290+
for arg in get_args(cls)
291+
if (type(arg) is not None)
292+
)
282293
return issubclass(cls, pydantic.BaseModel)
283294
except TypeError:
284295
return False
@@ -321,14 +332,32 @@ def detect_collection_fields(
321332
for attr in path[1:]:
322333
if hasattr(annotation_or_field, "annotation"):
323334
annotation_or_field = annotation_or_field.annotation
324-
annotation_or_field = next(
325-
(
326-
a
327-
for a in annotation_or_field.model_fields.values()
328-
if a.alias == attr
329-
),
330-
annotation_or_field.model_fields.get(attr),
331-
) # pragma: no cover
335+
336+
# check union types
337+
if get_origin(annotation_or_field) in UNION_TYPES:
338+
for arg in get_args(annotation_or_field):
339+
if type(arg) is None:
340+
continue # Skip NoneType
341+
if hasattr(arg, "model_fields"):
342+
annotation_or_field = next(
343+
(
344+
a
345+
for a in arg.model_fields.values()
346+
if a.alias == attr
347+
),
348+
arg.model_fields.get(attr),
349+
) # pragma: no cover
350+
else:
351+
continue
352+
else:
353+
annotation_or_field = next(
354+
(
355+
a
356+
for a in annotation_or_field.model_fields.values()
357+
if a.alias == attr
358+
),
359+
annotation_or_field.model_fields.get(attr),
360+
) # pragma: no cover
332361

333362
annotation_or_field = getattr(
334363
annotation_or_field, "outer_type_", annotation_or_field

0 commit comments

Comments
 (0)