Skip to content

Commit cfbf132

Browse files
feat(nodes): validate default values for all fields
This prevents issues where the node is defined with an invalid default value, which would guarantee an error during a ser/de roundtrip. - Upstream issue requesting this functionality be built-in to pydantic: pydantic/pydantic#8722 - Upstream PR that implements the functionality: pydantic/pydantic-core#1593
1 parent e2fed1e commit cfbf132

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

invokeai/app/invocations/baseinvocation.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,31 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
477477
return None
478478

479479

480+
class NoDefaultSentinel:
481+
pass
482+
483+
484+
def validate_field_default(field_name: str, invocation_type: str, annotation: Any, field_info: FieldInfo) -> None:
485+
"""Validates the default value of a field against its pydantic field definition."""
486+
487+
assert isinstance(field_info.json_schema_extra, dict), "json_schema_extra is not a dict"
488+
489+
# By the time we are doing this, we've already done some pydantic magic by overriding the original default value.
490+
# We store the original default value in the json_schema_extra dict, so we can validate it here.
491+
orig_default = field_info.json_schema_extra.get("orig_default", NoDefaultSentinel)
492+
493+
if orig_default is NoDefaultSentinel:
494+
return
495+
496+
TempDefaultValidator = create_model("TempDefaultValidator", field_to_validate=(annotation, field_info))
497+
498+
# Validate the default value against the annotation
499+
try:
500+
TempDefaultValidator(field_to_validate=orig_default)
501+
except Exception as e:
502+
raise ValueError(f"Default value for field {field_name} on invocation {invocation_type} is invalid, {e}") from e
503+
504+
480505
def is_optional(annotation: Any) -> bool:
481506
"""
482507
Checks if the given annotation is optional (i.e. Optional[X], Union[X, None] or X | None).
@@ -529,8 +554,12 @@ def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
529554
assert isinstance(field_info.json_schema_extra, dict), (
530555
f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?"
531556
)
557+
558+
validate_field_default(field_name, invocation_type, annotation, field_info)
559+
532560
if field_info.default is None and not is_optional(annotation):
533561
annotation = annotation | None
562+
534563
fields[field_name] = (annotation, field_info)
535564

536565
# Add OpenAPI schema extras

0 commit comments

Comments
 (0)