diff --git a/.changeset/fix-model-override.md b/.changeset/fix-model-override.md new file mode 100644 index 000000000..faadd8060 --- /dev/null +++ b/.changeset/fix-model-override.md @@ -0,0 +1,53 @@ +--- +default: patch +--- + +# Fix overriding of object property class + +Fixed issue #1123, in which a property could end up with the wrong type when combining two object schemas with `allOf`, if the type of the property was itself an object but had a different schema in each. Example: + +```yaml + ModelA: + properties: + status: + type: string + result: + - $ref: "#/components/schemas/BaseResult" + + ModelB: + allOf: + - $ref: "#/components/schemas/ModelA" + - properties: + result: + - $ref: "#/components/schemas/ExtendedResult" + + ModelC: + allOf: + - $ref: "#/components/schemas/ModelA" + - properties: + result: + - $ref: "#/components/schemas/UnrelatedResult" + + BaseResult: + properties: + prop1: + type: string + + ExtendedResult: + allOf: + - $ref: "#/components/schemas/BaseResult" + - properties: + prop2: + type: string + + UnrelatedResult: + properties: + prop3: + type: string +``` + +Previously, in the generated classes for both `ModelB` and `ModelC`, the type of `result` was being incorrectly set to `BaseResult`. + +The new behavior is, when computing `allOf: [A, B]` where `A` and `B` are both objects, any property `P` whose name exists in both schemas will have a schema equivalent to `allOf: [A.P, B.P]`. This is consistent with the basic definition of `allOf`. + +When translating this into Python code, the generator will use a type that correctly describes the combined schema for the property. If the combined schema is exactly equal in shape to either `A.P` or `B.P` (implying that one was already derived from the other using `allOf`) then it will reuse the corresponding Python class. Otherwise it will create a new class, just as it would for an inline schema that used `allOf`. Therefore in the example above, the type of `ModelB.result` is `ExtendedResult`, but the type of `ModelC.result` is a new class called `ModelCResult` that includes all the properties from `BaseResult` and `UnrelatedResult`. diff --git a/openapi_python_client/parser/properties/merge_properties.py b/openapi_python_client/parser/properties/merge_properties.py index db6424a7c..fdb479090 100644 --- a/openapi_python_client/parser/properties/merge_properties.py +++ b/openapi_python_client/parser/properties/merge_properties.py @@ -1,9 +1,16 @@ from __future__ import annotations +from itertools import chain + +from openapi_python_client import schema as oai +from openapi_python_client import utils +from openapi_python_client.config import Config from openapi_python_client.parser.properties.date import DateProperty from openapi_python_client.parser.properties.datetime import DateTimeProperty from openapi_python_client.parser.properties.file import FileProperty from openapi_python_client.parser.properties.literal_enum_property import LiteralEnumProperty +from openapi_python_client.parser.properties.model_property import ModelProperty, _gather_property_data +from openapi_python_client.parser.properties.schemas import Class __all__ = ["merge_properties"] @@ -27,7 +34,12 @@ STRING_WITH_FORMAT_TYPES = (DateProperty, DateTimeProperty, FileProperty) -def merge_properties(prop1: Property, prop2: Property) -> Property | PropertyError: # noqa: PLR0911 +def merge_properties( # noqa:PLR0911 + prop1: Property, + prop2: Property, + parent_name: str, + config: Config, +) -> Property | PropertyError: """Attempt to create a new property that incorporates the behavior of both. This is used when merging schemas with allOf, when two schemas define a property with the same name. @@ -57,7 +69,7 @@ def merge_properties(prop1: Property, prop2: Property) -> Property | PropertyErr if isinstance(prop1, LiteralEnumProperty) or isinstance(prop2, LiteralEnumProperty): return _merge_with_literal_enum(prop1, prop2) - if (merged := _merge_same_type(prop1, prop2)) is not None: + if (merged := _merge_same_type(prop1, prop2, parent_name, config)) is not None: return merged if (merged := _merge_numeric(prop1, prop2)) is not None: @@ -71,7 +83,9 @@ def merge_properties(prop1: Property, prop2: Property) -> Property | PropertyErr ) -def _merge_same_type(prop1: Property, prop2: Property) -> Property | None | PropertyError: +def _merge_same_type( + prop1: Property, prop2: Property, parent_name: str, config: Config +) -> Property | None | PropertyError: if type(prop1) is not type(prop2): return None @@ -79,8 +93,11 @@ def _merge_same_type(prop1: Property, prop2: Property) -> Property | None | Prop # It's always OK to redefine a property with everything exactly the same return prop1 + if isinstance(prop1, ModelProperty) and isinstance(prop2, ModelProperty): + return _merge_models(prop1, prop2, parent_name, config) + if isinstance(prop1, ListProperty) and isinstance(prop2, ListProperty): - inner_property = merge_properties(prop1.inner_property, prop2.inner_property) # type: ignore + inner_property = merge_properties(prop1.inner_property, prop2.inner_property, "", config) # type: ignore if isinstance(inner_property, PropertyError): return PropertyError(detail=f"can't merge list properties: {inner_property.detail}") prop1.inner_property = inner_property @@ -90,6 +107,62 @@ def _merge_same_type(prop1: Property, prop2: Property) -> Property | None | Prop return _merge_common_attributes(prop1, prop2) +def _merge_models( + prop1: ModelProperty, prop2: ModelProperty, parent_name: str, config: Config +) -> Property | PropertyError: + # The logic here is basically equivalent to what we would do for a schema that was + # "allOf: [prop1, prop2]". We apply the property merge logic recursively and create a new third + # schema if the result cannot be fully described by one or the other. If it *can* be fully + # described by one or the other, then we can simply reuse the class for that one: for instance, + # in a common case where B is an object type that extends A and fully includes it, so that + # "allOf: [A, B]" would be the same as B, then it's valid to use the existing B model class. + # We would still call _merge_common_attributes in that case, to handle metadat like "description". + for prop in [prop1, prop2]: + if prop.needs_post_processing(): + # This means not all of the details of the schema have been filled in, possibly due to a + # forward reference. That may be resolved in a later pass, but for now we can't proceed. + return PropertyError(f"Schema for {prop} in allOf was not processed", data=prop.data) + + if _model_is_extension_of(prop1, prop2, parent_name, config): + return _merge_common_attributes(prop1, prop2) + elif _model_is_extension_of(prop2, prop1, parent_name, config): + return _merge_common_attributes(prop2, prop1, prop2) + + # Neither of the schemas is a superset of the other, so merging them will result in a new type. + merged_props: dict[str, Property] = {p.name: p for p in chain(prop1.required_properties, prop1.optional_properties)} + for model in [prop1, prop2]: + for sub_prop in chain(model.required_properties, model.optional_properties): + if sub_prop.name in merged_props: + merged_prop = merge_properties(merged_props[sub_prop.name], sub_prop, parent_name, config) + if isinstance(merged_prop, PropertyError): + return merged_prop + merged_props[sub_prop.name] = merged_prop + else: + merged_props[sub_prop.name] = sub_prop + + prop_details = _gather_property_data(merged_props.values()) + + name = prop2.name + class_string = f"{utils.pascal_case(parent_name)}{utils.pascal_case(name)}" + class_info = Class.from_string(string=class_string, config=config) + roots = prop1.roots.union(prop2.roots).difference({prop1.class_info.name, prop2.class_info.name}) + roots.add(class_info.name) + prop = ModelProperty( + class_info=class_info, + data=oai.Schema.model_construct(allOf=[prop1.data, prop2.data]), + roots=roots, + details=prop_details, + description=prop2.description or prop1.description, + default=None, + required=prop2.required or prop1.required, + name=name, + python_name=utils.PythonIdentifier(value=name, prefix=config.field_prefix), + example=prop2.example or prop1.example, + ) + + return prop + + def _merge_string_with_format(prop1: Property, prop2: Property) -> Property | None | PropertyError: """Merge a string that has no format with a string that has a format""" # Here we need to use the DateProperty/DateTimeProperty/FileProperty as the base so that we preserve @@ -196,3 +269,23 @@ def _merge_common_attributes(base: PropertyT, *extend_with: PropertyProtocol) -> def _values_are_subset(prop1: EnumProperty, prop2: EnumProperty) -> bool: return set(prop1.values.items()) <= set(prop2.values.items()) + + +def _model_is_extension_of( + extended_model: ModelProperty, base_model: ModelProperty, parent_name: str, config: Config +) -> bool: + def _properties_are_extension_of(extended_list: list[Property], base_list: list[Property]) -> bool: + for p2 in base_list: + if not [p1 for p1 in extended_list if _property_is_extension_of(p2, p1, parent_name, config)]: + return False + return True + + return _properties_are_extension_of( + extended_model.required_properties, base_model.required_properties + ) and _properties_are_extension_of(extended_model.optional_properties, base_model.optional_properties) + + +def _property_is_extension_of(extended_prop: Property, base_prop: Property, parent_name: str, config: Config) -> bool: + return base_prop.name == extended_prop.name and ( + base_prop == extended_prop or merge_properties(base_prop, extended_prop, parent_name, config) == extended_prop + ) diff --git a/openapi_python_client/parser/properties/model_property.py b/openapi_python_client/parser/properties/model_property.py index 762624501..e3b4ead3b 100644 --- a/openapi_python_client/parser/properties/model_property.py +++ b/openapi_python_client/parser/properties/model_property.py @@ -1,9 +1,10 @@ from __future__ import annotations +from collections.abc import Iterable from itertools import chain -from typing import Any, ClassVar, NamedTuple +from typing import Any, ClassVar -from attrs import define, evolve +from attrs import define, evolve, field from ... import Config, utils from ... import schema as oai @@ -14,6 +15,17 @@ from .schemas import Class, ReferencePath, Schemas, parse_reference_path +@define +class ModelDetails: + """Container for basic attributes of a model schema that can be computed separately""" + + required_properties: list[Property] | None = None + optional_properties: list[Property] | None = None + additional_properties: Property | None = None + relative_imports: set[str] = field(factory=set) + lazy_imports: set[str] = field(factory=set) + + @define class ModelProperty(PropertyProtocol): """A property which refers to another Schema""" @@ -27,11 +39,7 @@ class ModelProperty(PropertyProtocol): data: oai.Schema description: str roots: set[ReferencePath | utils.ClassName] - required_properties: list[Property] | None - optional_properties: list[Property] | None - relative_imports: set[str] | None - lazy_imports: set[str] | None - additional_properties: Property | None + details: ModelDetails _json_type_string: ClassVar[str] = "dict[str, Any]" template: ClassVar[str] = "model_property.py.jinja" @@ -75,22 +83,14 @@ def build( class_string = title class_info = Class.from_string(string=class_string, config=config) model_roots = {*roots, class_info.name} - required_properties: list[Property] | None = None - optional_properties: list[Property] | None = None - relative_imports: set[str] | None = None - lazy_imports: set[str] | None = None - additional_properties: Property | None = None + details = ModelDetails() if process_properties: data_or_err, schemas = _process_property_data( data=data, schemas=schemas, class_info=class_info, config=config, roots=model_roots ) if isinstance(data_or_err, PropertyError): return data_or_err, schemas - property_data, additional_properties = data_or_err - required_properties = property_data.required_props - optional_properties = property_data.optional_props - relative_imports = property_data.relative_imports - lazy_imports = property_data.lazy_imports + details = data_or_err for root in roots: if isinstance(root, utils.ClassName): continue @@ -100,11 +100,7 @@ def build( class_info=class_info, data=data, roots=model_roots, - required_properties=required_properties, - optional_properties=optional_properties, - relative_imports=relative_imports, - lazy_imports=lazy_imports, - additional_properties=additional_properties, + details=details, description=data.description or "", default=None, required=required, @@ -125,6 +121,31 @@ def build( ) return prop, schemas + def needs_post_processing(self) -> bool: + return not ( + isinstance(self.details.required_properties, list) and isinstance(self.details.optional_properties, list) + ) + + @property + def required_properties(self) -> list[Property]: + return self.details.required_properties or [] + + @property + def optional_properties(self) -> list[Property]: + return self.details.optional_properties or [] + + @property + def additional_properties(self) -> Property | None: + return self.details.additional_properties + + @property + def relative_imports(self) -> set[str]: + return self.details.relative_imports + + @property + def lazy_imports(self) -> set[str] | None: + return self.details.lazy_imports + @classmethod def convert_value(cls, value: Any) -> Value | None | PropertyError: if value is not None: @@ -132,7 +153,7 @@ def convert_value(cls, value: Any) -> Value | None | PropertyError: return None def __attrs_post_init__(self) -> None: - if self.relative_imports: + if self.details.relative_imports: self.set_relative_imports(self.relative_imports) @property @@ -174,7 +195,7 @@ def set_relative_imports(self, relative_imports: set[str]) -> None: Args: relative_imports: The set of relative import strings """ - object.__setattr__(self, "relative_imports", {ri for ri in relative_imports if self.self_import not in ri}) + self.details.relative_imports = {ri for ri in relative_imports if self.self_import not in ri} def set_lazy_imports(self, lazy_imports: set[str]) -> None: """Set the lazy imports set for this ModelProperty, filtering out self imports @@ -182,7 +203,7 @@ def set_lazy_imports(self, lazy_imports: set[str]) -> None: Args: lazy_imports: The set of lazy import strings """ - object.__setattr__(self, "lazy_imports", {li for li in lazy_imports if self.self_import not in li}) + self.details.lazy_imports = {li for li in lazy_imports if self.self_import not in li} def get_type_string( self, @@ -229,35 +250,25 @@ def _resolve_naming_conflict(first: Property, second: Property, config: Config) return None -class _PropertyData(NamedTuple): - optional_props: list[Property] - required_props: list[Property] - relative_imports: set[str] - lazy_imports: set[str] - schemas: Schemas - - -def _process_properties( # noqa: PLR0912, PLR0911 +def _process_properties( # noqa: PLR0911 *, data: oai.Schema, schemas: Schemas, class_name: utils.ClassName, config: Config, roots: set[ReferencePath | utils.ClassName], -) -> _PropertyData | PropertyError: +) -> tuple[ModelDetails | PropertyError, Schemas]: from . import property_from_data from .merge_properties import merge_properties properties: dict[str, Property] = {} - relative_imports: set[str] = set() - lazy_imports: set[str] = set() required_set = set(data.required or []) def _add_if_no_conflict(new_prop: Property) -> PropertyError | None: nonlocal properties name_conflict = properties.get(new_prop.name) - merged_prop = merge_properties(name_conflict, new_prop) if name_conflict else new_prop + merged_prop = merge_properties(name_conflict, new_prop, class_name, config) if name_conflict else new_prop if isinstance(merged_prop, PropertyError): merged_prop.header = f"Found conflicting properties named {new_prop.name} when creating {class_name}" return merged_prop @@ -281,21 +292,19 @@ def _add_if_no_conflict(new_prop: Property) -> PropertyError | None: if isinstance(sub_prop, oai.Reference): ref_path = parse_reference_path(sub_prop.ref) if isinstance(ref_path, ParseError): - return PropertyError(detail=ref_path.detail, data=sub_prop) + return PropertyError(detail=ref_path.detail, data=sub_prop), schemas sub_model = schemas.classes_by_reference.get(ref_path) if sub_model is None: - return PropertyError(f"Reference {sub_prop.ref} not found") + return PropertyError(f"Reference {sub_prop.ref} not found"), schemas if not isinstance(sub_model, ModelProperty): - return PropertyError("Cannot take allOf a non-object") + return PropertyError("Cannot take allOf a non-object"), schemas # Properties of allOf references first should be processed first - if not ( - isinstance(sub_model.required_properties, list) and isinstance(sub_model.optional_properties, list) - ): - return PropertyError(f"Reference {sub_model.name} in allOf was not processed", data=sub_prop) + if sub_model.needs_post_processing(): + return PropertyError(f"Reference {sub_model.name} in allOf was not processed", data=sub_prop), schemas for prop in chain(sub_model.required_properties, sub_model.optional_properties): err = _add_if_no_conflict(prop) if err is not None: - return err + return err, schemas schemas.add_dependencies(ref_path=ref_path, roots=roots) else: unprocessed_props.extend(sub_prop.properties.items() if sub_prop.properties else []) @@ -316,25 +325,26 @@ def _add_if_no_conflict(new_prop: Property) -> PropertyError | None: if not isinstance(prop_or_error, PropertyError): prop_or_error = _add_if_no_conflict(prop_or_error) if isinstance(prop_or_error, PropertyError): - return prop_or_error + return prop_or_error, schemas + + return _gather_property_data(properties.values()), schemas - required_properties = [] - optional_properties = [] - for prop in properties.values(): - if prop.required: - required_properties.append(prop) - else: - optional_properties.append(prop) +def _gather_property_data(properties: Iterable[Property]) -> ModelDetails: + required_properties: list[Property] = [] + optional_properties: list[Property] = [] + relative_imports: set[str] = set() + lazy_imports: set[str] = set() + for prop in properties: + (required_properties if prop.required else optional_properties).append(prop) lazy_imports.update(prop.get_lazy_imports(prefix="..")) relative_imports.update(prop.get_imports(prefix="..")) - - return _PropertyData( - optional_props=optional_properties, - required_props=required_properties, + return ModelDetails( + optional_properties=optional_properties, + required_properties=required_properties, relative_imports=relative_imports, lazy_imports=lazy_imports, - schemas=schemas, + additional_properties=None, ) @@ -389,13 +399,12 @@ def _process_property_data( class_info: Class, config: Config, roots: set[ReferencePath | utils.ClassName], -) -> tuple[tuple[_PropertyData, Property | None] | PropertyError, Schemas]: - property_data = _process_properties( +) -> tuple[ModelDetails | PropertyError, Schemas]: + model_details, schemas = _process_properties( data=data, schemas=schemas, class_name=class_info.name, config=config, roots=roots ) - if isinstance(property_data, PropertyError): - return property_data, schemas - schemas = property_data.schemas + if isinstance(model_details, PropertyError): + return model_details, schemas additional_properties, schemas = _get_additional_properties( schema_additional=data.additionalProperties, @@ -409,10 +418,11 @@ def _process_property_data( elif additional_properties is None: pass else: - property_data.relative_imports.update(additional_properties.get_imports(prefix="..")) - property_data.lazy_imports.update(additional_properties.get_lazy_imports(prefix="..")) + model_details = evolve(model_details, additional_properties=additional_properties) + model_details.relative_imports.update(additional_properties.get_imports(prefix="..")) + model_details.lazy_imports.update(additional_properties.get_lazy_imports(prefix="..")) - return (property_data, additional_properties), schemas + return model_details, schemas def process_model(model_prop: ModelProperty, *, schemas: Schemas, config: Config) -> Schemas | PropertyError: @@ -434,11 +444,8 @@ def process_model(model_prop: ModelProperty, *, schemas: Schemas, config: Config if isinstance(data_or_err, PropertyError): return data_or_err - property_data, additional_properties = data_or_err + model_prop.details = data_or_err + model_prop.set_relative_imports(data_or_err.relative_imports) + model_prop.set_lazy_imports(data_or_err.lazy_imports) - object.__setattr__(model_prop, "required_properties", property_data.required_props) - object.__setattr__(model_prop, "optional_properties", property_data.optional_props) - model_prop.set_relative_imports(property_data.relative_imports) - model_prop.set_lazy_imports(property_data.lazy_imports) - object.__setattr__(model_prop, "additional_properties", additional_properties) return schemas diff --git a/openapi_python_client/parser/properties/protocol.py b/openapi_python_client/parser/properties/protocol.py index 9a5b51828..0ee64088d 100644 --- a/openapi_python_client/parser/properties/protocol.py +++ b/openapi_python_client/parser/properties/protocol.py @@ -185,3 +185,7 @@ def is_base_type(self) -> bool: ListProperty.__name__, UnionProperty.__name__, } + + def needs_post_processing(self) -> bool: + """Returns true if the parser should call process_model() on this property in a second pass.""" + return False diff --git a/tests/conftest.py b/tests/conftest.py index 969e57cbd..d47b776ce 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,6 +26,7 @@ UnionProperty, ) from openapi_python_client.parser.properties.float import FloatProperty +from openapi_python_client.parser.properties.model_property import ModelDetails from openapi_python_client.parser.properties.protocol import PropertyType, Value from openapi_python_client.schema.openapi_schema_pydantic import Parameter from openapi_python_client.schema.parameter_location import ParameterLocation @@ -65,15 +66,25 @@ def _factory(**kwargs): "class_info": Class(name=ClassName("MyClass", ""), module_name=PythonIdentifier("my_module", "")), "data": oai.Schema.model_construct(), "roots": set(), - "required_properties": None, - "optional_properties": None, - "relative_imports": None, - "lazy_imports": None, - "additional_properties": None, "python_name": "", "example": "", **kwargs, } + # shortcuts for setting attributes within ModelDetails + if "details" not in kwargs: + detail_args = {} + for arg_name in [ + "required_properties", + "optional_properties", + "additional_properties", + "relative_imports", + "lazy_imports", + ]: + if arg_name in kwargs: + detail_args[arg_name] = kwargs[arg_name] + kwargs.pop(arg_name) + kwargs["details"] = ModelDetails(**detail_args) + return ModelProperty(**kwargs) return _factory diff --git a/tests/test_parser/test_properties/test_merge_properties.py b/tests/test_parser/test_properties/test_merge_properties.py index 819f9ec26..71b30901b 100644 --- a/tests/test_parser/test_properties/test_merge_properties.py +++ b/tests/test_parser/test_properties/test_merge_properties.py @@ -7,6 +7,7 @@ from openapi_python_client.parser.properties.float import FloatProperty from openapi_python_client.parser.properties.int import IntProperty from openapi_python_client.parser.properties.merge_properties import merge_properties +from openapi_python_client.parser.properties.model_property import ModelDetails, ModelProperty from openapi_python_client.parser.properties.protocol import Value from openapi_python_client.parser.properties.schemas import Class from openapi_python_client.parser.properties.string import StringProperty @@ -21,6 +22,7 @@ def test_merge_basic_attributes_same_type( string_property_factory, list_property_factory, model_property_factory, + config, ): basic_props = [ boolean_property_factory(default=Value(python_code="True", raw_value="True")), @@ -28,20 +30,20 @@ def test_merge_basic_attributes_same_type( float_property_factory(default=Value("1.5", 1.5)), string_property_factory(default=StringProperty.convert_value("x")), list_property_factory(), - model_property_factory(), + model_property_factory(required_properties=[], optional_properties=[]), ] for basic_prop in basic_props: with_required = evolve(basic_prop, required=True) - assert merge_properties(basic_prop, with_required) == with_required - assert merge_properties(with_required, basic_prop) == with_required + assert merge_properties(basic_prop, with_required, "", config) == with_required + assert merge_properties(with_required, basic_prop, "", config) == with_required without_default = evolve(basic_prop, default=None) - assert merge_properties(basic_prop, without_default) == basic_prop - assert merge_properties(without_default, basic_prop) == basic_prop + assert merge_properties(basic_prop, without_default, "", config) == basic_prop + assert merge_properties(without_default, basic_prop, "", config) == basic_prop with_desc1 = evolve(basic_prop, description="desc1") with_desc2 = evolve(basic_prop, description="desc2") - assert merge_properties(basic_prop, with_desc1) == with_desc1 - assert merge_properties(with_desc1, basic_prop) == with_desc1 - assert merge_properties(with_desc1, with_desc2) == with_desc2 + assert merge_properties(basic_prop, with_desc1, "", config) == with_desc1 + assert merge_properties(with_desc1, basic_prop, "", config) == with_desc1 + assert merge_properties(with_desc1, with_desc2, "", config) == with_desc2 def test_incompatible_types( @@ -51,6 +53,7 @@ def test_incompatible_types( string_property_factory, list_property_factory, model_property_factory, + config, ): props = [ boolean_property_factory(default=True), @@ -64,21 +67,21 @@ def test_incompatible_types( for prop1, prop2 in permutations(props, 2): if {prop1.__class__, prop2.__class__} == {IntProperty, FloatProperty}: continue # the int+float case is covered in another test - error = merge_properties(prop1, prop2) + error = merge_properties(prop1, prop2, "", config) assert isinstance(error, PropertyError), f"Expected {type(prop1)} and {type(prop2)} to be incompatible" -def test_merge_int_with_float(int_property_factory, float_property_factory): +def test_merge_int_with_float(int_property_factory, float_property_factory, config): int_prop = int_property_factory(description="desc1") float_prop = float_property_factory(default=Value("2", 2), description="desc2") - assert merge_properties(int_prop, float_prop) == ( + assert merge_properties(int_prop, float_prop, "", config) == ( evolve(int_prop, default=Value("2", 2), description=float_prop.description) ) - assert merge_properties(float_prop, int_prop) == evolve(int_prop, default=Value("2", 2)) + assert merge_properties(float_prop, int_prop, "", config) == evolve(int_prop, default=Value("2", 2)) float_prop_with_non_int_default = evolve(float_prop, default=Value("2.5", 2.5)) - error = merge_properties(int_prop, float_prop_with_non_int_default) + error = merge_properties(int_prop, float_prop_with_non_int_default, "", config) assert isinstance(error, PropertyError), "Expected invalid default to error" assert error.detail == "Invalid int value: 2.5" @@ -90,6 +93,7 @@ def test_merge_with_any( float_property_factory, string_property_factory, model_property_factory, + config, ): original_desc = "description" props = [ @@ -101,8 +105,8 @@ def test_merge_with_any( ] any_prop = any_property_factory() for prop in props: - assert merge_properties(any_prop, prop) == prop - assert merge_properties(prop, any_prop) == prop + assert merge_properties(any_prop, prop, "", config) == prop + assert merge_properties(prop, any_prop, "", config) == prop @pytest.mark.parametrize("literal_enums", (False, True)) @@ -135,13 +139,13 @@ def test_merge_enums(literal_enums, enum_property_factory, literal_enum_property enum_with_fewer_values.class_info = Class.from_string(string="FewerValuesEnum", config=config) enum_with_more_values.class_info = Class.from_string(string="MoreValuesEnum", config=config) - assert merge_properties(enum_with_fewer_values, enum_with_more_values) == evolve( + assert merge_properties(enum_with_fewer_values, enum_with_more_values, "", config) == evolve( enum_with_more_values, values=enum_with_fewer_values.values, class_info=enum_with_fewer_values.class_info, description=enum_with_fewer_values.description, ) - assert merge_properties(enum_with_more_values, enum_with_fewer_values) == evolve( + assert merge_properties(enum_with_more_values, enum_with_fewer_values, "", config) == evolve( enum_with_fewer_values, example=enum_with_more_values.example, ) @@ -149,7 +153,7 @@ def test_merge_enums(literal_enums, enum_property_factory, literal_enum_property @pytest.mark.parametrize("literal_enums", (False, True)) def test_merge_string_with_string_enum( - literal_enums, string_property_factory, enum_property_factory, literal_enum_property_factory + literal_enums, string_property_factory, enum_property_factory, literal_enum_property_factory, config ): string_prop = string_property_factory(default=Value("A", "A"), description="desc1", example="example1") enum_prop = ( @@ -170,8 +174,8 @@ def test_merge_string_with_string_enum( ) ) - assert merge_properties(string_prop, enum_prop) == evolve(enum_prop, required=True) - assert merge_properties(enum_prop, string_prop) == evolve( + assert merge_properties(string_prop, enum_prop, "", config) == evolve(enum_prop, required=True) + assert merge_properties(enum_prop, string_prop, "", config) == evolve( enum_prop, required=True, default=Value("'A'" if literal_enums else "test.A", "A"), @@ -182,7 +186,7 @@ def test_merge_string_with_string_enum( @pytest.mark.parametrize("literal_enums", (False, True)) def test_merge_int_with_int_enum( - literal_enums, int_property_factory, enum_property_factory, literal_enum_property_factory + literal_enums, int_property_factory, enum_property_factory, literal_enum_property_factory, config ): int_prop = int_property_factory(default=Value("1", 1), description="desc1", example="example1") enum_prop = ( @@ -203,8 +207,8 @@ def test_merge_int_with_int_enum( ) ) - assert merge_properties(int_prop, enum_prop) == evolve(enum_prop, required=True) - assert merge_properties(enum_prop, int_prop) == evolve( + assert merge_properties(int_prop, enum_prop, "", config) == evolve(enum_prop, required=True) + assert merge_properties(enum_prop, int_prop, "", config) == evolve( enum_prop, required=True, description=int_prop.description, example=int_prop.example ) @@ -219,6 +223,7 @@ def test_merge_with_incompatible_enum( enum_property_factory, literal_enum_property_factory, model_property_factory, + config, ): props = [ boolean_property_factory(), @@ -241,11 +246,11 @@ def test_merge_with_incompatible_enum( ) for prop in props: if not isinstance(prop, StringProperty): - assert isinstance(merge_properties(prop, string_enum_prop), PropertyError) - assert isinstance(merge_properties(string_enum_prop, prop), PropertyError) + assert isinstance(merge_properties(prop, string_enum_prop, "", config), PropertyError) + assert isinstance(merge_properties(string_enum_prop, prop, "", config), PropertyError) if not isinstance(prop, IntProperty): - assert isinstance(merge_properties(prop, int_enum_prop), PropertyError) - assert isinstance(merge_properties(int_enum_prop, prop), PropertyError) + assert isinstance(merge_properties(prop, int_enum_prop, "", config), PropertyError) + assert isinstance(merge_properties(int_enum_prop, prop, "", config), PropertyError) def test_merge_string_with_formatted_string( @@ -253,6 +258,7 @@ def test_merge_string_with_formatted_string( date_time_property_factory, file_property_factory, string_property_factory, + config, ): string_prop = string_property_factory(description="a plain string") string_prop_with_invalid_default = string_property_factory( @@ -264,19 +270,19 @@ def test_merge_string_with_formatted_string( file_property_factory(description="a file"), ] for formatted_prop in formatted_props: - merged1 = merge_properties(string_prop, formatted_prop) + merged1 = merge_properties(string_prop, formatted_prop, "", config) assert isinstance(merged1, formatted_prop.__class__) assert merged1.description == formatted_prop.description - merged2 = merge_properties(formatted_prop, string_prop) + merged2 = merge_properties(formatted_prop, string_prop, "", config) assert isinstance(merged2, formatted_prop.__class__) assert merged2.description == string_prop.description - assert isinstance(merge_properties(string_prop_with_invalid_default, formatted_prop), PropertyError) - assert isinstance(merge_properties(formatted_prop, string_prop_with_invalid_default), PropertyError) + assert isinstance(merge_properties(string_prop_with_invalid_default, formatted_prop, "", config), PropertyError) + assert isinstance(merge_properties(formatted_prop, string_prop_with_invalid_default, "", config), PropertyError) -def test_merge_lists(int_property_factory, list_property_factory, string_property_factory): +def test_merge_lists(int_property_factory, list_property_factory, string_property_factory, config): string_prop_1 = string_property_factory(description="desc1") string_prop_2 = string_property_factory(example="desc2") int_prop = int_property_factory() @@ -284,8 +290,134 @@ def test_merge_lists(int_property_factory, list_property_factory, string_propert list_prop_2 = list_property_factory(inner_property=string_prop_2) list_prop_3 = list_property_factory(inner_property=int_prop) - assert merge_properties(list_prop_1, list_prop_2) == evolve( - list_prop_1, inner_property=merge_properties(string_prop_1, string_prop_2) + assert merge_properties(list_prop_1, list_prop_2, "", config) == evolve( + list_prop_1, inner_property=merge_properties(string_prop_1, string_prop_2, "", config) + ) + + assert isinstance(merge_properties(list_prop_1, list_prop_3, "", config), PropertyError) + + +def test_merge_related_models(model_property_factory, string_property_factory, config): + base_model = model_property_factory( + name="BaseModel", + details=ModelDetails( + required_properties=[ + string_property_factory(name="req_1", description="base description"), + ], + optional_properties=[ + string_property_factory(name="opt_1"), + ], + ), + description="desc_1", + example="example_1", + class_info=Class.from_string(string="BaseModel", config=config), + ) + extension_model = model_property_factory( + name="ExtensionModel", + details=ModelDetails( + required_properties=[ + string_property_factory(name="req_1", description="extended description"), + string_property_factory(name="req_2"), + ], + optional_properties=[ + string_property_factory(name="opt_1"), + string_property_factory(name="opt_2"), + ], + ), + description="desc_2", + class_info=Class.from_string(string="DerivedModel", config=config), + ) + + assert merge_properties(base_model, extension_model, "", config) == evolve( + extension_model, example=base_model.example + ) + assert merge_properties(extension_model, base_model, "", config) == evolve( + extension_model, description=base_model.description, example=base_model.example ) - assert isinstance(merge_properties(list_prop_1, list_prop_3), PropertyError) + +def test_merge_unrelated_models(model_property_factory, string_property_factory, config): + model_1 = model_property_factory( + name="propName", + details=ModelDetails( + required_properties=[ + string_property_factory(name="req_1", required=True), + ], + optional_properties=[ + string_property_factory(name="opt_1", required=False), + ], + ), + description="desc_1", + example="example_1", + class_info=Class.from_string(string="Model1", config=config), + ) + model_2 = model_property_factory( + name="propName", + details=ModelDetails( + required_properties=[ + string_property_factory(name="req_1", required=True), + string_property_factory(name="req_2", required=True), + ], + optional_properties=[ + string_property_factory(name="opt_2", required=False), + ], + ), + description="desc_2", + class_info=Class.from_string(string="Model2", config=config), + ) + + result = merge_properties(model_1, model_2, "ParentSchema", config) + + assert isinstance(result, ModelProperty) + assert [p.name for p in result.required_properties] == ["req_1", "req_2"] + assert [p.name for p in result.optional_properties] == ["opt_1", "opt_2"] + assert result.class_info.name == "ParentSchemaPropName" + assert result.description == model_2.description + + +def test_merge_models_with_incompatible_property( + model_property_factory, string_property_factory, int_property_factory, config +): + model_1 = model_property_factory( + name="propName", + details=ModelDetails( + required_properties=[ + string_property_factory(name="prop1", required=True), + ], + optional_properties=[], + ), + class_info=Class.from_string(string="Model1", config=config), + ) + model_2 = model_property_factory( + name="propName", + details=ModelDetails( + required_properties=[ + int_property_factory(name="prop1", required=True), + ], + optional_properties=[], + ), + class_info=Class.from_string(string="Model2", config=config), + ) + + result = merge_properties(model_1, model_2, "ParentSchema", config) + + assert isinstance(result, PropertyError) + assert result.detail == "str can't be merged with int" + + +def test_merge_models_not_yet_processed(model_property_factory, string_property_factory, int_property_factory, config): + model_1 = model_property_factory( + name="propName", + details=ModelDetails(required_properties=None, optional_properties=None), + class_info=Class.from_string(string="Model1", config=config), + ) + model_2 = model_property_factory( + name="propName", + details=ModelDetails(required_properties=None, optional_properties=None), + class_info=Class.from_string(string="Model2", config=config), + ) + + result = merge_properties(model_1, model_2, "ParentSchema", config) + + assert isinstance(result, PropertyError) + assert "not processed" in result.detail diff --git a/tests/test_parser/test_properties/test_model_property.py b/tests/test_parser/test_properties/test_model_property.py index a51fd984b..8459cea2d 100644 --- a/tests/test_parser/test_properties/test_model_property.py +++ b/tests/test_parser/test_properties/test_model_property.py @@ -6,7 +6,12 @@ import openapi_python_client.schema as oai from openapi_python_client.parser.errors import PropertyError from openapi_python_client.parser.properties import Schemas, StringProperty -from openapi_python_client.parser.properties.model_property import ANY_ADDITIONAL_PROPERTY, _process_properties +from openapi_python_client.parser.properties.model_property import ( + ANY_ADDITIONAL_PROPERTY, + ModelDetails, + ModelProperty, + _process_properties, +) MODULE_NAME = "openapi_python_client.parser.properties.model_property" @@ -338,7 +343,7 @@ def test_conflicting_properties_different_types( } ) - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) assert isinstance(result, PropertyError) @@ -349,14 +354,14 @@ def test_process_properties_reference_not_exist(self, config): }, ) - result = _process_properties(data=data, class_name="", schemas=Schemas(), config=config, roots={"root"}) + result, _ = _process_properties(data=data, class_name="", schemas=Schemas(), config=config, roots={"root"}) assert isinstance(result, PropertyError) def test_process_properties_all_of_reference_not_exist(self, config): data = oai.Schema.model_construct(allOf=[oai.Reference.model_construct(ref="#/components/schema/NotExist")]) - result = _process_properties(data=data, class_name="", schemas=Schemas(), config=config, roots={"root"}) + result, _ = _process_properties(data=data, class_name="", schemas=Schemas(), config=config, roots={"root"}) assert isinstance(result, PropertyError) @@ -364,15 +369,15 @@ def test_process_properties_model_property_roots(self, model_property_factory, c roots = {"root"} data = oai.Schema(properties={"test_model_property": oai.Schema.model_construct(type="object")}) - result = _process_properties(data=data, class_name="", schemas=Schemas(), config=config, roots=roots) + result, _ = _process_properties(data=data, class_name="", schemas=Schemas(), config=config, roots=roots) - assert all(root in result.optional_props[0].roots for root in roots) + assert all(root in result.optional_properties[0].roots for root in roots) def test_invalid_reference(self, config): data = oai.Schema.model_construct(allOf=[oai.Reference.model_construct(ref="ThisIsNotGood")]) schemas = Schemas() - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) assert isinstance(result, PropertyError) @@ -384,7 +389,7 @@ def test_non_model_reference(self, enum_property_factory, config): } ) - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) assert isinstance(result, PropertyError) @@ -396,7 +401,7 @@ def test_reference_not_processed(self, model_property_factory, config): } ) - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) assert isinstance(result, PropertyError) @@ -419,8 +424,8 @@ def test_allof_string_and_string_enum( } ) - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) - assert result.required_props[0] == enum_property + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + assert result.required_properties[0] == enum_property def test_allof_string_enum_and_string( self, model_property_factory, enum_property_factory, string_property_factory, config @@ -442,8 +447,8 @@ def test_allof_string_enum_and_string( } ) - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) - assert result.optional_props[0] == enum_property + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + assert result.optional_properties[0] == enum_property def test_allof_int_and_int_enum(self, model_property_factory, enum_property_factory, int_property_factory, config): data = oai.Schema.model_construct( @@ -460,8 +465,8 @@ def test_allof_int_and_int_enum(self, model_property_factory, enum_property_fact } ) - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) - assert result.required_props[0] == enum_property + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + assert result.required_properties[0] == enum_property def test_allof_enum_incompatible_type( self, model_property_factory, enum_property_factory, int_property_factory, config @@ -480,7 +485,7 @@ def test_allof_enum_incompatible_type( } ) - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) assert isinstance(result, PropertyError) def test_allof_string_enums(self, model_property_factory, enum_property_factory, config): @@ -504,8 +509,8 @@ def test_allof_string_enums(self, model_property_factory, enum_property_factory, } ) - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) - assert result.required_props[0] == enum_property1 + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + assert result.required_properties[0] == enum_property1 def test_allof_int_enums(self, model_property_factory, enum_property_factory, config): data = oai.Schema.model_construct( @@ -528,8 +533,8 @@ def test_allof_int_enums(self, model_property_factory, enum_property_factory, co } ) - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) - assert result.required_props[0] == enum_property2 + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + assert result.required_properties[0] == enum_property2 def test_allof_enums_are_not_subsets(self, model_property_factory, enum_property_factory, config): data = oai.Schema.model_construct( @@ -552,7 +557,7 @@ def test_allof_enums_are_not_subsets(self, model_property_factory, enum_property } ) - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) assert isinstance(result, PropertyError) def test_duplicate_properties(self, model_property_factory, string_property_factory, config): @@ -567,9 +572,9 @@ def test_duplicate_properties(self, model_property_factory, string_property_fact } ) - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) - assert result.optional_props == [prop], "There should only be one copy of duplicate properties" + assert result.optional_properties == [prop], "There should only be one copy of duplicate properties" @pytest.mark.parametrize("first_required", [True, False]) @pytest.mark.parametrize("second_required", [True, False]) @@ -598,18 +603,18 @@ def test_mixed_requirements( ) roots = {"root"} - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots=roots) + result, schemas = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots=roots) required = first_required or second_required expected_prop = string_property_factory( required=required, ) - assert result.schemas.dependencies == {"/First": roots, "/Second": roots} + assert schemas.dependencies == {"/First": roots, "/Second": roots} if not required: - assert result.optional_props == [expected_prop] + assert result.optional_properties == [expected_prop] else: - assert result.required_props == [expected_prop] + assert result.required_properties == [expected_prop] def test_direct_properties_non_ref(self, string_property_factory, config): data = oai.Schema.model_construct( @@ -625,10 +630,10 @@ def test_direct_properties_non_ref(self, string_property_factory, config): ) schemas = Schemas() - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) - assert result.optional_props == [string_property_factory(name="second", required=False)] - assert result.required_props == [string_property_factory(name="first", required=True)] + assert result.optional_properties == [string_property_factory(name="second", required=False)] + assert result.required_properties == [string_property_factory(name="first", required=True)] def test_conflicting_property_names(self, config): data = oai.Schema.model_construct( @@ -638,7 +643,7 @@ def test_conflicting_property_names(self, config): } ) schemas = Schemas() - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) assert isinstance(result, PropertyError) def test_merge_inline_objects(self, model_property_factory, enum_property_factory, config): @@ -660,10 +665,10 @@ def test_merge_inline_objects(self, model_property_factory, enum_property_factor ) schemas = Schemas() - result = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) + result, _ = _process_properties(data=data, schemas=schemas, class_name="", config=config, roots={"root"}) assert not isinstance(result, PropertyError) - assert len(result.optional_props) == 1 - prop1 = result.optional_props[0] + assert len(result.optional_properties) == 1 + prop1 = result.optional_properties[0] assert isinstance(prop1, StringProperty) assert prop1.description == "desc" assert prop1.default == StringProperty.convert_value("a") @@ -674,7 +679,7 @@ def test_process_model_error(self, mocker, model_property_factory, config): from openapi_python_client.parser.properties import Schemas from openapi_python_client.parser.properties.model_property import process_model - model_prop = model_property_factory() + model_prop: ModelProperty = model_property_factory(details=ModelDetails()) schemas = Schemas() process_property_data = mocker.patch(f"{MODULE_NAME}._process_property_data") process_property_data.return_value = (PropertyError(), schemas) @@ -682,36 +687,36 @@ def test_process_model_error(self, mocker, model_property_factory, config): result = process_model(model_prop=model_prop, schemas=schemas, config=config) assert result == PropertyError() - assert model_prop.required_properties is None - assert model_prop.optional_properties is None - assert model_prop.relative_imports is None + assert model_prop.needs_post_processing() + assert model_prop.required_properties == [] + assert model_prop.optional_properties == [] + assert model_prop.relative_imports == set() assert model_prop.additional_properties is None - def test_process_model(self, mocker, model_property_factory, config): + def test_process_model(self, mocker, model_property_factory, string_property_factory, config): from openapi_python_client.parser.properties import Schemas - from openapi_python_client.parser.properties.model_property import _PropertyData, process_model + from openapi_python_client.parser.properties.model_property import ModelDetails, process_model model_prop = model_property_factory() schemas = Schemas() - property_data = _PropertyData( - required_props=["required"], - optional_props=["optional"], + model_details = ModelDetails( + required_properties=["required"], + optional_properties=["optional"], relative_imports={"relative"}, lazy_imports={"lazy"}, - schemas=schemas, + additional_properties=string_property_factory(), ) - additional_properties = True process_property_data = mocker.patch(f"{MODULE_NAME}._process_property_data") - process_property_data.return_value = ((property_data, additional_properties), schemas) + process_property_data.return_value = (model_details, schemas) result = process_model(model_prop=model_prop, schemas=schemas, config=config) assert result == schemas - assert model_prop.required_properties == property_data.required_props - assert model_prop.optional_properties == property_data.optional_props - assert model_prop.relative_imports == property_data.relative_imports - assert model_prop.lazy_imports == property_data.lazy_imports - assert model_prop.additional_properties == additional_properties + assert model_prop.required_properties == model_details.required_properties + assert model_prop.optional_properties == model_details.optional_properties + assert model_prop.relative_imports == model_details.relative_imports + assert model_prop.lazy_imports == model_details.lazy_imports + assert model_prop.additional_properties == model_details.additional_properties def test_set_relative_imports(model_property_factory): @@ -720,6 +725,9 @@ def test_set_relative_imports(model_property_factory): class_info = Class("ClassName", module_name="module_name") relative_imports = {f"from ..models.{class_info.module_name} import {class_info.name}"} - model_property = model_property_factory(class_info=class_info, relative_imports=relative_imports) + model_property = model_property_factory( + class_info=class_info, + details=ModelDetails(relative_imports=relative_imports), + ) assert model_property.relative_imports == set() diff --git a/tests/test_parser/test_properties/test_protocol.py b/tests/test_parser/test_properties/test_protocol.py index 1d4111750..0b9c627f3 100644 --- a/tests/test_parser/test_properties/test_protocol.py +++ b/tests/test_parser/test_properties/test_protocol.py @@ -85,3 +85,11 @@ def test_get_base_json_type_string(quoted, expected, any_property_factory, mocke mocker.patch.object(AnyProperty, "_json_type_string", "str") p = any_property_factory() assert p.get_base_json_type_string(quoted=quoted) is expected + + +def test_needs_post_processing(any_property_factory, model_property_factory): + p1 = any_property_factory() + assert p1.needs_post_processing() is False + + p2 = model_property_factory() + assert p2.needs_post_processing() is True