Skip to content

Commit df81f32

Browse files
feat(nodes): improved pydantic type annotation massaging
When we do our field type overrides to allow invocations to be instantiated without all required fields, we were not modifying the annotation of the field but did set the default value of the field to `None`. This results in an error when doing a ser/de round trip. Here's what we end up doing: ```py from pydantic import BaseModel, Field class MyModel(BaseModel): foo: str = Field(default=None) ``` And here is a simple round-trip, which should not error but which does: ```py MyModel(**MyModel().model_dump()) # ValidationError: 1 validation error for MyModel # foo # Input should be a valid string [type=string_type, input_value=None, input_type=NoneType] # For further information visit https://errors.pydantic.dev/2.11/v/string_type ``` To fix this, we now check every incoming field and update its annotation to match its default value. In other words, when we override the default field value to `None`, we make its type annotation `<original type> | None`. This prevents the error during deserialization. This slightly alters the schema for all invocations and outputs - the values of all fields without default values are now typed as `<original type> | None`, reflecting the overrides. This means the autogenerated types for fields have also changed for fields without defaults: ```ts // Old image?: components["schemas"]["ImageField"]; // New image?: components["schemas"]["ImageField"] | null; ``` This does not break anything on the frontend.
1 parent 143487a commit df81f32

File tree

2 files changed

+102
-20
lines changed

2 files changed

+102
-20
lines changed

invokeai/app/invocations/baseinvocation.py

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import inspect
66
import re
77
import sys
8+
import types
9+
import typing
810
import warnings
911
from abc import ABC, abstractmethod
1012
from enum import Enum
@@ -489,6 +491,18 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
489491
return None
490492

491493

494+
def is_optional(annotation: Any) -> bool:
495+
"""
496+
Checks if the given annotation is optional (i.e. Optional[X], Union[X, None] or X | None).
497+
"""
498+
origin = typing.get_origin(annotation)
499+
# PEP 604 unions (int|None) have origin types.UnionType
500+
is_union = origin is typing.Union or origin is types.UnionType
501+
if not is_union:
502+
return False
503+
return any(arg is type(None) for arg in typing.get_args(annotation))
504+
505+
492506
def invocation(
493507
invocation_type: str,
494508
title: Optional[str] = None,
@@ -523,6 +537,18 @@ def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
523537

524538
validate_fields(cls.model_fields, invocation_type)
525539

540+
fields: dict[str, tuple[Any, FieldInfo]] = {}
541+
542+
for field_name, field_info in cls.model_fields.items():
543+
annotation = field_info.annotation
544+
assert annotation is not None, f"{field_name} on invocation {invocation_type} has no type annotation."
545+
assert isinstance(field_info.json_schema_extra, dict), (
546+
f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?"
547+
)
548+
if field_info.default is None and not is_optional(annotation):
549+
annotation = annotation | None
550+
fields[field_name] = (annotation, field_info)
551+
526552
# Add OpenAPI schema extras
527553
uiconfig: dict[str, Any] = {}
528554
uiconfig["title"] = title
@@ -557,11 +583,17 @@ def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
557583
# Unfortunately, because the `GraphInvocation` uses a forward ref in its `graph` field's annotation, this does
558584
# not work. Instead, we have to create a new class with the type field and patch the original class with it.
559585

560-
invocation_type_annotation = Literal[invocation_type] # type: ignore
586+
invocation_type_annotation = Literal[invocation_type]
561587
invocation_type_field = Field(
562588
title="type", default=invocation_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
563589
)
564590

591+
# pydantic's Field function returns a FieldInfo, but they annotate it as returning a type so that type-checkers
592+
# don't get confused by something like this:
593+
# foo: str = Field() <-- this is a FieldInfo, not a str
594+
# Unfortunately this means we need to use type: ignore here to avoid type-checker errors
595+
fields["type"] = (invocation_type_annotation, invocation_type_field) # type: ignore
596+
565597
# Validate the `invoke()` method is implemented
566598
if "invoke" in cls.__abstractmethods__:
567599
raise ValueError(f'Invocation "{invocation_type}" must implement the "invoke" method')
@@ -583,17 +615,12 @@ def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
583615
)
584616

585617
docstring = cls.__doc__
586-
cls = create_model(
587-
cls.__qualname__,
588-
__base__=cls,
589-
__module__=cls.__module__,
590-
type=(invocation_type_annotation, invocation_type_field),
591-
)
592-
cls.__doc__ = docstring
618+
new_class = create_model(cls.__qualname__, __base__=cls, __module__=cls.__module__, **fields)
619+
new_class.__doc__ = docstring
593620

594-
InvocationRegistry.register_invocation(cls)
621+
InvocationRegistry.register_invocation(new_class)
595622

596-
return cls
623+
return new_class
597624

598625
return wrapper
599626

@@ -618,23 +645,32 @@ def wrapper(cls: Type[TBaseInvocationOutput]) -> Type[TBaseInvocationOutput]:
618645

619646
validate_fields(cls.model_fields, output_type)
620647

648+
fields: dict[str, tuple[Any, FieldInfo]] = {}
649+
650+
for field_name, field_info in cls.model_fields.items():
651+
annotation = field_info.annotation
652+
assert annotation is not None, f"{field_name} on invocation output {output_type} has no type annotation."
653+
assert isinstance(field_info.json_schema_extra, dict), (
654+
f"{field_name} on invocation output {output_type} has a non-dict json_schema_extra, did you forget to use InputField?"
655+
)
656+
if field_info.default is not PydanticUndefined and is_optional(annotation):
657+
annotation = annotation | None
658+
fields[field_name] = (annotation, field_info)
659+
621660
# Add the output type to the model.
622-
output_type_annotation = Literal[output_type] # type: ignore
661+
output_type_annotation = Literal[output_type]
623662
output_type_field = Field(
624663
title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
625664
)
626665

666+
fields["type"] = (output_type_annotation, output_type_field) # type: ignore
667+
627668
docstring = cls.__doc__
628-
cls = create_model(
629-
cls.__qualname__,
630-
__base__=cls,
631-
__module__=cls.__module__,
632-
type=(output_type_annotation, output_type_field),
633-
)
634-
cls.__doc__ = docstring
669+
new_class = create_model(cls.__qualname__, __base__=cls, __module__=cls.__module__, **fields)
670+
new_class.__doc__ = docstring
635671

636-
InvocationRegistry.register_output(cls)
672+
InvocationRegistry.register_output(new_class)
637673

638-
return cls
674+
return new_class
639675

640676
return wrapper
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from typing import Any, Literal, Optional, Union
2+
3+
import pytest
4+
from pydantic import BaseModel
5+
6+
7+
class TestModel(BaseModel):
8+
foo: Literal["bar"] = "bar"
9+
10+
11+
@pytest.mark.parametrize(
12+
"input_type, expected",
13+
[
14+
(str, False),
15+
(list[str], False),
16+
(list[dict[str, Any]], False),
17+
(list[None], False),
18+
(list[dict[str, None]], False),
19+
(Any, False),
20+
(True, False),
21+
(False, False),
22+
(Union[str, False], False),
23+
(Union[str, True], False),
24+
(None, False),
25+
(str | None, True),
26+
(Union[str, None], True),
27+
(Optional[str], True),
28+
(str | int | None, True),
29+
(None | str | int, True),
30+
(Union[None, str], True),
31+
(Optional[str], True),
32+
(Optional[int], True),
33+
(Optional[str], True),
34+
(TestModel | None, True),
35+
(Union[TestModel, None], True),
36+
(Optional[TestModel], True),
37+
],
38+
)
39+
def test_is_optional(input_type: Any, expected: bool) -> None:
40+
"""
41+
Test the is_optional function.
42+
"""
43+
from invokeai.app.invocations.baseinvocation import is_optional
44+
45+
result = is_optional(input_type)
46+
assert result == expected, f"Expected {expected} but got {result} for input type {input_type}"

0 commit comments

Comments
 (0)