From 8d8dc89116be035d5787c18ea9ec08e95c9af011 Mon Sep 17 00:00:00 2001 From: Victoria Litvinova Date: Fri, 30 Jun 2023 09:21:24 -0700 Subject: [PATCH 1/4] init changes --- .../iothub/providers/message_endpoint.py | 56 ++++++++++++++----- azext_iot/tests/iothub/conftest.py | 1 + 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/azext_iot/iothub/providers/message_endpoint.py b/azext_iot/iothub/providers/message_endpoint.py index 69d68b2cc..fc5b053e6 100644 --- a/azext_iot/iothub/providers/message_endpoint.py +++ b/azext_iot/iothub/providers/message_endpoint.py @@ -42,7 +42,12 @@ def __init__( rg: Optional[str] = None, ): super(MessageEndpoint, self).__init__(cmd, hub_name, rg, dataplane=False) - self.support_cosmos = hasattr(self.hub_resource.properties.routing.endpoints, "cosmos_db_sql_collections") + # Temporary flag to check for which cosmos property to look for. + self.support_cosmos = 0 + if hasattr(self.hub_resource.properties.routing.endpoints, "cosmos_db_sql_collections"): + self.support_cosmos = 1 + if hasattr(self.hub_resource.properties.routing.endpoints, "cosmos_db_sql_containers"): + self.support_cosmos = 2 self.cli = EmbeddedCLI(cli_ctx=self.cmd.cli_ctx) def create( @@ -179,16 +184,22 @@ def create( del new_endpoint["connectionString"] new_endpoint.update({ "databaseName": database_name, - "collectionName": container_name, "primaryKey": primary_key, "secondaryKey": secondary_key, "partitionKeyName": partition_key_name, "partitionKeyTemplate": partition_key_template, }) - # TODO @vilit - why is this None if empty - if endpoints.cosmos_db_sql_collections is None: - endpoints.cosmos_db_sql_collections = [] - endpoints.cosmos_db_sql_collections.append(new_endpoint) + # TODO @vilit - None checks for when the service breaks things + if self.support_cosmos == 2: + new_endpoint["containerName"] = container_name + if endpoints.cosmos_db_sql_containers is None: + endpoints.cosmos_db_sql_containers = [] + endpoints.cosmos_db_sql_containers.append(new_endpoint) + if self.support_cosmos == 1: + new_endpoint["collectionName"] = container_name + if endpoints.cosmos_db_sql_collections is None: + endpoints.cosmos_db_sql_collections = [] + endpoints.cosmos_db_sql_collections.append(new_endpoint) elif endpoint_type.lower() == EndpointType.AzureStorageContainer.value: if fetch_connection_string: # try to get connection string @@ -325,8 +336,10 @@ def update( original_endpoint.endpoint_uri = parsed_cs["AccountEndpoint"] if database_name: original_endpoint.database_name = database_name - if container_name: + if container_name and self.support_cosmos == 2: original_endpoint.container_name = container_name + if container_name and self.support_cosmos == 1: + original_endpoint.collection_name = container_name if partition_key_name: original_endpoint.partition_key_name = None if partition_key_name == "" else partition_key_name if partition_key_template: @@ -374,7 +387,9 @@ def _show_by_type(self, endpoint_name: str, endpoint_type: Optional[str] = None) endpoint_list.extend(endpoints.service_bus_topics) if endpoint_type is None or endpoint_type.lower() == EndpointType.AzureStorageContainer.value: endpoint_list.extend(endpoints.storage_containers) - if self.support_cosmos and (endpoint_type is None or endpoint_type.lower() == EndpointType.CosmosDBContainer.value): + if self.support_cosmos == 2 and (endpoint_type is None or endpoint_type.lower() == EndpointType.CosmosDBContainer.value): + endpoint_list.extend(endpoints.cosmos_db_sql_containers) + if self.support_cosmos == 1 and (endpoint_type is None or endpoint_type.lower() == EndpointType.CosmosDBContainer.value): endpoint_list.extend(endpoints.cosmos_db_sql_collections) for endpoint in endpoint_list: @@ -402,7 +417,9 @@ def list(self, endpoint_type: Optional[str] = None): return endpoints.service_bus_queues elif EndpointType.ServiceBusTopic.value == endpoint_type: return endpoints.service_bus_topics - elif EndpointType.CosmosDBContainer.value == endpoint_type and self.support_cosmos: + elif EndpointType.CosmosDBContainer.value == endpoint_type and self.support_cosmos == 2: + return endpoints.cosmos_db_sql_containers + elif EndpointType.CosmosDBContainer.value == endpoint_type and self.support_cosmos == 1: return endpoints.cosmos_db_sql_collections elif EndpointType.CosmosDBContainer.value == endpoint_type: raise InvalidArgumentValueError(INVALID_CLI_CORE_FOR_COSMOS) @@ -418,7 +435,7 @@ def delete( endpoints = self.hub_resource.properties.routing.endpoints if endpoint_type: endpoint_type = endpoint_type.lower() - if EndpointType.CosmosDBContainer.value == endpoint_type and not self.support_cosmos: + if EndpointType.CosmosDBContainer.value == endpoint_type and self.support_cosmos == 0: raise InvalidArgumentValueError(INVALID_CLI_CORE_FOR_COSMOS) if self.hub_resource.properties.routing.enrichments or self.hub_resource.properties.routing.routes: @@ -438,7 +455,9 @@ def delete( endpoint_names.extend([e.name for e in endpoints.service_bus_queues]) if not endpoint_type or endpoint_type == EndpointType.ServiceBusTopic.value: endpoint_names.extend([e.name for e in endpoints.service_bus_topics]) - if self.support_cosmos and not endpoint_type or endpoint_type == EndpointType.CosmosDBContainer.value: + if self.support_cosmos == 2 and not endpoint_type or endpoint_type == EndpointType.CosmosDBContainer.value: + endpoint_names.extend([e.name for e in endpoints.cosmos_db_sql_containers]) + if self.support_cosmos == 1 and not endpoint_type or endpoint_type == EndpointType.CosmosDBContainer.value: endpoint_names.extend([e.name for e in endpoints.cosmos_db_sql_collections]) if not endpoint_type or endpoint_type == EndpointType.AzureStorageContainer.value: endpoint_names.extend([e.name for e in endpoints.storage_containers]) @@ -486,7 +505,12 @@ def delete( endpoints.service_bus_queues = [e for e in endpoints.service_bus_queues if e.name.lower() != endpoint_name] if not endpoint_type or EndpointType.ServiceBusTopic.value == endpoint_type: endpoints.service_bus_topics = [e for e in endpoints.service_bus_topics if e.name.lower() != endpoint_name] - if self.support_cosmos and not endpoint_type or EndpointType.CosmosDBContainer.value == endpoint_type: + if self.support_cosmos == 2 and not endpoint_type or EndpointType.CosmosDBContainer.value == endpoint_type: + cosmos_db_endpoints = endpoints.cosmos_db_sql_containers if endpoints.cosmos_db_sql_containers else [] + endpoints.cosmos_db_sql_containers = [ + e for e in cosmos_db_endpoints if e.name.lower() != endpoint_name + ] + if self.support_cosmos == 1 and not endpoint_type or EndpointType.CosmosDBContainer.value == endpoint_type: cosmos_db_endpoints = endpoints.cosmos_db_sql_collections if endpoints.cosmos_db_sql_collections else [] endpoints.cosmos_db_sql_collections = [ e for e in cosmos_db_endpoints if e.name.lower() != endpoint_name @@ -501,7 +525,9 @@ def delete( endpoints.service_bus_queues = [] elif EndpointType.ServiceBusTopic.value == endpoint_type: endpoints.service_bus_topics = [] - elif EndpointType.CosmosDBContainer.value == endpoint_type and self.support_cosmos: + elif EndpointType.CosmosDBContainer.value == endpoint_type and self.support_cosmos == 2: + endpoints.cosmos_db_sql_containers = [] + elif EndpointType.CosmosDBContainer.value == endpoint_type and self.support_cosmos == 1: endpoints.cosmos_db_sql_collections = [] elif EndpointType.AzureStorageContainer.value == endpoint_type: endpoints.storage_containers = [] @@ -510,7 +536,9 @@ def delete( endpoints.event_hubs = [] endpoints.service_bus_queues = [] endpoints.service_bus_topics = [] - if self.support_cosmos: + if self.support_cosmos == 2: + endpoints.cosmos_db_sql_containers = [] + if self.support_cosmos == 1: endpoints.cosmos_db_sql_collections = [] endpoints.storage_containers = [] diff --git a/azext_iot/tests/iothub/conftest.py b/azext_iot/tests/iothub/conftest.py index 6286af6c4..b927b9462 100644 --- a/azext_iot/tests/iothub/conftest.py +++ b/azext_iot/tests/iothub/conftest.py @@ -693,6 +693,7 @@ def _cosmos_db_provisioner(): collection_name = generate_hub_depenency_id() partition_key_path = "/test" location = "eastus" + print(f"--locations regionName={location}") cosmos_obj = cli.invoke( "cosmosdb create --name {} --resource-group {} --locations regionName={} failoverPriority=0".format( account_name, RG, location From 28eb1c3a97af8387163e7906b52d799c9c4b2af0 Mon Sep 17 00:00:00 2001 From: Victoria Litvinova Date: Tue, 18 Jul 2023 17:31:49 -0700 Subject: [PATCH 2/4] fix unit test + add back compat test --- .../test_iothub_message_endpoint_unit.py | 213 +++++++++++++++++- 1 file changed, 211 insertions(+), 2 deletions(-) diff --git a/azext_iot/tests/iothub/message_endpoint/test_iothub_message_endpoint_unit.py b/azext_iot/tests/iothub/message_endpoint/test_iothub_message_endpoint_unit.py index 3b384aee3..28c0156a4 100644 --- a/azext_iot/tests/iothub/message_endpoint/test_iothub_message_endpoint_unit.py +++ b/azext_iot/tests/iothub/message_endpoint/test_iothub_message_endpoint_unit.py @@ -52,6 +52,36 @@ def create_mock_endpoint(): hub_mock.properties.routing.endpoints.service_bus_queues = [create_mock_endpoint()] hub_mock.properties.routing.endpoints.service_bus_topics = [create_mock_endpoint()] hub_mock.properties.routing.endpoints.storage_containers = [create_mock_endpoint()] + hub_mock.properties.routing.endpoints.cosmos_db_sql_containers = [create_mock_endpoint()] + + def initialize_mock_client(self, *args): + self.client = mocker.MagicMock() + self.client.begin_create_or_update.return_value = generic_response + return hub_mock + + find_resource.side_effect = initialize_mock_client + + yield find_resource + + +@pytest.fixture() +def fixture_update_endpoint_backwards_comp_ops(mocker): + # Parse connection string + mocker.patch(parse_cosmos_db_cstring_path, return_value={ + "AccountKey": "get_cosmos_db_account_key", + "AccountEndpoint": "get_cosmos_db_account_endpoint" + }) + + # Hub Resource + find_resource = mocker.patch(path_find_resource, autospec=True) + + def create_mock_endpoint(): + endpoint = mocker.Mock() + endpoint.name = endpoint_name + return endpoint + + hub_mock = mocker.MagicMock() + del hub_mock.properties.routing.endpoints.cosmos_db_sql_containers hub_mock.properties.routing.endpoints.cosmos_db_sql_collections = [create_mock_endpoint()] def initialize_mock_client(self, *args): @@ -721,7 +751,7 @@ def test_message_endpoint_update_cosmos_db_sql_container(self, mocker, fixture_c assert req.get("resource_group_name") == resource_group hub_resource = fixture_find_resource.call_args[0][0].client.begin_create_or_update.call_args[0][2] # TODO: @vilit fix once service fixes their naming - endpoints = hub_resource.properties.routing.endpoints.cosmos_db_sql_collections + endpoints = hub_resource.properties.routing.endpoints.cosmos_db_sql_containers assert len(endpoints) == 1 endpoint = endpoints[0] @@ -800,7 +830,7 @@ def test_message_endpoint_update_cosmos_db_sql_container(self, mocker, fixture_c else: assert isinstance(endpoint.authentication_type, mock) - def test_message_endpoint_update_cosmos_db_sql_collections_error(self, fixture_cmd, fixture_update_endpoint_ops): + def test_message_endpoint_update_cosmos_db_sql_container_error(self, fixture_cmd, fixture_update_endpoint_ops): # Cannot do both types of Authentication with pytest.raises(MutuallyExclusiveArgumentError) as e: subject.message_endpoint_update_cosmos_db_container( @@ -848,3 +878,182 @@ def test_message_endpoint_update_cosmos_db_sql_collections_error(self, fixture_c hub_name=hub_name, endpoint_name=generate_names(), ) + + @pytest.mark.parametrize( + "req", + [ + {}, + { + "endpoint_resource_group": generate_names(), + "endpoint_subscription_id": generate_names(), + "database_name": generate_names(), + "connection_string": generate_names(), + "primary_key": None, + "secondary_key": None, + "endpoint_uri": generate_names(), + "partition_key_name": None, + "partition_key_template": None, + "identity": None, + "resource_group_name": None, + }, + { + "endpoint_resource_group": None, + "endpoint_subscription_id": None, + "database_name": None, + "connection_string": None, + "primary_key": None, + "secondary_key": None, + "endpoint_uri": generate_names(), + "partition_key_name": generate_names(), + "partition_key_template": generate_names(), + "identity": generate_names(), + "resource_group_name": generate_names(), + }, + { + "endpoint_resource_group": None, + "endpoint_subscription_id": None, + "database_name": None, + "connection_string": None, + "primary_key": None, + "secondary_key": None, + "endpoint_uri": None, + "partition_key_name": None, + "partition_key_template": None, + "identity": "[system]", + "resource_group_name": None, + }, + { + "endpoint_resource_group": None, + "endpoint_subscription_id": None, + "database_name": None, + "connection_string": generate_names(), + "primary_key": None, + "secondary_key": generate_names(), + "endpoint_uri": None, + "partition_key_name": None, + "partition_key_template": generate_names(), + "identity": None, + "resource_group_name": None, + }, + { + "endpoint_resource_group": generate_names(), + "endpoint_subscription_id": None, + "database_name": None, + "connection_string": generate_names(), + "primary_key": generate_names(), + "secondary_key": generate_names(), + "endpoint_uri": None, + "partition_key_name": generate_names(), + "partition_key_template": None, + "identity": None, + "resource_group_name": None, + }, + { + "endpoint_resource_group": None, + "endpoint_subscription_id": None, + "database_name": generate_names(), + "connection_string": None, + "primary_key": None, + "secondary_key": None, + "endpoint_uri": None, + "partition_key_name": None, + "partition_key_template": None, + "identity": None, + "resource_group_name": None, + }, + ] + ) + def test_message_endpoint_update_cosmos_db_sql_collections( + self, mocker, fixture_cmd, fixture_update_endpoint_backwards_comp_ops, req + ): + result = subject.message_endpoint_update_cosmos_db_container( + cmd=fixture_cmd, + hub_name=hub_name, + endpoint_name=endpoint_name, + **req + ) + fixture_find_resource = fixture_update_endpoint_backwards_comp_ops + + assert result == generic_response + resource_group = fixture_find_resource.call_args[0][2] + assert req.get("resource_group_name") == resource_group + hub_resource = fixture_find_resource.call_args[0][0].client.begin_create_or_update.call_args[0][2] + # TODO: @vilit fix once service fixes their naming + endpoints = hub_resource.properties.routing.endpoints.cosmos_db_sql_collections + assert len(endpoints) == 1 + endpoint = endpoints[0] + + assert endpoint.name == endpoint_name + mock = mocker.Mock + + # if a prop is not set, it will be a Mock object + # Props that will always be set if present + if req.get("endpoint_resource_group"): + assert endpoint.resource_group == req.get("endpoint_resource_group") + else: + assert isinstance(endpoint.resource_group, mock) + + if req.get("endpoint_subscription_id"): + assert endpoint.subscription_id == req.get("endpoint_subscription_id") + else: + assert isinstance(endpoint.subscription_id, mock) + + if req.get("database_name"): + assert endpoint.database_name == req.get("database_name").lower() + else: + assert isinstance(endpoint.database_name, mock) + + if req.get("partition_key_name"): + partition_key_name = req.get("partition_key_name") + if partition_key_name == "": + assert endpoint.partition_key_name is None + else: + endpoint.partition_key_name == partition_key_name + else: + assert isinstance(endpoint.partition_key_name, mock) + + if req.get("partition_key_template"): + partition_key_template = req.get("partition_key_template") + if partition_key_template == "": + assert endpoint.partition_key_template is None + else: + endpoint.partition_key_template == partition_key_template + else: + assert isinstance(endpoint.partition_key_template, mock) + + # Connection strings dont exist + assert isinstance(endpoint.connection_string, mock) + + # Authentication props + if req.get("identity"): + assert endpoint.authentication_type == AuthenticationType.IdentityBased.value + assert endpoint.primary_key is None + assert endpoint.secondary_key is None + identity = req.get("identity") + if identity == "[system]": + assert endpoint.identity is None + else: + assert isinstance(endpoint.identity, ManagedIdentity) + assert endpoint.identity.user_assigned_identity == identity + elif any([req.get("connection_string"), req.get("primary_key"), req.get("secondary_key")]): + assert endpoint.authentication_type == AuthenticationType.KeyBased.value + assert endpoint.identity is None + assert endpoint.entity_path is None + connection_string = req.get("connection_string") + primary_key = req.get("primary_key") + secondary_key = req.get("secondary_key") + endpoint_uri = req.get("endpoint_uri") + + if primary_key: + assert endpoint.primary_key == primary_key + if secondary_key: + assert endpoint.secondary_key == secondary_key + if not primary_key and not secondary_key and connection_string: + assert endpoint.primary_key == endpoint.secondary_key == "get_cosmos_db_account_key" + + if endpoint_uri: + assert endpoint.endpoint_uri == endpoint_uri + elif connection_string: + assert endpoint.endpoint_uri == "get_cosmos_db_account_endpoint" + else: + assert isinstance(endpoint.authentication_type, mock) From 0273882c7d79a07485ac33ba4334c8317da12e4a Mon Sep 17 00:00:00 2001 From: Victoria Litvinova Date: Wed, 16 Aug 2023 11:34:25 -0700 Subject: [PATCH 3/4] Add enum, remove print --- azext_iot/iothub/common.py | 14 ++++ .../iothub/providers/message_endpoint.py | 80 ++++++++++--------- azext_iot/tests/iothub/conftest.py | 1 - 3 files changed, 58 insertions(+), 37 deletions(-) diff --git a/azext_iot/iothub/common.py b/azext_iot/iothub/common.py index e7dc69ef4..9b3aa4dd5 100644 --- a/azext_iot/iothub/common.py +++ b/azext_iot/iothub/common.py @@ -193,3 +193,17 @@ class CertificateAuthorityVersions(Enum): """ v2 = "v2" v1 = "v1" + + +class IoTHubSDKVersion(Enum): + """ + Types to determine which object properties the hub supports for backwards compatibility with the + control plane sdk. Currently has these distinctions (from oldest to newest versions): + + No cosmos endpoints + Cosmos endpoints as collections + Cosmos endpoints as containers + """ + NoCosmos = 0 + CosmosCollections = 1 + CosmosContainers = 2 diff --git a/azext_iot/iothub/providers/message_endpoint.py b/azext_iot/iothub/providers/message_endpoint.py index 81db51995..b47c62f6a 100644 --- a/azext_iot/iothub/providers/message_endpoint.py +++ b/azext_iot/iothub/providers/message_endpoint.py @@ -23,7 +23,8 @@ SYSTEM_ASSIGNED_IDENTITY, AuthenticationType, EncodingFormat, - EndpointType + EndpointType, + IoTHubSDKVersion ) from azext_iot.iothub.providers.base import IoTHubProvider from azext_iot.common._azure import parse_cosmos_db_connection_string @@ -43,11 +44,11 @@ def __init__( ): super(MessageEndpoint, self).__init__(cmd, hub_name, rg, dataplane=False) # Temporary flag to check for which cosmos property to look for. - self.support_cosmos = 0 + self.support_cosmos = IoTHubSDKVersion.NoCosmos.value if hasattr(self.hub_resource.properties.routing.endpoints, "cosmos_db_sql_collections"): - self.support_cosmos = 1 + self.support_cosmos = IoTHubSDKVersion.CosmosCollections.value if hasattr(self.hub_resource.properties.routing.endpoints, "cosmos_db_sql_containers"): - self.support_cosmos = 2 + self.support_cosmos = IoTHubSDKVersion.CosmosContainers.value self.cli = EmbeddedCLI(cli_ctx=self.cmd.cli_ctx) def create( @@ -189,13 +190,13 @@ def create( "partitionKeyName": partition_key_name, "partitionKeyTemplate": partition_key_template, }) - # TODO @vilit - None checks for when the service breaks things - if self.support_cosmos == 2: + # @vilit - None checks for when the service breaks things + if self.support_cosmos == IoTHubSDKVersion.CosmosContainers.value: new_endpoint["containerName"] = container_name if endpoints.cosmos_db_sql_containers is None: endpoints.cosmos_db_sql_containers = [] endpoints.cosmos_db_sql_containers.append(new_endpoint) - if self.support_cosmos == 1: + if self.support_cosmos == IoTHubSDKVersion.CosmosCollections.value: new_endpoint["collectionName"] = container_name if endpoints.cosmos_db_sql_collections is None: endpoints.cosmos_db_sql_collections = [] @@ -380,10 +381,11 @@ def _show_by_type(self, endpoint_name: str, endpoint_type: Optional[str] = None) endpoint_list.extend(endpoints.service_bus_topics) if endpoint_type is None or endpoint_type.lower() == EndpointType.AzureStorageContainer.value: endpoint_list.extend(endpoints.storage_containers) - if self.support_cosmos == 2 and (endpoint_type is None or endpoint_type.lower() == EndpointType.CosmosDBContainer.value): - endpoint_list.extend(endpoints.cosmos_db_sql_containers) - if self.support_cosmos == 1 and (endpoint_type is None or endpoint_type.lower() == EndpointType.CosmosDBContainer.value): - endpoint_list.extend(endpoints.cosmos_db_sql_collections) + if (endpoint_type is None or endpoint_type.lower() == EndpointType.CosmosDBContainer.value): + if self.support_cosmos == IoTHubSDKVersion.CosmosContainers.value: + endpoint_list.extend(endpoints.cosmos_db_sql_containers) + elif self.support_cosmos == IoTHubSDKVersion.CosmosCollections.value: + endpoint_list.extend(endpoints.cosmos_db_sql_collections) for endpoint in endpoint_list: if endpoint.name.lower() == endpoint_name.lower(): @@ -410,10 +412,11 @@ def list(self, endpoint_type: Optional[str] = None): return endpoints.service_bus_queues elif EndpointType.ServiceBusTopic.value == endpoint_type: return endpoints.service_bus_topics - elif EndpointType.CosmosDBContainer.value == endpoint_type and self.support_cosmos == 2: - return endpoints.cosmos_db_sql_containers - elif EndpointType.CosmosDBContainer.value == endpoint_type and self.support_cosmos == 1: - return endpoints.cosmos_db_sql_collections + elif EndpointType.CosmosDBContainer.value == endpoint_type: + if self.support_cosmos == IoTHubSDKVersion.CosmosContainers.value: + return endpoints.cosmos_db_sql_containers + elif self.support_cosmos == IoTHubSDKVersion.CosmosCollections.value: + return endpoints.cosmos_db_sql_collections elif EndpointType.CosmosDBContainer.value == endpoint_type: raise InvalidArgumentValueError(INVALID_CLI_CORE_FOR_COSMOS) elif EndpointType.AzureStorageContainer.value == endpoint_type: @@ -428,7 +431,9 @@ def delete( endpoints = self.hub_resource.properties.routing.endpoints if endpoint_type: endpoint_type = endpoint_type.lower() - if EndpointType.CosmosDBContainer.value == endpoint_type and self.support_cosmos == 0: + if ( + EndpointType.CosmosDBContainer.value == endpoint_type and self.support_cosmos == IoTHubSDKVersion.NoCosmos.value + ): raise InvalidArgumentValueError(INVALID_CLI_CORE_FOR_COSMOS) if self.hub_resource.properties.routing.enrichments or self.hub_resource.properties.routing.routes: @@ -448,10 +453,11 @@ def delete( endpoint_names.extend([e.name for e in endpoints.service_bus_queues]) if not endpoint_type or endpoint_type == EndpointType.ServiceBusTopic.value: endpoint_names.extend([e.name for e in endpoints.service_bus_topics]) - if self.support_cosmos == 2 and not endpoint_type or endpoint_type == EndpointType.CosmosDBContainer.value: - endpoint_names.extend([e.name for e in endpoints.cosmos_db_sql_containers]) - if self.support_cosmos == 1 and not endpoint_type or endpoint_type == EndpointType.CosmosDBContainer.value: - endpoint_names.extend([e.name for e in endpoints.cosmos_db_sql_collections]) + if not endpoint_type or endpoint_type == EndpointType.CosmosDBContainer.value: + if self.support_cosmos == IoTHubSDKVersion.CosmosContainers.value: + endpoint_names.extend([e.name for e in endpoints.cosmos_db_sql_containers]) + if self.support_cosmos == IoTHubSDKVersion.CosmosCollections.value: + endpoint_names.extend([e.name for e in endpoints.cosmos_db_sql_collections]) if not endpoint_type or endpoint_type == EndpointType.AzureStorageContainer.value: endpoint_names.extend([e.name for e in endpoints.storage_containers]) @@ -498,16 +504,17 @@ def delete( endpoints.service_bus_queues = [e for e in endpoints.service_bus_queues if e.name.lower() != endpoint_name] if not endpoint_type or EndpointType.ServiceBusTopic.value == endpoint_type: endpoints.service_bus_topics = [e for e in endpoints.service_bus_topics if e.name.lower() != endpoint_name] - if self.support_cosmos == 2 and not endpoint_type or EndpointType.CosmosDBContainer.value == endpoint_type: - cosmos_db_endpoints = endpoints.cosmos_db_sql_containers if endpoints.cosmos_db_sql_containers else [] - endpoints.cosmos_db_sql_containers = [ - e for e in cosmos_db_endpoints if e.name.lower() != endpoint_name - ] - if self.support_cosmos == 1 and not endpoint_type or EndpointType.CosmosDBContainer.value == endpoint_type: - cosmos_db_endpoints = endpoints.cosmos_db_sql_collections if endpoints.cosmos_db_sql_collections else [] - endpoints.cosmos_db_sql_collections = [ - e for e in cosmos_db_endpoints if e.name.lower() != endpoint_name - ] + if not endpoint_type or endpoint_type == EndpointType.CosmosDBContainer.value: + if self.support_cosmos == IoTHubSDKVersion.CosmosContainers.value: + cosmos_db_endpoints = endpoints.cosmos_db_sql_containers if endpoints.cosmos_db_sql_containers else [] + endpoints.cosmos_db_sql_containers = [ + e for e in cosmos_db_endpoints if e.name.lower() != endpoint_name + ] + if self.support_cosmos == IoTHubSDKVersion.CosmosCollections.value: + cosmos_db_endpoints = endpoints.cosmos_db_sql_collections if endpoints.cosmos_db_sql_collections else [] + endpoints.cosmos_db_sql_collections = [ + e for e in cosmos_db_endpoints if e.name.lower() != endpoint_name + ] if not endpoint_type or EndpointType.AzureStorageContainer.value == endpoint_type: endpoints.storage_containers = [e for e in endpoints.storage_containers if e.name.lower() != endpoint_name] elif endpoint_type: @@ -518,10 +525,11 @@ def delete( endpoints.service_bus_queues = [] elif EndpointType.ServiceBusTopic.value == endpoint_type: endpoints.service_bus_topics = [] - elif EndpointType.CosmosDBContainer.value == endpoint_type and self.support_cosmos == 2: - endpoints.cosmos_db_sql_containers = [] - elif EndpointType.CosmosDBContainer.value == endpoint_type and self.support_cosmos == 1: - endpoints.cosmos_db_sql_collections = [] + elif EndpointType.CosmosDBContainer.value == endpoint_type: + if self.support_cosmos == IoTHubSDKVersion.CosmosContainers.value: + endpoints.cosmos_db_sql_containers = [] + elif self.support_cosmos == IoTHubSDKVersion.CosmosCollections.value: + endpoints.cosmos_db_sql_collections = [] elif EndpointType.AzureStorageContainer.value == endpoint_type: endpoints.storage_containers = [] else: @@ -529,9 +537,9 @@ def delete( endpoints.event_hubs = [] endpoints.service_bus_queues = [] endpoints.service_bus_topics = [] - if self.support_cosmos == 2: + if self.support_cosmos == IoTHubSDKVersion.CosmosContainers.value: endpoints.cosmos_db_sql_containers = [] - if self.support_cosmos == 1: + if self.support_cosmos == IoTHubSDKVersion.CosmosCollections.value: endpoints.cosmos_db_sql_collections = [] endpoints.storage_containers = [] diff --git a/azext_iot/tests/iothub/conftest.py b/azext_iot/tests/iothub/conftest.py index b927b9462..6286af6c4 100644 --- a/azext_iot/tests/iothub/conftest.py +++ b/azext_iot/tests/iothub/conftest.py @@ -693,7 +693,6 @@ def _cosmos_db_provisioner(): collection_name = generate_hub_depenency_id() partition_key_path = "/test" location = "eastus" - print(f"--locations regionName={location}") cosmos_obj = cli.invoke( "cosmosdb create --name {} --resource-group {} --locations regionName={} failoverPriority=0".format( account_name, RG, location From f16ebc7c99eb7a46b9dac343d45918fe2a455cf0 Mon Sep 17 00:00:00 2001 From: Victoria Litvinova Date: Wed, 23 Aug 2023 09:17:12 -0700 Subject: [PATCH 4/4] fix int test --- .../test_iothub_message_endpoint_int.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/azext_iot/tests/iothub/message_endpoint/test_iothub_message_endpoint_int.py b/azext_iot/tests/iothub/message_endpoint/test_iothub_message_endpoint_int.py index 14b8788c3..c1e12b3e3 100644 --- a/azext_iot/tests/iothub/message_endpoint/test_iothub_message_endpoint_int.py +++ b/azext_iot/tests/iothub/message_endpoint/test_iothub_message_endpoint_int.py @@ -1041,7 +1041,8 @@ def test_iot_cosmos_endpoint_lifecycle(provisioned_cosmosdb_with_identity_module ).as_json() assert len(cosmos_list) == 3 - assert endpoint_list["cosmosDbSqlCollections"] == cosmos_list + expected_list = endpoint_list.get("cosmosDbSqlCollections", []) + endpoint_list.get("cosmosDbSqlContainers", []) + assert cosmos_list == expected_list # Update # Keybased -> User, add pkn + pkt @@ -1457,8 +1458,7 @@ def build_expected_endpoint( expected["connectionString"] = connection_string if entity_path: expected["entityPath"] = entity_path - if container_name and not database_name: - # storage container + if container_name: expected["containerName"] = container_name if encoding: expected["encoding"] = encoding @@ -1471,9 +1471,6 @@ def build_expected_endpoint( expected["maxChunkSizeInBytes"] = max_chunk_size_in_bytes * max_chunk_size_constant if database_name: expected["databaseName"] = database_name - if container_name and database_name: - # cosmosdb container - expected["collectionName"] = container_name if partition_key_name: expected["partitionKeyName"] = partition_key_name if partition_key_template: @@ -1522,9 +1519,15 @@ def assert_endpoint_properties(result: dict, expected: dict): if "entityPath" in expected: assert result["entityPath"] == expected["entityPath"] - # Storage Account only + # Shared between Storage and Cosmos Db: if "containerName" in expected: - assert result["containerName"] == expected["containerName"] + resulting_container_name = result.get("containerName") + if resulting_container_name is None: + # older version of cosmos + resulting_container_name = result.get("collectionName") + assert resulting_container_name == expected["containerName"] + + # Storage Account only if "encoding" in expected: assert result["encoding"] == expected["encoding"] if "fileNameFormat" in expected: @@ -1537,8 +1540,6 @@ def assert_endpoint_properties(result: dict, expected: dict): # Cosmos DB only if "databaseName" in expected: assert result["databaseName"] == expected["databaseName"] - if "collectionName" in expected: - assert result["collectionName"] == expected["collectionName"] if "partitionKeyName" in expected: assert result["partitionKeyName"] == expected["partitionKeyName"] if "partitionKeyTemplate" in expected: