Skip to content

Commit 3011150

Browse files
feat(nodes): validate default values for all fields
This prevents issues where the node is defined with an invalid default value, which would guarantee an error during a ser/de roundtrip. - Upstream issue requesting this functionality be built-in to pydantic: pydantic/pydantic#8722 - Upstream PR that implements the functionality: pydantic/pydantic-core#1593
1 parent 05aa1fc commit 3011150

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

invokeai/app/invocations/baseinvocation.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,31 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
491491
return None
492492

493493

494+
class NoDefaultSentinel:
495+
pass
496+
497+
498+
def validate_field_default(field_name: str, invocation_type: str, annotation: Any, field_info: FieldInfo) -> None:
499+
"""Validates the default value of a field against its pydantic field definition."""
500+
501+
assert isinstance(field_info.json_schema_extra, dict), "json_schema_extra is not a dict"
502+
503+
# By the time we are doing this, we've already done some pydantic magic by overriding the original default value.
504+
# We store the original default value in the json_schema_extra dict, so we can validate it here.
505+
orig_default = field_info.json_schema_extra.get("orig_default", NoDefaultSentinel)
506+
507+
if orig_default is NoDefaultSentinel:
508+
return
509+
510+
TempDefaultValidator = create_model("TempDefaultValidator", field_to_validate=(annotation, field_info))
511+
512+
# Validate the default value against the annotation
513+
try:
514+
TempDefaultValidator(field_to_validate=orig_default)
515+
except Exception as e:
516+
raise ValueError(f"Default value for field {field_name} on invocation {invocation_type} is invalid, {e}") from e
517+
518+
494519
def is_optional(annotation: Any) -> bool:
495520
"""
496521
Checks if the given annotation is optional (i.e. Optional[X], Union[X, None] or X | None).
@@ -545,8 +570,12 @@ def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
545570
assert isinstance(field_info.json_schema_extra, dict), (
546571
f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?"
547572
)
573+
574+
validate_field_default(field_name, invocation_type, annotation, field_info)
575+
548576
if field_info.default is None and not is_optional(annotation):
549577
annotation = annotation | None
578+
550579
fields[field_name] = (annotation, field_info)
551580

552581
# Add OpenAPI schema extras

0 commit comments

Comments
 (0)