Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions azext_iot/iothub/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,17 @@ class CertificateAuthorityVersions(Enum):
"""
v2 = "v2"
v1 = "v1"


class IoTHubSDKVersion(Enum):
"""
Types to determine which object properties the hub supports for backwards compatibility with the
control plane sdk. Currently has these distinctions (from oldest to newest versions):

No cosmos endpoints
Cosmos endpoints as collections
Cosmos endpoints as containers
"""
NoCosmos = 0
CosmosCollections = 1
CosmosContainers = 2
78 changes: 56 additions & 22 deletions azext_iot/iothub/providers/message_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
SYSTEM_ASSIGNED_IDENTITY,
AuthenticationType,
EncodingFormat,
EndpointType
EndpointType,
IoTHubSDKVersion
)
from azext_iot.iothub.providers.base import IoTHubProvider
from azext_iot.common._azure import parse_cosmos_db_connection_string
Expand All @@ -42,7 +43,12 @@ def __init__(
rg: Optional[str] = None,
):
super(MessageEndpoint, self).__init__(cmd, hub_name, rg, dataplane=False)
self.support_cosmos = hasattr(self.hub_resource.properties.routing.endpoints, "cosmos_db_sql_collections")
# Temporary flag to check for which cosmos property to look for.
self.support_cosmos = IoTHubSDKVersion.NoCosmos.value
if hasattr(self.hub_resource.properties.routing.endpoints, "cosmos_db_sql_collections"):
self.support_cosmos = IoTHubSDKVersion.CosmosCollections.value
if hasattr(self.hub_resource.properties.routing.endpoints, "cosmos_db_sql_containers"):
self.support_cosmos = IoTHubSDKVersion.CosmosContainers.value
self.cli = EmbeddedCLI(cli_ctx=self.cmd.cli_ctx)

def create(
Expand Down Expand Up @@ -179,16 +185,22 @@ def create(
del new_endpoint["connectionString"]
new_endpoint.update({
"databaseName": database_name,
"collectionName": container_name,
"primaryKey": primary_key,
"secondaryKey": secondary_key,
"partitionKeyName": partition_key_name,
"partitionKeyTemplate": partition_key_template,
})
# TODO @vilit - why is this None if empty
if endpoints.cosmos_db_sql_collections is None:
endpoints.cosmos_db_sql_collections = []
endpoints.cosmos_db_sql_collections.append(new_endpoint)
# @vilit - None checks for when the service breaks things
if self.support_cosmos == IoTHubSDKVersion.CosmosContainers.value:
new_endpoint["containerName"] = container_name
if endpoints.cosmos_db_sql_containers is None:
endpoints.cosmos_db_sql_containers = []
endpoints.cosmos_db_sql_containers.append(new_endpoint)
if self.support_cosmos == IoTHubSDKVersion.CosmosCollections.value:
new_endpoint["collectionName"] = container_name
if endpoints.cosmos_db_sql_collections is None:
endpoints.cosmos_db_sql_collections = []
endpoints.cosmos_db_sql_collections.append(new_endpoint)
elif endpoint_type.lower() == EndpointType.AzureStorageContainer.value:
if fetch_connection_string:
# try to get connection string
Expand Down Expand Up @@ -369,8 +381,11 @@ def _show_by_type(self, endpoint_name: str, endpoint_type: Optional[str] = None)
endpoint_list.extend(endpoints.service_bus_topics)
if endpoint_type is None or endpoint_type.lower() == EndpointType.AzureStorageContainer.value:
endpoint_list.extend(endpoints.storage_containers)
if self.support_cosmos and (endpoint_type is None or endpoint_type.lower() == EndpointType.CosmosDBContainer.value):
endpoint_list.extend(endpoints.cosmos_db_sql_collections)
if (endpoint_type is None or endpoint_type.lower() == EndpointType.CosmosDBContainer.value):
if self.support_cosmos == IoTHubSDKVersion.CosmosContainers.value:
endpoint_list.extend(endpoints.cosmos_db_sql_containers)
elif self.support_cosmos == IoTHubSDKVersion.CosmosCollections.value:
endpoint_list.extend(endpoints.cosmos_db_sql_collections)

for endpoint in endpoint_list:
if endpoint.name.lower() == endpoint_name.lower():
Expand All @@ -397,8 +412,11 @@ def list(self, endpoint_type: Optional[str] = None):
return endpoints.service_bus_queues
elif EndpointType.ServiceBusTopic.value == endpoint_type:
return endpoints.service_bus_topics
elif EndpointType.CosmosDBContainer.value == endpoint_type and self.support_cosmos:
return endpoints.cosmos_db_sql_collections
elif EndpointType.CosmosDBContainer.value == endpoint_type:
if self.support_cosmos == IoTHubSDKVersion.CosmosContainers.value:
return endpoints.cosmos_db_sql_containers
elif self.support_cosmos == IoTHubSDKVersion.CosmosCollections.value:
return endpoints.cosmos_db_sql_collections
elif EndpointType.CosmosDBContainer.value == endpoint_type:
raise InvalidArgumentValueError(INVALID_CLI_CORE_FOR_COSMOS)
elif EndpointType.AzureStorageContainer.value == endpoint_type:
Expand All @@ -413,7 +431,9 @@ def delete(
endpoints = self.hub_resource.properties.routing.endpoints
if endpoint_type:
endpoint_type = endpoint_type.lower()
if EndpointType.CosmosDBContainer.value == endpoint_type and not self.support_cosmos:
if (
EndpointType.CosmosDBContainer.value == endpoint_type and self.support_cosmos == IoTHubSDKVersion.NoCosmos.value
):
raise InvalidArgumentValueError(INVALID_CLI_CORE_FOR_COSMOS)

if self.hub_resource.properties.routing.enrichments or self.hub_resource.properties.routing.routes:
Expand All @@ -433,8 +453,11 @@ def delete(
endpoint_names.extend([e.name for e in endpoints.service_bus_queues])
if not endpoint_type or endpoint_type == EndpointType.ServiceBusTopic.value:
endpoint_names.extend([e.name for e in endpoints.service_bus_topics])
if self.support_cosmos and not endpoint_type or endpoint_type == EndpointType.CosmosDBContainer.value:
endpoint_names.extend([e.name for e in endpoints.cosmos_db_sql_collections])
if not endpoint_type or endpoint_type == EndpointType.CosmosDBContainer.value:
if self.support_cosmos == IoTHubSDKVersion.CosmosContainers.value:
endpoint_names.extend([e.name for e in endpoints.cosmos_db_sql_containers])
if self.support_cosmos == IoTHubSDKVersion.CosmosCollections.value:
endpoint_names.extend([e.name for e in endpoints.cosmos_db_sql_collections])
if not endpoint_type or endpoint_type == EndpointType.AzureStorageContainer.value:
endpoint_names.extend([e.name for e in endpoints.storage_containers])

Expand Down Expand Up @@ -481,11 +504,17 @@ def delete(
endpoints.service_bus_queues = [e for e in endpoints.service_bus_queues if e.name.lower() != endpoint_name]
if not endpoint_type or EndpointType.ServiceBusTopic.value == endpoint_type:
endpoints.service_bus_topics = [e for e in endpoints.service_bus_topics if e.name.lower() != endpoint_name]
if self.support_cosmos and not endpoint_type or EndpointType.CosmosDBContainer.value == endpoint_type:
cosmos_db_endpoints = endpoints.cosmos_db_sql_collections if endpoints.cosmos_db_sql_collections else []
endpoints.cosmos_db_sql_collections = [
e for e in cosmos_db_endpoints if e.name.lower() != endpoint_name
]
if not endpoint_type or endpoint_type == EndpointType.CosmosDBContainer.value:
if self.support_cosmos == IoTHubSDKVersion.CosmosContainers.value:
cosmos_db_endpoints = endpoints.cosmos_db_sql_containers if endpoints.cosmos_db_sql_containers else []
endpoints.cosmos_db_sql_containers = [
e for e in cosmos_db_endpoints if e.name.lower() != endpoint_name
]
if self.support_cosmos == IoTHubSDKVersion.CosmosCollections.value:
cosmos_db_endpoints = endpoints.cosmos_db_sql_collections if endpoints.cosmos_db_sql_collections else []
endpoints.cosmos_db_sql_collections = [
e for e in cosmos_db_endpoints if e.name.lower() != endpoint_name
]
if not endpoint_type or EndpointType.AzureStorageContainer.value == endpoint_type:
endpoints.storage_containers = [e for e in endpoints.storage_containers if e.name.lower() != endpoint_name]
elif endpoint_type:
Expand All @@ -496,16 +525,21 @@ def delete(
endpoints.service_bus_queues = []
elif EndpointType.ServiceBusTopic.value == endpoint_type:
endpoints.service_bus_topics = []
elif EndpointType.CosmosDBContainer.value == endpoint_type and self.support_cosmos:
endpoints.cosmos_db_sql_collections = []
elif EndpointType.CosmosDBContainer.value == endpoint_type:
if self.support_cosmos == IoTHubSDKVersion.CosmosContainers.value:
endpoints.cosmos_db_sql_containers = []
elif self.support_cosmos == IoTHubSDKVersion.CosmosCollections.value:
endpoints.cosmos_db_sql_collections = []
elif EndpointType.AzureStorageContainer.value == endpoint_type:
endpoints.storage_containers = []
else:
# Delete all endpoints
endpoints.event_hubs = []
endpoints.service_bus_queues = []
endpoints.service_bus_topics = []
if self.support_cosmos:
if self.support_cosmos == IoTHubSDKVersion.CosmosContainers.value:
endpoints.cosmos_db_sql_containers = []
if self.support_cosmos == IoTHubSDKVersion.CosmosCollections.value:
endpoints.cosmos_db_sql_collections = []
endpoints.storage_containers = []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1041,7 +1041,8 @@ def test_iot_cosmos_endpoint_lifecycle(provisioned_cosmosdb_with_identity_module
).as_json()

assert len(cosmos_list) == 3
assert endpoint_list["cosmosDbSqlCollections"] == cosmos_list
expected_list = endpoint_list.get("cosmosDbSqlCollections", []) + endpoint_list.get("cosmosDbSqlContainers", [])
assert cosmos_list == expected_list

# Update
# Keybased -> User, add pkn + pkt
Expand Down Expand Up @@ -1457,8 +1458,7 @@ def build_expected_endpoint(
expected["connectionString"] = connection_string
if entity_path:
expected["entityPath"] = entity_path
if container_name and not database_name:
# storage container
if container_name:
expected["containerName"] = container_name
if encoding:
expected["encoding"] = encoding
Expand All @@ -1471,9 +1471,6 @@ def build_expected_endpoint(
expected["maxChunkSizeInBytes"] = max_chunk_size_in_bytes * max_chunk_size_constant
if database_name:
expected["databaseName"] = database_name
if container_name and database_name:
# cosmosdb container
expected["collectionName"] = container_name
if partition_key_name:
expected["partitionKeyName"] = partition_key_name
if partition_key_template:
Expand Down Expand Up @@ -1522,9 +1519,15 @@ def assert_endpoint_properties(result: dict, expected: dict):
if "entityPath" in expected:
assert result["entityPath"] == expected["entityPath"]

# Storage Account only
# Shared between Storage and Cosmos Db:
if "containerName" in expected:
assert result["containerName"] == expected["containerName"]
resulting_container_name = result.get("containerName")
if resulting_container_name is None:
# older version of cosmos
resulting_container_name = result.get("collectionName")
assert resulting_container_name == expected["containerName"]

# Storage Account only
if "encoding" in expected:
assert result["encoding"] == expected["encoding"]
if "fileNameFormat" in expected:
Expand All @@ -1537,8 +1540,6 @@ def assert_endpoint_properties(result: dict, expected: dict):
# Cosmos DB only
if "databaseName" in expected:
assert result["databaseName"] == expected["databaseName"]
if "collectionName" in expected:
assert result["collectionName"] == expected["collectionName"]
if "partitionKeyName" in expected:
assert result["partitionKeyName"] == expected["partitionKeyName"]
if "partitionKeyTemplate" in expected:
Expand Down
Loading