@@ -198,13 +198,20 @@ def _args_flatten_map(self, args: List[FuncParam]) -> Dict[str, Tuple[str, ...]]
198
198
199
199
def _model_flatten_map (self , model : TModel , prefix : str ) -> Generator :
200
200
field : FieldInfo
201
- for attr , field in model .model_fields .items ():
202
- field_name = field .alias or attr
203
- name = f"{ prefix } { self .FLATTEN_PATH_SEP } { field_name } "
204
- if is_pydantic_model (field .annotation ):
205
- yield from self ._model_flatten_map (field .annotation , name ) # type: ignore
206
- else :
207
- yield field_name , name
201
+ if get_origin (model ) in UNION_TYPES :
202
+ # If the model is a union type, process each type in the union
203
+ for arg in get_args (model ):
204
+ if type (arg ) is None :
205
+ continue # Skip NoneType
206
+ yield from self ._model_flatten_map (arg , prefix )
207
+ else :
208
+ for attr , field in model .model_fields .items ():
209
+ field_name = field .alias or attr
210
+ name = f"{ prefix } { self .FLATTEN_PATH_SEP } { field_name } "
211
+ if is_pydantic_model (field .annotation ):
212
+ yield from self ._model_flatten_map (field .annotation , name ) # type: ignore
213
+ else :
214
+ yield field_name , name
208
215
209
216
def _get_param_type (self , name : str , arg : inspect .Parameter ) -> FuncParam :
210
217
# _EMPTY = self.signature.empty
@@ -278,7 +285,11 @@ def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam:
278
285
def is_pydantic_model (cls : Any ) -> bool :
279
286
try :
280
287
if get_origin (cls ) in UNION_TYPES :
281
- return any (issubclass (arg , pydantic .BaseModel ) for arg in get_args (cls ))
288
+ return any (
289
+ issubclass (arg , pydantic .BaseModel )
290
+ for arg in get_args (cls )
291
+ if (type (arg ) is not None )
292
+ )
282
293
return issubclass (cls , pydantic .BaseModel )
283
294
except TypeError :
284
295
return False
@@ -321,14 +332,32 @@ def detect_collection_fields(
321
332
for attr in path [1 :]:
322
333
if hasattr (annotation_or_field , "annotation" ):
323
334
annotation_or_field = annotation_or_field .annotation
324
- annotation_or_field = next (
325
- (
326
- a
327
- for a in annotation_or_field .model_fields .values ()
328
- if a .alias == attr
329
- ),
330
- annotation_or_field .model_fields .get (attr ),
331
- ) # pragma: no cover
335
+
336
+ # check union types
337
+ if get_origin (annotation_or_field ) in UNION_TYPES :
338
+ for arg in get_args (annotation_or_field ):
339
+ if type (arg ) is None :
340
+ continue # Skip NoneType
341
+ if hasattr (arg , "model_fields" ):
342
+ annotation_or_field = next (
343
+ (
344
+ a
345
+ for a in arg .model_fields .values ()
346
+ if a .alias == attr
347
+ ),
348
+ arg .model_fields .get (attr ),
349
+ ) # pragma: no cover
350
+ else :
351
+ continue
352
+ else :
353
+ annotation_or_field = next (
354
+ (
355
+ a
356
+ for a in annotation_or_field .model_fields .values ()
357
+ if a .alias == attr
358
+ ),
359
+ annotation_or_field .model_fields .get (attr ),
360
+ ) # pragma: no cover
332
361
333
362
annotation_or_field = getattr (
334
363
annotation_or_field , "outer_type_" , annotation_or_field
0 commit comments